mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 02:46:24 +02:00
Merge pull request #991 from garylin2099/code_interpreter
Solve confilcts
This commit is contained in:
commit
25fbab1cd0
146 changed files with 4466 additions and 1375 deletions
1
.github/workflows/fulltest.yaml
vendored
1
.github/workflows/fulltest.yaml
vendored
|
|
@ -54,7 +54,6 @@ jobs:
|
|||
export ALLOW_OPENAI_API_CALL=0
|
||||
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
|
||||
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
|
||||
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
|
||||
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
|
||||
- name: Show coverage report
|
||||
run: |
|
||||
|
|
|
|||
2
.github/workflows/unittest.yaml
vendored
2
.github/workflows/unittest.yaml
vendored
|
|
@ -31,7 +31,7 @@ jobs:
|
|||
- name: Test with pytest
|
||||
run: |
|
||||
export ALLOW_OPENAI_API_CALL=0
|
||||
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
|
||||
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
|
||||
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
|
||||
- name: Show coverage report
|
||||
run: |
|
||||
|
|
|
|||
43
README.md
43
README.md
|
|
@ -26,7 +26,9 @@ # MetaGPT: The Multi-Agent Framework
|
|||
</p>
|
||||
|
||||
## News
|
||||
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/mi/README.md), a powerful agent capable of solving a wide range of real-world problems.
|
||||
🚀 March. 01, 2024: Our Data Interpreter paper is on arxiv. Find all design and benchmark details [here](https://arxiv.org/abs/2402.18679)!
|
||||
|
||||
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), a powerful agent capable of solving a wide range of real-world problems.
|
||||
|
||||
🚀 Jan. 16, 2024: Our paper [MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework
|
||||
](https://arxiv.org/abs/2308.00352) accepted for oral presentation **(top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
|
||||
|
|
@ -97,6 +99,45 @@ ### Usage
|
|||
detail installation please refer to [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version)
|
||||
or [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker)
|
||||
|
||||
### Docker installation
|
||||
<details><summary><strong>⏬ Step 1: Download metagpt image and prepare config2.yaml </strong><i>:: click to expand ::</i></summary>
|
||||
<div>
|
||||
|
||||
```bash
|
||||
docker pull metagpt/metagpt:latest
|
||||
mkdir -p /opt/metagpt/{config,workspace}
|
||||
docker run --rm metagpt/metagpt:latest cat /app/metagpt/config/config2.yaml > /opt/metagpt/config/config2.yaml
|
||||
vim /opt/metagpt/config/config2.yaml # Change the config
|
||||
```
|
||||
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details><summary><strong>⏬ Step 2: Run metagpt container </strong><i>:: click to expand ::</i></summary>
|
||||
<div>
|
||||
|
||||
```bash
|
||||
docker run --name metagpt -d \
|
||||
--privileged \
|
||||
-v /opt/metagpt/config/config2.yaml:/app/metagpt/config/config2.yaml \
|
||||
-v /opt/metagpt/workspace:/app/metagpt/workspace \
|
||||
metagpt/metagpt:latest
|
||||
```
|
||||
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details><summary><strong>⏬ Step 3: Use metagpt </strong><i>:: click to expand ::</i></summary>
|
||||
<div>
|
||||
|
||||
```bash
|
||||
docker exec -it metagpt /bin/bash
|
||||
$ metagpt "Create a 2048 game" # this will create a repo in ./workspace
|
||||
```
|
||||
|
||||
</div>
|
||||
</details>
|
||||
|
||||
### QuickStart & Demo Video
|
||||
- Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT)
|
||||
- [Matthew Berman: How To Install MetaGPT - Build A Startup With One Prompt!!](https://youtu.be/uT75J_KG_aY)
|
||||
|
|
|
|||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
| Version | Supported |
|
||||
|---------|--------------------|
|
||||
| 7.x | :x: |
|
||||
| 6.x | :x: |
|
||||
| < 6.x | :x: |
|
||||
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you have any vulnerability reports, please contact alexanderwu@deepwisdom.ai .
|
||||
|
|
@ -3,8 +3,16 @@ llm:
|
|||
base_url: "YOUR_BASE_URL"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
repair_llm_output: true # when the output is not a valid json, try to repair it
|
||||
proxy: "YOUR_PROXY" # for LLM API requests
|
||||
pricing_plan: "" # Optional. If invalid, it will be automatically filled in with the value of the `model`.
|
||||
# Azure-exclusive pricing plan mappings:
|
||||
# - gpt-3.5-turbo 4k: "gpt-3.5-turbo-1106"
|
||||
# - gpt-4-turbo: "gpt-4-turbo-preview"
|
||||
# - gpt-4-turbo-vision: "gpt-4-vision-preview"
|
||||
# - gpt-4 8k: "gpt-4"
|
||||
# See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
|
||||
repair_llm_output: true # when the output is not a valid json, try to repair it
|
||||
|
||||
proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Author: garylin2099
|
|||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `send_to`
|
||||
value of the `Message` object; modify the argument type of `get_by_actions`.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import platform
|
||||
from typing import Any
|
||||
|
|
@ -105,4 +106,4 @@ def main(idea: str, investment: float = 3.0, n_round: int = 10):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
fire.Fire(main) # run as python debate.py --idea="TOPIC" --investment=3.0 --n_round=5
|
||||
|
|
|
|||
|
|
@ -8,14 +8,17 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import Role
|
||||
from metagpt.team import Team
|
||||
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action1.llm.model = "gpt-4-1106-preview"
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2.llm.model = "gpt-3.5-turbo-1106"
|
||||
gpt35 = Config.default()
|
||||
gpt35.llm.model = "gpt-3.5-turbo-1106"
|
||||
gpt4 = Config.default()
|
||||
gpt4.llm.model = "gpt-4-1106-preview"
|
||||
action1 = Action(config=gpt4, name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2 = Action(config=gpt35, name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
|
||||
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
|
|
|
|||
18
examples/di/README.md
Normal file
18
examples/di/README.md
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Data Interpreter (DI)
|
||||
|
||||
## What is Data Interpreter
|
||||
Data Interpreter is an agent who solves problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below.
|
||||
|
||||
## Example List
|
||||
- Data visualization
|
||||
- Machine learning modeling
|
||||
- Image background removal
|
||||
- Solve math problems
|
||||
- Receipt OCR
|
||||
- Tool usage: web page imitation
|
||||
- Tool usage: web crawling
|
||||
- Tool usage: text2image
|
||||
- Tool usage: email summarization and response
|
||||
- More on the way!
|
||||
|
||||
Please see [here](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html) for detailed explanation.
|
||||
|
|
@ -5,15 +5,15 @@
|
|||
@File : crawl_webpage.py
|
||||
"""
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
prompt = """Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
|
||||
and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables*"""
|
||||
mi = Interpreter(use_tools=True)
|
||||
di = DataInterpreter(use_tools=True)
|
||||
|
||||
await mi.run(prompt)
|
||||
await di.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
mi = Interpreter(use_tools=False)
|
||||
await mi.run(requirement)
|
||||
di = DataInterpreter(use_tools=False)
|
||||
await di.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
"""
|
||||
import os
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -22,9 +22,9 @@ async def main():
|
|||
Firstly, Please help me fetch the latest 5 senders and full letter contents.
|
||||
Then, summarize each of the 5 emails into one sentence (you can do this by yourself, no need to import other models to do this) and output them in a markdown format."""
|
||||
|
||||
mi = Interpreter(use_tools=True)
|
||||
di = DataInterpreter(use_tools=True)
|
||||
|
||||
await mi.run(prompt)
|
||||
await di.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -5,7 +5,7 @@
|
|||
@Author : mannaandpoem
|
||||
@File : imitate_webpage.py
|
||||
"""
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
@ -15,9 +15,9 @@ Firstly, utilize Selenium and WebDriver for rendering.
|
|||
Secondly, convert image to a webpage including HTML, CSS and JS in one go.
|
||||
Finally, save webpage in a text file.
|
||||
Note: All required dependencies and environments have been fully installed and configured."""
|
||||
mi = Interpreter(use_tools=True)
|
||||
di = DataInterpreter(use_tools=True)
|
||||
|
||||
await mi.run(prompt)
|
||||
await di.run(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
13
examples/di/machine_learning.py
Normal file
13
examples/di/machine_learning.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import fire
|
||||
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main(auto_run: bool = True):
|
||||
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
di = DataInterpreter(auto_run=auto_run)
|
||||
await di.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.roles.mi.ml_engineer import MLEngineer
|
||||
from metagpt.roles.di.ml_engineer import MLEngineer
|
||||
|
||||
|
||||
async def main(requirement: str):
|
||||
21
examples/di/ocr_receipt.py
Normal file
21
examples/di/ocr_receipt.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main():
|
||||
# Notice: pip install metagpt[ocr] before using this example
|
||||
image_path = "image.jpg"
|
||||
language = "English"
|
||||
requirement = f"""This is a {language} receipt image.
|
||||
Your goal is to perform OCR on images using PaddleOCR, output text content from the OCR results and discard
|
||||
coordinates and confidence levels, then recognize the total amount from ocr text content, and finally save as table.
|
||||
Image path: {image_path}.
|
||||
NOTE: The environments for Paddle and PaddleOCR are all ready and has been fully installed."""
|
||||
di = DataInterpreter()
|
||||
|
||||
await di.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
mi = Interpreter(use_tools=False)
|
||||
await mi.run(requirement)
|
||||
di = DataInterpreter(use_tools=False)
|
||||
await di.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -4,12 +4,12 @@
|
|||
# @Desc :
|
||||
import asyncio
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
mi = Interpreter(use_tools=True, goal=requirement)
|
||||
await mi.run(requirement)
|
||||
di = DataInterpreter(use_tools=True, goal=requirement)
|
||||
await di.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
mi = Interpreter(use_tools=False)
|
||||
await mi.run(requirement)
|
||||
di = DataInterpreter(use_tools=False)
|
||||
await di.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -13,7 +13,18 @@ from metagpt.logs import logger
|
|||
|
||||
async def main():
|
||||
llm = LLM()
|
||||
logger.info(await llm.aask("hello world"))
|
||||
# llm type check
|
||||
question = "what's your name"
|
||||
logger.info(f"{question}: ")
|
||||
logger.info(await llm.aask(question))
|
||||
logger.info("\n\n")
|
||||
|
||||
logger.info(
|
||||
await llm.aask(
|
||||
"who are you", system_msgs=["act as a robot, just answer 'I'am robot' if the question is 'who are you'"]
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
|
||||
|
||||
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
# MetaGPT Interpreter (MI)
|
||||
|
||||
## What is Interpreter
|
||||
Interpreter is an agent who solves problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below.
|
||||
|
||||
## Example List
|
||||
- Data visualization
|
||||
- Machine learning modeling
|
||||
- Image background removal
|
||||
- Solve math problems
|
||||
- Receipt OCR
|
||||
- Tool usage: web page imitation
|
||||
- Tool usage: web crawling
|
||||
- Tool usage: text2image
|
||||
- Tool usage: email summarization and response
|
||||
- More on the way!
|
||||
|
||||
Please see [here](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/mi_intro.html) for detailed explanation.
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
import fire
|
||||
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
|
||||
WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
|
||||
DATA_DIR = "path/to/your/data"
|
||||
# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data
|
||||
SALES_FORECAST_REQ = f"""Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset, the others is train dataset), include plot total sales trends, print metric and plot scatter plots of
|
||||
groud truth and predictions on validation data. Dataset is {DATA_DIR}/train.csv, the metric is weighted mean absolute error (WMAE) for test data. Notice: *print* key variables to get more information for next task step.
|
||||
"""
|
||||
|
||||
REQUIREMENTS = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ}
|
||||
|
||||
|
||||
async def main(auto_run: bool = True, use_case: str = "wine"):
|
||||
mi = Interpreter(auto_run=auto_run)
|
||||
requirement = REQUIREMENTS[use_case]
|
||||
await mi.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
|
||||
|
||||
async def main():
|
||||
# Notice: pip install metagpt[ocr] before using this example
|
||||
image_path = "image.jpg"
|
||||
language = "English"
|
||||
requirement = f"""This is a {language} receipt image.
|
||||
Your goal is to perform OCR on images using PaddleOCR, then extract the total amount from ocr text results, and finally save as table. Image path: {image_path}.
|
||||
NOTE: The environments for Paddle and PaddleOCR are all ready and has been fully installed."""
|
||||
mi = Interpreter()
|
||||
|
||||
await mi.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
72
examples/reverse_engineering.py
Normal file
72
examples/reverse_engineering.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
|
||||
from metagpt.context import Context
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
||||
|
||||
@app.command("", help="Python project reverse engineering.")
|
||||
def startup(
|
||||
project_root: str = typer.Argument(
|
||||
default="",
|
||||
help="Specify the root directory of the existing project for reverse engineering.",
|
||||
),
|
||||
output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."),
|
||||
):
|
||||
package_root = Path(project_root)
|
||||
if not package_root.exists():
|
||||
raise FileNotFoundError(f"{project_root} not exists")
|
||||
if not _is_python_package_root(package_root):
|
||||
raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".')
|
||||
init_file = package_root / "__init__.py" # used by pyreverse
|
||||
init_file_exists = init_file.exists()
|
||||
if not init_file_exists:
|
||||
init_file.touch()
|
||||
|
||||
if not output_dir:
|
||||
output_dir = package_root / "../reverse_engineering_output"
|
||||
logger.info(f"output dir:{output_dir}")
|
||||
try:
|
||||
asyncio.run(reverse_engineering(package_root, Path(output_dir)))
|
||||
finally:
|
||||
if not init_file_exists:
|
||||
init_file.unlink(missing_ok=True)
|
||||
tmp_dir = package_root / "__dot__"
|
||||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def _is_python_package_root(package_root: Path) -> bool:
|
||||
for file_path in package_root.iterdir():
|
||||
if file_path.is_file():
|
||||
if file_path.suffix == ".py":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def reverse_engineering(package_root: Path, output_dir: Path):
|
||||
ctx = Context()
|
||||
ctx.git_repo = GitRepository(output_dir)
|
||||
ctx.repo = ProjectRepo(ctx.git_repo)
|
||||
action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx)
|
||||
await action.run()
|
||||
|
||||
action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx)
|
||||
await action.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
|
@ -4,21 +4,17 @@
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.roles import Searcher
|
||||
from metagpt.tools.search_engine import SearchEngine, SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def main():
|
||||
question = "What are the most interesting human facts?"
|
||||
kwargs = {"api_key": "", "cse_id": "", "proxy": None}
|
||||
# Serper API
|
||||
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question)
|
||||
# SerpAPI
|
||||
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question)
|
||||
# Google API
|
||||
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question)
|
||||
# DDG API
|
||||
await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question)
|
||||
|
||||
search = Config.default().search
|
||||
kwargs = {"api_key": search.api_key, "cse_id": search.cse_id, "proxy": None}
|
||||
await Searcher(search_engine=SearchEngine(engine=search.api_type, **kwargs)).run(question)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -14,6 +14,22 @@ from metagpt.actions.action_node import ActionNode
|
|||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
class Chapter(BaseModel):
|
||||
name: str = Field(default="Chapter 1", description="The name of the chapter.")
|
||||
content: str = Field(default="...", description="The content of the chapter. No more than 1000 words.")
|
||||
|
||||
|
||||
class Chapters(BaseModel):
|
||||
chapters: List[Chapter] = Field(
|
||||
default=[
|
||||
{"name": "Chapter 1", "content": "..."},
|
||||
{"name": "Chapter 2", "content": "..."},
|
||||
{"name": "Chapter 3", "content": "..."},
|
||||
],
|
||||
description="The chapters of the novel.",
|
||||
)
|
||||
|
||||
|
||||
class Novel(BaseModel):
|
||||
name: str = Field(default="The Lord of the Rings", description="The name of the novel.")
|
||||
user_group: str = Field(default="...", description="The user group of the novel.")
|
||||
|
|
@ -28,22 +44,17 @@ class Novel(BaseModel):
|
|||
ending: str = Field(default="...", description="The ending of the novel.")
|
||||
|
||||
|
||||
class Chapter(BaseModel):
|
||||
name: str = Field(default="Chapter 1", description="The name of the chapter.")
|
||||
content: str = Field(default="...", description="The content of the chapter. No more than 1000 words.")
|
||||
|
||||
|
||||
async def generate_novel():
|
||||
instruction = (
|
||||
"Write a novel named 'Harry Potter in The Lord of the Rings'. "
|
||||
"Write a novel named 'Reborn in Skyrim'. "
|
||||
"Fill the empty nodes with your own ideas. Be creative! Use your own words!"
|
||||
"I will tip you $100,000 if you write a good novel."
|
||||
)
|
||||
novel_node = await ActionNode.from_pydantic(Novel).fill(context=instruction, llm=LLM())
|
||||
chap_node = await ActionNode.from_pydantic(Chapter).fill(
|
||||
chap_node = await ActionNode.from_pydantic(Chapters).fill(
|
||||
context=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM()
|
||||
)
|
||||
print(chap_node.content)
|
||||
print(chap_node.instruct_content)
|
||||
|
||||
|
||||
asyncio.run(generate_novel())
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ from metagpt.actions.write_code_review import WriteCodeReview
|
|||
from metagpt.actions.write_prd import WritePRD
|
||||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.actions.mi.write_plan import WritePlan
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.di.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.actions.di.write_plan import WritePlan
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
||||
IMPLEMENTATION_APPROACH = ActionNode(
|
||||
|
|
@ -109,14 +108,3 @@ REFINED_NODES = [
|
|||
|
||||
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
|
||||
REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = DESIGN_API_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_DESIGN_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from __future__ import annotations
|
|||
import json
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.prompts.mi.write_analysis_code import (
|
||||
from metagpt.prompts.di.write_analysis_code import (
|
||||
CHECK_DATA_PROMPT,
|
||||
DEBUG_REFLECTION_EXAMPLE,
|
||||
INTERPRETER_SYSTEM_MSG,
|
||||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
|
||||
REQUIRED_PYTHON_PACKAGES = ActionNode(
|
||||
key="Required Python packages",
|
||||
|
|
@ -119,14 +118,3 @@ REFINED_NODES = [
|
|||
|
||||
PM_NODE = ActionNode.from_children("PM_NODE", NODES)
|
||||
REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = PM_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_PM_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : rebuild_class_view.py
|
||||
@Desc : Rebuild class view info
|
||||
@Desc : Reconstructs class diagram from a source code project.
|
||||
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
|
||||
"""
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import aiofiles
|
||||
|
||||
|
|
@ -21,86 +23,144 @@ from metagpt.const import (
|
|||
GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, RepoParser
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
"""
|
||||
Reconstructs a graph repository about class diagram from a source code project.
|
||||
|
||||
Attributes:
|
||||
graph_db (Optional[GraphRepository]): The optional graph repository.
|
||||
"""
|
||||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
Args:
|
||||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
format (str): The format for the prompt schema.
|
||||
"""
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=Path(self.i_context))
|
||||
# use pylint
|
||||
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
|
||||
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
|
||||
await GraphRepository.rebuild_composition_relationship(self.graph_db)
|
||||
# use ast
|
||||
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for file_info in symbols:
|
||||
# Align to the same root directory in accordance with `class_views`.
|
||||
file_info.file = self._align_root(file_info.file, direction, diff_path)
|
||||
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
|
||||
await self._create_mermaid_class_views(graph_db=graph_db)
|
||||
await graph_db.save()
|
||||
await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info)
|
||||
await self._create_mermaid_class_views()
|
||||
await self.graph_db.save()
|
||||
|
||||
async def _create_mermaid_class_views(self, graph_db):
|
||||
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
|
||||
async def _create_mermaid_class_views(self) -> str:
|
||||
"""Creates a Mermaid class diagram using data from the `graph_db` graph repository.
|
||||
|
||||
This method utilizes information stored in the graph repository to generate a Mermaid class diagram.
|
||||
Returns:
|
||||
mermaid class diagram file name.
|
||||
"""
|
||||
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = path / self.context.git_repo.workdir.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
|
||||
filename = str(pathname.with_suffix(".class_diagram.mmd"))
|
||||
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
|
||||
content = "classDiagram\n"
|
||||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
# class names
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
class_distinct = set()
|
||||
relationship_distinct = set()
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
|
||||
content = await self._create_mermaid_class(r.subject)
|
||||
if content:
|
||||
await writer.write(content)
|
||||
class_distinct.add(r.subject)
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
|
||||
content, distinct = await self._create_mermaid_relationship(r.subject)
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
relationship_distinct.update(distinct)
|
||||
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
|
||||
if self.i_context:
|
||||
r_filename = Path(filename).relative_to(self.context.git_repo.workdir)
|
||||
await self.graph_db.insert(
|
||||
subject=self.i_context, predicate="hasMermaidClassDiagramFile", object_=str(r_filename)
|
||||
)
|
||||
logger.info(f"{self.i_context} hasMermaidClassDiagramFile {filename}")
|
||||
return filename
|
||||
|
||||
async def _create_mermaid_class(self, ns_class_name) -> str:
|
||||
"""Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created.
|
||||
|
||||
Returns:
|
||||
str: A Mermaid code block object in markdown representing the class diagram.
|
||||
"""
|
||||
fields = split_namespace(ns_class_name)
|
||||
if len(fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
return ""
|
||||
|
||||
class_view = ClassView(name=fields[1])
|
||||
rows = await graph_db.select(subject=ns_class_name)
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
|
||||
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
|
||||
attribute = ClassAttribute(
|
||||
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
|
||||
)
|
||||
class_view.attributes.append(attribute)
|
||||
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
|
||||
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
|
||||
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
|
||||
class_view.methods.append(method)
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
|
||||
if not rows:
|
||||
return ""
|
||||
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
|
||||
class_view = UMLClassView.load_dot_class_info(dot_class_info)
|
||||
|
||||
# update graph db
|
||||
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
# update uml view
|
||||
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
# update uml isCompositeOf
|
||||
for c in dot_class_info.compositions:
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name,
|
||||
predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF,
|
||||
object_=concat_namespace("?", c),
|
||||
)
|
||||
|
||||
# update uml isAggregateOf
|
||||
for a in dot_class_info.aggregations:
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name,
|
||||
predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF,
|
||||
object_=concat_namespace("?", a),
|
||||
)
|
||||
|
||||
content = class_view.get_mermaid(align=1)
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
distinct.add(ns_class_name)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
|
||||
async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[Optional[str], Optional[Set]]:
|
||||
"""Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication.
|
||||
"""
|
||||
s_fields = split_namespace(ns_class_name)
|
||||
if len(s_fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
return None, None
|
||||
|
||||
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
|
||||
mappings = {
|
||||
|
|
@ -109,8 +169,9 @@ class RebuildClassView(Action):
|
|||
AGGREGATION: " o-- ",
|
||||
}
|
||||
content = ""
|
||||
distinct = set()
|
||||
for p, v in predicates.items():
|
||||
rows = await graph_db.select(subject=ns_class_name, predicate=p)
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=p)
|
||||
for r in rows:
|
||||
o_fields = split_namespace(r.object_)
|
||||
if len(o_fields) > 2:
|
||||
|
|
@ -121,86 +182,26 @@ class RebuildClassView(Action):
|
|||
distinct.add(link)
|
||||
content += f"\t{link}\n"
|
||||
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(name: str, language="python"):
|
||||
pattern = re.compile(r"<I>(.*?)<\/I>")
|
||||
result = re.search(pattern, name)
|
||||
|
||||
abstraction = ""
|
||||
if result:
|
||||
name = result.group(1)
|
||||
abstraction = "*"
|
||||
if name.startswith("__"):
|
||||
visibility = "-"
|
||||
elif name.startswith("_"):
|
||||
visibility = "#"
|
||||
else:
|
||||
visibility = "+"
|
||||
return name, visibility, abstraction
|
||||
|
||||
@staticmethod
|
||||
async def _parse_variable_type(ns_name, graph_db) -> str:
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
|
||||
if not rows:
|
||||
return ""
|
||||
vals = rows[0].object_.replace("'", "").split(":")
|
||||
if len(vals) == 1:
|
||||
return ""
|
||||
val = vals[-1].strip()
|
||||
return "" if val == "NoneType" else val + " "
|
||||
|
||||
@staticmethod
|
||||
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
|
||||
if not rows:
|
||||
return
|
||||
info = rows[0].object_.replace("'", "")
|
||||
|
||||
fs_tag = "("
|
||||
ix = info.find(fs_tag)
|
||||
fe_tag = "):"
|
||||
eix = info.rfind(fe_tag)
|
||||
if eix < 0:
|
||||
fe_tag = ")"
|
||||
eix = info.rfind(fe_tag)
|
||||
args_info = info[ix + len(fs_tag) : eix].strip()
|
||||
method.return_type = info[eix + len(fe_tag) :].strip()
|
||||
if method.return_type == "None":
|
||||
method.return_type = ""
|
||||
if "(" in method.return_type:
|
||||
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
|
||||
|
||||
# parse args
|
||||
if not args_info:
|
||||
return
|
||||
splitter_ixs = []
|
||||
cost = 0
|
||||
for i in range(len(args_info)):
|
||||
if args_info[i] == "[":
|
||||
cost += 1
|
||||
elif args_info[i] == "]":
|
||||
cost -= 1
|
||||
if args_info[i] == "," and cost == 0:
|
||||
splitter_ixs.append(i)
|
||||
splitter_ixs.append(len(args_info))
|
||||
args = []
|
||||
ix = 0
|
||||
for eix in splitter_ixs:
|
||||
args.append(args_info[ix:eix])
|
||||
ix = eix + 1
|
||||
for arg in args:
|
||||
parts = arg.strip().split(":")
|
||||
if len(parts) == 1:
|
||||
method.args.append(ClassAttribute(name=parts[0].strip()))
|
||||
continue
|
||||
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
|
||||
return content, distinct
|
||||
|
||||
@staticmethod
|
||||
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
|
||||
"""Returns the difference between the root path and the path information represented in the package name.
|
||||
|
||||
Args:
|
||||
path_root (Path): The root path.
|
||||
package_root (Path): The package root path.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part.
|
||||
|
||||
Example:
|
||||
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
|
||||
"-", "metagpt"
|
||||
|
||||
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT/metagpt"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
|
||||
"=", "."
|
||||
"""
|
||||
if len(str(path_root)) > len(str(package_root)):
|
||||
return "+", str(path_root.relative_to(package_root))
|
||||
if len(str(path_root)) < len(str(package_root)):
|
||||
|
|
@ -208,7 +209,24 @@ class RebuildClassView(Action):
|
|||
return "=", "."
|
||||
|
||||
@staticmethod
|
||||
def _align_root(path: str, direction: str, diff_path: str):
|
||||
def _align_root(path: str, direction: str, diff_path: str) -> str:
|
||||
"""Aligns the path to the same root represented by `diff_path`.
|
||||
|
||||
Args:
|
||||
path (str): The path to be aligned.
|
||||
direction (str): The direction of alignment ('+', '-', '=').
|
||||
diff_path (str): The path representing the difference.
|
||||
|
||||
Returns:
|
||||
str: The aligned path.
|
||||
|
||||
Example:
|
||||
>>> _align_root(path="metagpt/software_company.py", direction="+", diff_path="MetaGPT")
|
||||
"MetaGPT/metagpt/software_company.py"
|
||||
|
||||
>>> _align_root(path="metagpt/software_company.py", direction="-", diff_path="metagpt")
|
||||
"software_company.py"
|
||||
"""
|
||||
if direction == "=":
|
||||
return path
|
||||
if direction == "+":
|
||||
|
|
|
|||
|
|
@ -4,34 +4,214 @@
|
|||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : rebuild_sequence_view.py
|
||||
@Desc : Rebuild sequence view info
|
||||
@Desc : Reconstruct sequence view information through reverse engineering.
|
||||
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, list_files
|
||||
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import (
|
||||
add_affix,
|
||||
aread,
|
||||
auto_namespace,
|
||||
concat_namespace,
|
||||
general_after_log,
|
||||
list_files,
|
||||
parse_json_code_block,
|
||||
read_file_block,
|
||||
split_namespace,
|
||||
)
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword
|
||||
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class ReverseUseCase(BaseModel):
|
||||
"""
|
||||
Represents a reverse engineered use case.
|
||||
|
||||
Attributes:
|
||||
description (str): A description of the reverse use case.
|
||||
inputs (List[str]): List of inputs for the reverse use case.
|
||||
outputs (List[str]): List of outputs for the reverse use case.
|
||||
actors (List[str]): List of actors involved in the reverse use case.
|
||||
steps (List[str]): List of steps for the reverse use case.
|
||||
reason (str): The reason behind the reverse use case.
|
||||
"""
|
||||
|
||||
description: str
|
||||
inputs: List[str]
|
||||
outputs: List[str]
|
||||
actors: List[str]
|
||||
steps: List[str]
|
||||
reason: str
|
||||
|
||||
|
||||
class ReverseUseCaseDetails(BaseModel):
|
||||
"""
|
||||
Represents details of a reverse engineered use case.
|
||||
|
||||
Attributes:
|
||||
description (str): A description of the reverse use case details.
|
||||
use_cases (List[ReverseUseCase]): List of reverse use cases.
|
||||
relationship (List[str]): List of relationships associated with the reverse use case details.
|
||||
"""
|
||||
|
||||
description: str
|
||||
use_cases: List[ReverseUseCase]
|
||||
relationship: List[str]
|
||||
|
||||
|
||||
class RebuildSequenceView(Action):
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
entries = await RebuildSequenceView._search_main_entry(graph_db)
|
||||
for entry in entries:
|
||||
await self._rebuild_sequence_view(entry, graph_db)
|
||||
await graph_db.save()
|
||||
"""
|
||||
Represents an action to reconstruct sequence view through reverse engineering.
|
||||
|
||||
@staticmethod
|
||||
async def _search_main_entry(graph_db) -> List:
|
||||
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
Attributes:
|
||||
graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations.
|
||||
"""
|
||||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
Args:
|
||||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
format (str): The format for the prompt schema.
|
||||
"""
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
if not self.i_context:
|
||||
entries = await self._search_main_entry()
|
||||
else:
|
||||
entries = [SPO(subject=self.i_context, predicate="", object_="")]
|
||||
for entry in entries:
|
||||
await self._rebuild_main_sequence_view(entry)
|
||||
while await self._merge_sequence_view(entry):
|
||||
pass
|
||||
await self.graph_db.save()
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _rebuild_main_sequence_view(self, entry: SPO):
|
||||
"""
|
||||
Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the
|
||||
subject `__name__:__main__`.
|
||||
"""
|
||||
filename = entry.subject.split(":", 1)[0]
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
classes = []
|
||||
prefix = filename + ":"
|
||||
for r in rows:
|
||||
if prefix in r.subject:
|
||||
classes.append(r)
|
||||
await self._rebuild_use_case(r.subject)
|
||||
participants = await self._search_participants(split_namespace(entry.subject)[0])
|
||||
class_details = []
|
||||
class_views = []
|
||||
for c in classes:
|
||||
detail = await self._get_class_detail(c.subject)
|
||||
if not detail:
|
||||
continue
|
||||
class_details.append(detail)
|
||||
view = await self._get_uml_class_view(c.subject)
|
||||
if view:
|
||||
class_views.append(view)
|
||||
|
||||
actors = await self._get_participants(c.subject)
|
||||
participants.update(set(actors))
|
||||
|
||||
use_case_blocks = []
|
||||
for c in classes:
|
||||
use_cases = await self._get_class_use_cases(c.subject)
|
||||
use_case_blocks.append(use_cases)
|
||||
prompt_blocks = ["## Use Cases\n" + "\n".join(use_case_blocks)]
|
||||
block = "## Participants\n"
|
||||
for p in participants:
|
||||
block += f"- {p}\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Mermaid Class Views\n```mermaid\n"
|
||||
block += "\n\n".join([c.get_mermaid() for c in class_views])
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Source Code\n```python\n"
|
||||
block += await self._get_source_code(filename)
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
prompt = "\n---\n".join(prompt_blocks)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=[
|
||||
"You are a python code to Mermaid Sequence Diagram translator in function detail.",
|
||||
"Translate the given markdown text to a Mermaid Sequence Diagram.",
|
||||
"Return the merged Mermaid sequence diagram in a markdown code block format.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
|
||||
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject,
|
||||
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
|
||||
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
|
||||
)
|
||||
for c in classes:
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
|
||||
)
|
||||
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
|
||||
|
||||
async def _merge_sequence_view(self, entry: SPO) -> bool:
|
||||
"""
|
||||
Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO object representing the relationship in the graph database.
|
||||
|
||||
Returns:
|
||||
bool: True if additional information has been augmented, otherwise False.
|
||||
"""
|
||||
new_participant = await self._search_new_participant(entry)
|
||||
if not new_participant:
|
||||
return False
|
||||
|
||||
await self._merge_participant(entry, new_participant)
|
||||
return True
|
||||
|
||||
async def _search_main_entry(self) -> List:
|
||||
"""
|
||||
Asynchronously searches for the SPO object that is related to `__name__:__main__`.
|
||||
|
||||
Returns:
|
||||
List: A list containing information about the main entry in the sequence diagram.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
tag = "__name__:__main__"
|
||||
entries = []
|
||||
for r in rows:
|
||||
|
|
@ -39,24 +219,395 @@ class RebuildSequenceView(Action):
|
|||
entries.append(r)
|
||||
return entries
|
||||
|
||||
async def _rebuild_sequence_view(self, entry, graph_db):
|
||||
filename = entry.subject.split(":", 1)[0]
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _rebuild_use_case(self, ns_class_name: str):
|
||||
"""
|
||||
Asynchronously reconstructs the use case for the provided namespace-prefixed class name.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
|
||||
if rows:
|
||||
return
|
||||
content = await aread(filename=src_filename, encoding="utf-8")
|
||||
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
|
||||
data = await self.llm.aask(
|
||||
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
|
||||
|
||||
detail = await self._get_class_detail(ns_class_name)
|
||||
if not detail:
|
||||
return
|
||||
participants = set()
|
||||
participants.update(set(detail.compositions))
|
||||
participants.update(set(detail.aggregations))
|
||||
class_view = await self._get_uml_class_view(ns_class_name)
|
||||
source_code = await self._get_source_code(ns_class_name)
|
||||
|
||||
# prompt_blocks = [
|
||||
# "## Instruction\n"
|
||||
# "You are a python code to UML 2.0 Use Case translator.\n"
|
||||
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
|
||||
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
|
||||
# 'conflict with the information in "Mermaid Class Views".\n'
|
||||
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
|
||||
# "system interactions with the internal system.\n"
|
||||
# ]
|
||||
prompt_blocks = []
|
||||
block = "## Participants\n"
|
||||
for p in participants:
|
||||
block += f"- {p}\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Mermaid Class Views\n```mermaid\n"
|
||||
block += class_view.get_mermaid()
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Source Code\n```python\n"
|
||||
block += source_code
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
prompt = "\n---\n".join(prompt_blocks)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=[
|
||||
"You are a python code to UML 2.0 Use Case translator.",
|
||||
'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".',
|
||||
"The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
|
||||
'conflict with the information in "Mermaid Class Views".',
|
||||
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
|
||||
"system interactions with the internal system.",
|
||||
"Return a markdown JSON object with:\n"
|
||||
'- a "description" key to explain what the whole source code want to do;\n'
|
||||
'- a "use_cases" key list all use cases, each use case in the list should including a `description` '
|
||||
"key describes about what the use case to do, a `inputs` key lists the input names of the use case "
|
||||
"from external sources, a `outputs` key lists the output names of the use case to external sources, "
|
||||
"a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how "
|
||||
"the use case works step by step, a `reason` key explaining under what circumstances would the "
|
||||
"external system execute this use case.\n"
|
||||
'- a "relationship" key lists all the descriptions of relationship among these use cases.\n',
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
code_blocks = parse_json_code_block(rsp)
|
||||
for block in code_blocks:
|
||||
detail = ReverseUseCaseDetails.model_validate_json(block)
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
|
||||
)
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _rebuild_sequence_view(self, ns_class_name: str):
|
||||
"""
|
||||
Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed.
|
||||
"""
|
||||
await self._rebuild_use_case(ns_class_name)
|
||||
|
||||
prompts_blocks = []
|
||||
use_case_markdown = await self._get_class_use_cases(ns_class_name)
|
||||
if not use_case_markdown: # external class
|
||||
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
|
||||
return
|
||||
block = f"## Use Cases\n{use_case_markdown}"
|
||||
prompts_blocks.append(block)
|
||||
|
||||
participants = await self._get_participants(ns_class_name)
|
||||
block = "## Participants\n" + "\n".join([f"- {s}" for s in participants])
|
||||
prompts_blocks.append(block)
|
||||
|
||||
view = await self._get_uml_class_view(ns_class_name)
|
||||
block = "## Mermaid Class Views\n```mermaid\n"
|
||||
block += view.get_mermaid()
|
||||
block += "\n```\n"
|
||||
prompts_blocks.append(block)
|
||||
|
||||
block = "## Source Code\n```python\n"
|
||||
block += await self._get_source_code(ns_class_name)
|
||||
block += "\n```\n"
|
||||
prompts_blocks.append(block)
|
||||
prompt = "\n---\n".join(prompts_blocks)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
"You are a Mermaid Sequence Diagram translator in function detail.",
|
||||
"Translate the markdown text to a Mermaid Sequence Diagram.",
|
||||
"Return a markdown mermaid code block.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
|
||||
async def _get_participants(self, ns_class_name: str) -> List[str]:
|
||||
"""
|
||||
Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO
|
||||
object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of participants in the sequence diagram.
|
||||
"""
|
||||
participants = set()
|
||||
detail = await self._get_class_detail(ns_class_name)
|
||||
if not detail:
|
||||
return []
|
||||
participants.update(set(detail.compositions))
|
||||
participants.update(set(detail.aggregations))
|
||||
return list(participants)
|
||||
|
||||
async def _get_class_use_cases(self, ns_class_name: str) -> str:
|
||||
"""
|
||||
Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information.
|
||||
|
||||
Returns:
|
||||
str: A string containing the assembled context about the use case information.
|
||||
"""
|
||||
block = ""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
|
||||
for i, r in enumerate(rows):
|
||||
detail = ReverseUseCaseDetails.model_validate_json(r.object_)
|
||||
block += f"\n### {i + 1}. {detail.description}"
|
||||
for j, use_case in enumerate(detail.use_cases):
|
||||
block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n"
|
||||
block += "\n##### Inputs\n" + "\n".join([f"- {s}" for s in use_case.inputs])
|
||||
block += "\n##### Outputs\n" + "\n".join([f"- {s}" for s in use_case.outputs])
|
||||
block += "\n##### Actors\n" + "\n".join([f"- {s}" for s in use_case.actors])
|
||||
block += "\n##### Steps\n" + "\n".join([f"- {s}" for s in use_case.steps])
|
||||
block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship])
|
||||
return block + "\n"
|
||||
|
||||
async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None:
|
||||
"""
|
||||
Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve class details.
|
||||
|
||||
Returns:
|
||||
Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details,
|
||||
or None if the details are not available.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
|
||||
if not rows:
|
||||
return None
|
||||
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
|
||||
return dot_class_info
|
||||
|
||||
async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None:
|
||||
"""
|
||||
Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details.
|
||||
|
||||
Returns:
|
||||
Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details,
|
||||
or None if the details are not available.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
if not rows:
|
||||
return None
|
||||
class_view = UMLClassView.model_validate_json(rows[0].object_)
|
||||
return class_view
|
||||
|
||||
async def _get_source_code(self, ns_class_name: str) -> str:
|
||||
"""
|
||||
Asynchronously retrieves the source code of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code.
|
||||
|
||||
Returns:
|
||||
str: A string containing the source code of the specified namespace-prefixed class.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
filename = split_namespace(ns_class_name=ns_class_name)[0]
|
||||
if not rows:
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
return ""
|
||||
return await aread(filename=src_filename, encoding="utf-8")
|
||||
code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_)
|
||||
return await read_file_block(
|
||||
filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno
|
||||
)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
|
||||
logger.info(data)
|
||||
|
||||
@staticmethod
|
||||
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
"""
|
||||
Convert package name to the full path of the module.
|
||||
|
||||
Args:
|
||||
root (Union[str, Path]): The root path or string representing the package.
|
||||
pathname (Union[str, Path]): The pathname or string representing the module.
|
||||
|
||||
Returns:
|
||||
Union[Path, None]: The full path of the module, or None if the path cannot be determined.
|
||||
|
||||
Examples:
|
||||
If `root`(workdir) is "/User/xxx/github/MetaGPT/metagpt", and the `pathname` is
|
||||
"metagpt/management/skill_manager.py", then the returned value will be
|
||||
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
|
||||
"""
|
||||
if re.match(r"^/.+", pathname):
|
||||
return pathname
|
||||
files = list_files(root=root)
|
||||
postfix = "/" + str(pathname)
|
||||
for i in files:
|
||||
if str(i).endswith(postfix):
|
||||
return i
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
|
||||
"""
|
||||
Parses the provided Mermaid sequence diagram and returns the list of participants.
|
||||
|
||||
Args:
|
||||
mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of participants extracted from the sequence diagram.
|
||||
"""
|
||||
pattern = r"participant ([a-zA-Z\.0-9_]+)"
|
||||
matches = re.findall(pattern, mermaid_sequence_diagram)
|
||||
matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches]
|
||||
return matches
|
||||
|
||||
async def _search_new_participant(self, entry: SPO) -> str | None:
|
||||
"""
|
||||
Asynchronously retrieves a participant whose sequence diagram has not been augmented.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO object representing the relationship in the graph database.
|
||||
|
||||
Returns:
|
||||
Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
if not rows:
|
||||
return None
|
||||
sequence_view = rows[0].object_
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
|
||||
merged_participants = []
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
merged_participants.append(name)
|
||||
participants = self.parse_participant(sequence_view)
|
||||
for p in participants:
|
||||
if p in merged_participants:
|
||||
continue
|
||||
return p
|
||||
return None
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _merge_participant(self, entry: SPO, class_name: str):
|
||||
"""
|
||||
Augments the sequence diagram of `class_name` to the sequence diagram of `entry`.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO object representing the base sequence diagram.
|
||||
class_name (str): The class name whose sequence diagram is to be augmented.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
participants = []
|
||||
for r in rows:
|
||||
name = split_namespace(r.subject)[-1]
|
||||
if name == class_name:
|
||||
participants.append(r)
|
||||
if len(participants) == 0: # external participants
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
|
||||
)
|
||||
return
|
||||
if len(participants) > 1:
|
||||
for r in participants:
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
|
||||
)
|
||||
return
|
||||
|
||||
participant = participants[0]
|
||||
await self._rebuild_sequence_view(participant.subject)
|
||||
sequence_views = await self.graph_db.select(
|
||||
subject=participant.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW
|
||||
)
|
||||
if not sequence_views: # external class
|
||||
return
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
prompt = f"```mermaid\n{sequence_views[0].object_}\n```\n---\n```mermaid\n{rows[0].object_}\n```"
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
"You are a tool to merge sequence diagrams into one.",
|
||||
"Participants with the same name are considered identical.",
|
||||
"Return the merged Mermaid sequence diagram in a markdown code block format.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject,
|
||||
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
|
||||
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
|
||||
)
|
||||
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
|
||||
|
||||
async def _save_sequence_view(self, subject: str, content: str):
|
||||
pattern = re.compile(r"[^a-zA-Z0-9]")
|
||||
name = re.sub(pattern, "_", subject)
|
||||
filename = Path(name).with_suffix(".sequence_diagram.mmd")
|
||||
await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)
|
||||
|
||||
async def _search_participants(self, filename: str) -> Set:
|
||||
content = await self._get_source_code(filename)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
msg=content,
|
||||
system_msgs=[
|
||||
"You are a tool for listing all class names used in a source file.",
|
||||
"Return a markdown JSON object with: "
|
||||
'- a "class_names" key containing the list of class names used in the file; '
|
||||
'- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
|
||||
],
|
||||
)
|
||||
|
||||
class _Data(BaseModel):
|
||||
class_names: List[str]
|
||||
reasons: List
|
||||
|
||||
json_blocks = parse_json_code_block(rsp)
|
||||
data = _Data.model_validate_json(json_blocks[0])
|
||||
return set(data.class_names)
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : rebuild_sequence_view_an.py
|
||||
"""
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.utils.mermaid import MMC2
|
||||
|
||||
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
|
||||
key="Program call flow",
|
||||
expected_type=str,
|
||||
instruction='Translate the "context" content into "format example" format.',
|
||||
example=MMC2,
|
||||
)
|
||||
|
|
@ -50,6 +50,7 @@ class ArgumentsParingAction(Action):
|
|||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
|
||||
stream=False,
|
||||
)
|
||||
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
|
||||
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class TalkAction(Action):
|
|||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
msg, format_msgs, system_msgs = self.aask_args
|
||||
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
|
||||
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False)
|
||||
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
||||
from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -98,8 +94,6 @@ class WriteCode(Action):
|
|||
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
|
||||
coding_context = CodingContext.loads(self.i_context.content)
|
||||
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
|
||||
code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
|
||||
code_plan_and_change = code_plan_and_change_doc.content if code_plan_and_change_doc else ""
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
summary_doc = None
|
||||
if coding_context.design_doc and coding_context.design_doc.filename:
|
||||
|
|
@ -111,7 +105,7 @@ class WriteCode(Action):
|
|||
|
||||
if bug_feedback:
|
||||
code_context = coding_context.code_doc.content
|
||||
elif code_plan_and_change:
|
||||
elif self.config.inc:
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
|
||||
)
|
||||
|
|
@ -122,10 +116,10 @@ class WriteCode(Action):
|
|||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
)
|
||||
|
||||
if code_plan_and_change:
|
||||
if self.config.inc:
|
||||
prompt = REFINED_TEMPLATE.format(
|
||||
user_requirement=requirement_doc.content if requirement_doc else "",
|
||||
code_plan_and_change=code_plan_and_change,
|
||||
code_plan_and_change=str(coding_context.code_plan_and_change_doc),
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
task=coding_context.task_doc.content if coding_context.task_doc else "",
|
||||
code=code_context,
|
||||
|
|
|
|||
|
|
@ -6,30 +6,44 @@
|
|||
@File : write_code_plan_and_change_an.py
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodePlanAndChangeContext
|
||||
|
||||
CODE_PLAN_AND_CHANGE = ActionNode(
|
||||
key="Code Plan And Change",
|
||||
expected_type=str,
|
||||
instruction="Developing comprehensive and step-by-step incremental development plan, and write Incremental "
|
||||
"Change by making a code draft that how to implement incremental development including detailed steps based on the "
|
||||
"context. Note: Track incremental changes using mark of '+' or '-' for add/modify/delete code, and conforms to the "
|
||||
"output format of git diff",
|
||||
example="""
|
||||
1. Plan for calculator.py: Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, multiplication, and division. Additionally, implement robust error handling for the division operation to mitigate potential issues related to division by zero.
|
||||
```python
|
||||
DEVELOPMENT_PLAN = ActionNode(
|
||||
key="Development Plan",
|
||||
expected_type=List[str],
|
||||
instruction="Develop a comprehensive and step-by-step incremental development plan, providing the detail "
|
||||
"changes to be implemented at each step based on the order of 'Task List'",
|
||||
example=[
|
||||
"Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, ...",
|
||||
"Update the existing codebase in main.py to incorporate new API endpoints for subtraction, ...",
|
||||
],
|
||||
)
|
||||
|
||||
INCREMENTAL_CHANGE = ActionNode(
|
||||
key="Incremental Change",
|
||||
expected_type=List[str],
|
||||
instruction="Write Incremental Change by making a code draft that how to implement incremental development "
|
||||
"including detailed steps based on the context. Note: Track incremental changes using the marks `+` and `-` to "
|
||||
"indicate additions and deletions, and ensure compliance with the output format of `git diff`",
|
||||
example=[
|
||||
'''```diff
|
||||
--- Old/calculator.py
|
||||
+++ New/calculator.py
|
||||
|
||||
class Calculator:
|
||||
self.result = number1 + number2
|
||||
return self.result
|
||||
|
||||
- def sub(self, number1, number2) -> float:
|
||||
+ def subtract(self, number1: float, number2: float) -> float:
|
||||
+ '''
|
||||
+ """
|
||||
+ Subtracts the second number from the first and returns the result.
|
||||
+
|
||||
+ Args:
|
||||
|
|
@ -38,13 +52,13 @@ class Calculator:
|
|||
+
|
||||
+ Returns:
|
||||
+ float: The difference of number1 and number2.
|
||||
+ '''
|
||||
+ """
|
||||
+ self.result = number1 - number2
|
||||
+ return self.result
|
||||
+
|
||||
def multiply(self, number1: float, number2: float) -> float:
|
||||
- pass
|
||||
+ '''
|
||||
+ """
|
||||
+ Multiplies two numbers and returns the result.
|
||||
+
|
||||
+ Args:
|
||||
|
|
@ -53,15 +67,15 @@ class Calculator:
|
|||
+
|
||||
+ Returns:
|
||||
+ float: The product of number1 and number2.
|
||||
+ '''
|
||||
+ """
|
||||
+ self.result = number1 * number2
|
||||
+ return self.result
|
||||
+
|
||||
def divide(self, number1: float, number2: float) -> float:
|
||||
- pass
|
||||
+ '''
|
||||
+ """
|
||||
+ ValueError: If the second number is zero.
|
||||
+ '''
|
||||
+ """
|
||||
+ if number2 == 0:
|
||||
+ raise ValueError('Cannot divide by zero')
|
||||
+ self.result = number1 / number2
|
||||
|
|
@ -75,10 +89,11 @@ class Calculator:
|
|||
+ print("Result is already zero, no need to clear.")
|
||||
+
|
||||
self.result = 0.0
|
||||
```
|
||||
```''',
|
||||
"""```diff
|
||||
--- Old/main.py
|
||||
+++ New/main.py
|
||||
|
||||
2. Plan for main.py: Integrate new API endpoints for subtraction, multiplication, and division into the existing codebase of `main.py`. Then, ensure seamless integration with the overall application architecture and maintain consistency with coding standards.
|
||||
```python
|
||||
def add_numbers():
|
||||
result = calculator.add_numbers(num1, num2)
|
||||
return jsonify({'result': result}), 200
|
||||
|
|
@ -106,6 +121,7 @@ def add_numbers():
|
|||
if __name__ == '__main__':
|
||||
app.run()
|
||||
```""",
|
||||
],
|
||||
)
|
||||
|
||||
CODE_PLAN_AND_CHANGE_CONTEXT = """
|
||||
|
|
@ -172,14 +188,16 @@ Role: You are a professional engineer; The main goal is to complete incremental
|
|||
2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.
|
||||
3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import.
|
||||
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
|
||||
5. Follow Code Plan And Change: If there is any Incremental Change that is marked by the git diff format using '+' and '-' for add/modify/delete code, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the plan.
|
||||
5. Follow Code Plan And Change: If there is any "Incremental Change" that is marked by the git diff format with '+' and '-' symbols, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the "Development Plan".
|
||||
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
|
||||
7. Before using a external variable/module, make sure you import it first.
|
||||
8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.
|
||||
9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code.
|
||||
"""
|
||||
|
||||
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", [CODE_PLAN_AND_CHANGE])
|
||||
CODE_PLAN_AND_CHANGE = [DEVELOPMENT_PLAN, INCREMENTAL_CHANGE]
|
||||
|
||||
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", CODE_PLAN_AND_CHANGE)
|
||||
|
||||
|
||||
class WriteCodePlanAndChange(Action):
|
||||
|
|
@ -192,14 +210,14 @@ class WriteCodePlanAndChange(Action):
|
|||
prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename)
|
||||
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
|
||||
code_text = await self.get_old_codes()
|
||||
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
|
||||
requirement=self.i_context.requirement,
|
||||
prd=prd_doc.content,
|
||||
design=design_doc.content,
|
||||
task=task_doc.content,
|
||||
code=code_text,
|
||||
code=await self.get_old_codes(),
|
||||
)
|
||||
logger.info("Writing code plan and change..")
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
||||
async def get_old_codes(self) -> str:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.const import CODE_PLAN_AND_CHANGE_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -149,29 +149,21 @@ class WriteCodeReview(Action):
|
|||
use_inc=self.config.inc,
|
||||
)
|
||||
|
||||
if not self.config.inc:
|
||||
context = "\n".join(
|
||||
[
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
)
|
||||
else:
|
||||
ctx_list = [
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
if self.config.inc:
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
|
||||
context = "\n".join(
|
||||
[
|
||||
"## User New Requirements\n" + str(requirement_doc) + "\n",
|
||||
"## Code Plan And Change\n" + str(code_plan_and_change_doc) + "\n",
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
)
|
||||
insert_ctx_list = [
|
||||
"## User New Requirements\n" + str(requirement_doc) + "\n",
|
||||
"## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n",
|
||||
]
|
||||
ctx_list = insert_ctx_list + ctx_list
|
||||
|
||||
context_prompt = PROMPT_TEMPLATE.format(
|
||||
context=context,
|
||||
context="\n".join(ctx_list),
|
||||
code=iterative_code,
|
||||
filename=self.i_context.code_doc.filename,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ REFINED_PRODUCT_GOALS = ActionNode(
|
|||
key="Refined Product Goals",
|
||||
expected_type=List[str],
|
||||
instruction="Update and expand the original product goals to reflect the evolving needs due to incremental "
|
||||
"development.Ensure that the refined goals align with the current project direction and contribute to its success.",
|
||||
"development. Ensure that the refined goals align with the current project direction and contribute to its success.",
|
||||
example=[
|
||||
"Enhance user engagement through new features",
|
||||
"Optimize performance for scalability",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
class LLMType(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
CLAUDE = "claude" # alias name of anthropic
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
|
|
@ -24,6 +25,10 @@ class LLMType(Enum):
|
|||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
|
||||
MOONSHOT = "moonshot"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -36,12 +41,18 @@ class LLMConfig(YamlModel):
|
|||
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
api_key: str = "sk-"
|
||||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters.
|
||||
|
||||
# For Cloud Service Provider like Baidu/ Alibaba
|
||||
access_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
endpoint: Optional[str] = None # for self-deployed model on the cloud
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ MESSAGE_ROUTE_TO_NONE = "<none>"
|
|||
REQUIREMENT_FILENAME = "requirement.txt"
|
||||
BUGFIX_FILENAME = "bugfix.txt"
|
||||
PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt"
|
||||
CODE_PLAN_AND_CHANGE_FILENAME = "code_plan_and_change.json"
|
||||
|
||||
DOCS_FILE_REPO = "docs"
|
||||
PRDS_FILE_REPO = "docs/prd"
|
||||
|
|
@ -105,6 +104,7 @@ CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
|
|||
RESOURCES_FILE_REPO = "resources"
|
||||
SD_OUTPUT_FILE_REPO = "resources/sd_output"
|
||||
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
|
||||
VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_view"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
|
|
|||
|
|
@ -12,10 +12,14 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import (
|
||||
CostManager,
|
||||
FireworksCostManager,
|
||||
TokenCostManager,
|
||||
)
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
|
@ -80,12 +84,21 @@ class Context(BaseModel):
|
|||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
return FireworksCostManager()
|
||||
elif llm_config.api_type == LLMType.OPEN_LLM:
|
||||
return TokenCostManager()
|
||||
else:
|
||||
return self.cost_manager
|
||||
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.llm)
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
self._llm.cost_manager = self._select_costmanager(self.config.llm)
|
||||
return self._llm
|
||||
|
||||
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
|
||||
|
|
@ -93,5 +106,5 @@ class Context(BaseModel):
|
|||
# if self._llm is None:
|
||||
llm = create_llm_instance(llm_config)
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
llm.cost_manager = self._select_costmanager(llm_config)
|
||||
return llm
|
||||
|
|
|
|||
|
|
@ -11,15 +11,16 @@ from pathlib import Path
|
|||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from langchain.document_loaders import (
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
|
||||
|
||||
|
|
@ -130,9 +131,12 @@ class IndexableDocument(Document):
|
|||
if isinstance(data, pd.DataFrame):
|
||||
validate_cols(content_col, data)
|
||||
return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col)
|
||||
else:
|
||||
try:
|
||||
content = data_path.read_text()
|
||||
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
|
||||
except Exception as e:
|
||||
logger.debug(f"Load {str(data_path)} error: {e}")
|
||||
content = ""
|
||||
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
|
||||
|
||||
def _get_docs_and_metadatas_by_df(self) -> (list, list):
|
||||
df = self.data
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class BrainMemory(BaseModel):
|
|||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
logger.debug(f"title ask:{msg}")
|
||||
response = await llm.aask(msg=msg, system_msgs=[])
|
||||
response = await llm.aask(msg=msg, system_msgs=[], stream=False)
|
||||
logger.debug(f"title rsp: {response}")
|
||||
return response
|
||||
|
||||
|
|
@ -201,11 +201,15 @@ class BrainMemory(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
async def _openai_is_related(text1, text2, llm, **kwargs):
|
||||
command = (
|
||||
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
|
||||
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
|
||||
context = f"## Paragraph 1\n{text2}\n---\n## Paragraph 2\n{text1}\n"
|
||||
rsp = await llm.aask(
|
||||
msg=context,
|
||||
system_msgs=[
|
||||
"You are a tool capable of determining whether two paragraphs are semantically related."
|
||||
'Return "TRUE" if "Paragraph 1" is semantically relevant to "Paragraph 2", otherwise return "FALSE".'
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
result = True if "TRUE" in rsp else False
|
||||
p2 = text2.replace("\n", "")
|
||||
p1 = text1.replace("\n", "")
|
||||
|
|
@ -223,12 +227,17 @@ class BrainMemory(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
async def _openai_rewrite(sentence: str, context: str, llm):
|
||||
command = (
|
||||
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
|
||||
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
|
||||
prompt = f"## Context\n{context}\n---\n## Sentence\n{sentence}\n"
|
||||
rsp = await llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=[
|
||||
'You are a tool augmenting the "Sentence" with information from the "Context".',
|
||||
"Do not supplement the context with information that is not present, especially regarding the subject and object.",
|
||||
"Return the augmented sentence.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
|
||||
logger.info(f"REWRITE:\nCommand: {prompt}\nRESULT: {rsp}\n")
|
||||
return rsp
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -293,14 +302,14 @@ class BrainMemory(BaseModel):
|
|||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
system_msgs = [
|
||||
"You are a tool for summarizing and abstracting text.",
|
||||
f"Return the summarized text to less than {max_words} words.",
|
||||
]
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
system_msgs.append("The generated summary should be in the same language as the original text.")
|
||||
response = await self.llm.aask(msg=text, system_msgs=system_msgs, stream=False)
|
||||
logger.debug(f"{text}\nsummary rsp: {response}")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
|
@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL
|
|||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
from metagpt.utils.serialize import deserialize_message, serialize_message
|
||||
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ class MemoryStorage(FaissStore):
|
|||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -6,21 +6,20 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
"GeminiLLM",
|
||||
"OpenLLM",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAILLM",
|
||||
"AzureOpenAILLM",
|
||||
|
|
@ -28,4 +27,7 @@ __all__ = [
|
|||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
"SparkLLM",
|
||||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,37 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/21 11:15
|
||||
@Author : Leo Xiao
|
||||
@File : anthropic_api.py
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message, Usage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
class Claude2:
|
||||
@register_provider([LLMType.ANTHROPIC, LLMType.CLAUDE])
|
||||
class AnthropicLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__init_anthropic()
|
||||
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=self.config.api_key)
|
||||
def __init_anthropic(self):
|
||||
self.model = self.config.model
|
||||
self.aclient: AsyncAnthropic = AsyncAnthropic(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
|
||||
res = client.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
)
|
||||
return res.completion
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.config.max_token,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=self.config.api_key)
|
||||
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
|
||||
usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens}
|
||||
super()._update_costs(usage, model)
|
||||
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
)
|
||||
return res.completion
|
||||
def get_choice_text(self, resp: Message) -> str:
|
||||
return resp.content[0].text
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
|
||||
self._update_costs(resp.usage, self.model)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = Usage(input_tokens=0, output_tokens=0)
|
||||
async for event in stream:
|
||||
event_type = event.type
|
||||
if event_type == "message_start":
|
||||
usage.input_tokens = event.message.usage.input_tokens
|
||||
usage.output_tokens = event.message.usage.output_tokens
|
||||
elif event_type == "content_block_delta":
|
||||
content = event.delta.text
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
elif event_type == "message_delta":
|
||||
usage.output_tokens = event.usage.output_tokens # update final output_tokens
|
||||
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
|
|
@ -27,6 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
|
|||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.aclient = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
self.pricing_plan = self.config.pricing_plan
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
|
|
|
|||
|
|
@ -6,16 +6,29 @@
|
|||
@File : base_llm.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types import CompletionUsage
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.common import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
|
|
@ -29,6 +42,7 @@ class BaseLLM(ABC):
|
|||
aclient: Optional[Union[AsyncOpenAI]] = None
|
||||
cost_manager: Optional[CostManager] = None
|
||||
model: Optional[str] = None
|
||||
pricing_plan: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
|
|
@ -67,6 +81,28 @@ class BaseLLM(ABC):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
|
||||
"""update each request's token cost
|
||||
Args:
|
||||
model (str): model name or in some scenarios called endpoint
|
||||
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
|
||||
"""
|
||||
calc_usage = self.config.calc_usage and local_calc_usage
|
||||
model = model or self.model
|
||||
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
|
||||
if calc_usage and self.cost_manager:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
|
|
@ -108,6 +144,10 @@ class BaseLLM(ABC):
|
|||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
|
|
@ -120,8 +160,22 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages, timeout=timeout)
|
||||
resp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(resp)
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
@ -171,6 +225,20 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage | Dict):
|
||||
"""
|
||||
Updates the costs based on the provided usage information.
|
||||
"""
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
if isinstance(usage, Dict):
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
else:
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
completion_tokens = usage.completion_tokens
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
227
metagpt/provider/dashscope_api.py
Normal file
227
metagpt/provider/dashscope_api.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
import dashscope
|
||||
from dashscope.aigc.generation import Generation
|
||||
from dashscope.api_entities.aiohttp_request import AioHttpRequest
|
||||
from dashscope.api_entities.api_request_data import ApiRequestData
|
||||
from dashscope.api_entities.api_request_factory import _get_protocol_params
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
GenerationOutput,
|
||||
GenerationResponse,
|
||||
Message,
|
||||
)
|
||||
from dashscope.client.base_api import BaseAioApi
|
||||
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
|
||||
from dashscope.common.error import (
|
||||
InputDataRequired,
|
||||
InputRequired,
|
||||
ModelRequired,
|
||||
UnsupportedApiProtocol,
|
||||
)
|
||||
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM, LLMConfig
|
||||
from metagpt.provider.llm_provider_registry import LLMType, register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
|
||||
|
||||
|
||||
def build_api_arequest(
|
||||
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
|
||||
):
|
||||
(
|
||||
api_protocol,
|
||||
ws_stream_mode,
|
||||
is_binary_input,
|
||||
http_method,
|
||||
stream,
|
||||
async_request,
|
||||
query,
|
||||
headers,
|
||||
request_timeout,
|
||||
form,
|
||||
resources,
|
||||
) = _get_protocol_params(kwargs)
|
||||
task_id = kwargs.pop("task_id", None)
|
||||
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
|
||||
if not dashscope.base_http_api_url.endswith("/"):
|
||||
http_url = dashscope.base_http_api_url + "/"
|
||||
else:
|
||||
http_url = dashscope.base_http_api_url
|
||||
|
||||
if is_service:
|
||||
http_url = http_url + SERVICE_API_PATH + "/"
|
||||
|
||||
if task_group:
|
||||
http_url += "%s/" % task_group
|
||||
if task:
|
||||
http_url += "%s/" % task
|
||||
if function:
|
||||
http_url += function
|
||||
request = AioHttpRequest(
|
||||
url=http_url,
|
||||
api_key=api_key,
|
||||
http_method=http_method,
|
||||
stream=stream,
|
||||
async_request=async_request,
|
||||
query=query,
|
||||
timeout=request_timeout,
|
||||
task_id=task_id,
|
||||
)
|
||||
else:
|
||||
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
|
||||
|
||||
if headers is not None:
|
||||
request.add_headers(headers=headers)
|
||||
|
||||
if input is None and form is None:
|
||||
raise InputDataRequired("There is no input data and form data")
|
||||
|
||||
request_data = ApiRequestData(
|
||||
model,
|
||||
task_group=task_group,
|
||||
task=task,
|
||||
function=function,
|
||||
input=input,
|
||||
form=form,
|
||||
is_binary_input=is_binary_input,
|
||||
api_protocol=api_protocol,
|
||||
)
|
||||
request_data.add_resources(resources)
|
||||
request_data.add_parameters(**kwargs)
|
||||
request.data = request_data
|
||||
return request
|
||||
|
||||
|
||||
class AGeneration(Generation, BaseAioApi):
|
||||
@classmethod
|
||||
async def acall(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: Any = None,
|
||||
history: list = None,
|
||||
api_key: str = None,
|
||||
messages: List[Message] = None,
|
||||
plugins: Union[str, Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
|
||||
if (prompt is None or not prompt) and (messages is None or not messages):
|
||||
raise InputRequired("prompt or messages is required!")
|
||||
if model is None or not model:
|
||||
raise ModelRequired("Model is required!")
|
||||
task_group, function = "aigc", "generation" # fixed value
|
||||
if plugins is not None:
|
||||
headers = kwargs.pop("headers", {})
|
||||
if isinstance(plugins, str):
|
||||
headers["X-DashScope-Plugin"] = plugins
|
||||
else:
|
||||
headers["X-DashScope-Plugin"] = json.dumps(plugins)
|
||||
kwargs["headers"] = headers
|
||||
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
|
||||
|
||||
api_key, model = BaseAioApi._validate_params(api_key, model)
|
||||
request = build_api_arequest(
|
||||
model=model,
|
||||
input=input,
|
||||
task_group=task_group,
|
||||
task=Generation.task,
|
||||
function=function,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
response = await request.aio_call()
|
||||
is_stream = kwargs.get("stream", False)
|
||||
if is_stream:
|
||||
|
||||
async def aresp_iterator(response):
|
||||
async for resp in response:
|
||||
yield GenerationResponse.from_api_response(resp)
|
||||
|
||||
return aresp_iterator(response)
|
||||
else:
|
||||
return GenerationResponse.from_api_response(response)
|
||||
|
||||
|
||||
@register_provider(LLMType.DASHSCOPE)
|
||||
class DashScopeLLM(BaseLLM):
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.config = llm_config
|
||||
self.use_system_prompt = False # only some models support system_prompt
|
||||
self.__init_dashscope()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_dashscope(self):
|
||||
self.model = self.config.model
|
||||
self.api_key = self.config.api_key
|
||||
self.token_costs = DASHSCOPE_TOKEN_COSTS
|
||||
self.aclient: AGeneration = AGeneration
|
||||
|
||||
# check support system_message models
|
||||
support_system_models = [
|
||||
"qwen-", # all support
|
||||
"llama2-", # all support
|
||||
"baichuan2-7b-chat-v1",
|
||||
"chatglm3-6b",
|
||||
]
|
||||
for support_model in support_system_models:
|
||||
if support_model in self.model:
|
||||
self.use_system_prompt = True
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"result_format": "message",
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it"s specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if stream:
|
||||
kwargs["incremental_output"] = True
|
||||
return kwargs
|
||||
|
||||
def _check_response(self, resp: GenerationResponse):
|
||||
if resp.status_code != HTTPStatus.OK:
|
||||
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
|
||||
|
||||
def get_choice_text(self, output: GenerationOutput) -> str:
|
||||
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> GenerationOutput:
|
||||
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
|
||||
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
self._check_response(chunk)
|
||||
content = chunk.output.choices[0]["message"]["content"]
|
||||
usage = dict(chunk.usage) # each chunk has usage
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import re
|
||||
|
||||
from openai import APIConnectionError, AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.FIREWORKS)
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config=config)
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager = FireworksCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
choice_delta = choice.delta
|
||||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
print(choice_delta.content, end="")
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -13,19 +13,11 @@ from google.generativeai.types.generation_types import (
|
|||
GenerateContentResponse,
|
||||
GenerationConfig,
|
||||
)
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -55,6 +47,7 @@ class GeminiLLM(BaseLLM):
|
|||
self.__init_gemini(config)
|
||||
self.config = config
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
|
|
@ -72,16 +65,6 @@ class GeminiLLM(BaseLLM):
|
|||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
@ -105,16 +88,16 @@ class GeminiLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
usage = await self.aget_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
|
||||
**self._const_kwargs(messages, stream=True)
|
||||
)
|
||||
|
|
@ -129,17 +112,3 @@ class GeminiLLM(BaseLLM):
|
|||
usage = await self.aget_usage(messages, full_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -35,10 +35,16 @@ class HumanProvider(BaseLLM):
|
|||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ class LLMProviderRegistry:
|
|||
return self.providers[enum]
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
def register_provider(keys):
|
||||
"""register provider to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
else:
|
||||
LLM_REGISTRY.register(keys, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@File : metagpt_api.py
|
||||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -12,4 +14,7 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
|
||||
@register_provider(LLMType.METAGPT)
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
pass
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
# The current billing is based on usage frequency. If there is a future billing logic based on the
|
||||
# number of tokens, please refine the logic here accordingly.
|
||||
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
|
|
|
|||
|
|
@ -4,22 +4,12 @@
|
|||
|
||||
import json
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
|
|
@ -36,26 +26,17 @@ class OllamaLLM(BaseLLM):
|
|||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = TokenCostManager()
|
||||
self.cost_manager = TokenCostManager()
|
||||
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
self.pricing_plan = self.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
|
|
@ -69,7 +50,7 @@ class OllamaLLM(BaseLLM):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
|
|
@ -82,9 +63,9 @@ class OllamaLLM(BaseLLM):
|
|||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
|
|
@ -110,17 +91,3 @@ class OllamaLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
@register_provider(LLMType.OPEN_LLM)
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
|
||||
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
@ -6,10 +6,11 @@
|
|||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
|
@ -28,8 +29,13 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import CodeParser, decode_image, process_message
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.common import (
|
||||
CodeParser,
|
||||
decode_image,
|
||||
log_and_reraise,
|
||||
process_message,
|
||||
)
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -38,33 +44,20 @@ from metagpt.utils.token_counter import (
|
|||
)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMType.OPENAI)
|
||||
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL])
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
|
|
@ -86,22 +79,41 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
usage = None
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if finish_reason:
|
||||
if hasattr(chunk, "usage"):
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
if not usage:
|
||||
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
# "n": 1, # Some services do not provide this parameter, such as mistral
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": 0.3,
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
}
|
||||
|
|
@ -128,18 +140,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
log_llm_stream("\n")
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -239,23 +240,13 @@ class OpenAILLM(BaseLLM):
|
|||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.pricing_plan)
|
||||
except Exception as e:
|
||||
logger.warning(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return self.config.max_token
|
||||
|
|
|
|||
131
metagpt/provider/qianfan_api.py
Normal file
131
metagpt/provider/qianfan_api.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
|
||||
import copy
|
||||
import os
|
||||
|
||||
import qianfan
|
||||
from qianfan import ChatCompletion
|
||||
from qianfan.resources.typing import JsonBody
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import (
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS,
|
||||
QIANFAN_MODEL_TOKEN_COSTS,
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.QIANFAN)
|
||||
class QianFanLLM(BaseLLM):
|
||||
"""
|
||||
Refs
|
||||
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
|
||||
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
|
||||
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
|
||||
self.__init_qianfan()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_qianfan(self):
|
||||
if self.config.access_key and self.config.secret_key:
|
||||
# for system level auth, use access_key and secret_key, recommended by official
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
|
||||
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
|
||||
elif self.config.api_key and self.config.secret_key:
|
||||
# for application level auth, use api_key and secret_key
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
|
||||
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
|
||||
else:
|
||||
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
|
||||
|
||||
support_system_pairs = [
|
||||
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
|
||||
("ERNIE-Bot-8k", "ernie_bot_8k"),
|
||||
("ERNIE-Bot", "completions"),
|
||||
("ERNIE-Bot-turbo", "eb-instant"),
|
||||
("ERNIE-Speed", "ernie_speed"),
|
||||
("EB-turbo-AppBuilder", "ai_apaas"),
|
||||
]
|
||||
if self.config.model in [pair[0] for pair in support_system_pairs]:
|
||||
# only some ERNIE models support
|
||||
self.use_system_prompt = True
|
||||
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
|
||||
self.use_system_prompt = True
|
||||
|
||||
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
|
||||
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
|
||||
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
|
||||
|
||||
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
|
||||
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
|
||||
self.aclient: ChatCompletion = qianfan.ChatCompletion()
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it's specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if self.config.endpoint:
|
||||
kwargs["endpoint"] = self.config.endpoint
|
||||
elif self.config.model:
|
||||
kwargs["model"] = self.config.model
|
||||
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
model_or_endpoint = self.config.model or self.config.endpoint
|
||||
local_calc_usage = model_or_endpoint in self.token_costs
|
||||
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
|
||||
|
||||
def get_choice_text(self, resp: JsonBody) -> str:
|
||||
return resp.get("result", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
content = chunk.body.get("result", "")
|
||||
usage = chunk.body.get("usage", {})
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
@ -31,12 +31,18 @@ class SparkLLM(BaseLLM):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
|
|
|
|||
|
|
@ -5,21 +5,12 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
|
@ -47,22 +38,13 @@ class ZhiPuAILLM(BaseLLM):
|
|||
assert self.config.api_key
|
||||
self.api_key = self.config.api_key
|
||||
self.model = self.config.model # so far, it support glm-3-turbo、glm-4
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
self.llm = ZhiPuModelAPI(api_key=self.api_key)
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
|
|
@ -96,17 +78,3 @@ class ZhiPuAILLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Build a symbols repository from source code.
|
||||
|
||||
This script is designed to create a symbols repository from the provided source code.
|
||||
|
||||
@Time : 2023/11/17 17:58
|
||||
@Author : alexanderwu
|
||||
@File : repo_parser.py
|
||||
|
|
@ -15,15 +19,26 @@ from pathlib import Path
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str, aread
|
||||
from metagpt.utils.common import any_to_str, aread, remove_white_spaces
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class RepoFileInfo(BaseModel):
|
||||
"""
|
||||
Repository data element that represents information about a file.
|
||||
|
||||
Attributes:
|
||||
file (str): The name or path of the file.
|
||||
classes (List): A list of class names present in the file.
|
||||
functions (List): A list of function names present in the file.
|
||||
globals (List): A list of global variable names present in the file.
|
||||
page_info (List): A list of page-related information associated with the file.
|
||||
"""
|
||||
|
||||
file: str
|
||||
classes: List = Field(default_factory=list)
|
||||
functions: List = Field(default_factory=list)
|
||||
|
|
@ -32,6 +47,17 @@ class RepoFileInfo(BaseModel):
|
|||
|
||||
|
||||
class CodeBlockInfo(BaseModel):
|
||||
"""
|
||||
Repository data element representing information about a code block.
|
||||
|
||||
Attributes:
|
||||
lineno (int): The starting line number of the code block.
|
||||
end_lineno (int): The ending line number of the code block.
|
||||
type_name (str): The type or category of the code block.
|
||||
tokens (List): A list of tokens present in the code block.
|
||||
properties (Dict): A dictionary containing additional properties associated with the code block.
|
||||
"""
|
||||
|
||||
lineno: int
|
||||
end_lineno: int
|
||||
type_name: str
|
||||
|
|
@ -39,31 +65,395 @@ class CodeBlockInfo(BaseModel):
|
|||
properties: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassInfo(BaseModel):
|
||||
class DotClassAttribute(BaseModel):
|
||||
"""
|
||||
Repository data element representing a class attribute in dot format.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the class attribute.
|
||||
type_ (str): The type of the class attribute.
|
||||
default_ (str): The default value of the class attribute.
|
||||
description (str): A description of the class attribute.
|
||||
compositions (List[str]): A list of compositions associated with the class attribute.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
type_: str = ""
|
||||
default_: str = ""
|
||||
description: str
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotClassAttribute":
|
||||
"""
|
||||
Parses dot format text and returns a DotClassAttribute object.
|
||||
|
||||
Args:
|
||||
v (str): Dot format text to be parsed.
|
||||
|
||||
Returns:
|
||||
DotClassAttribute: An instance of the DotClassAttribute class representing the parsed data.
|
||||
"""
|
||||
val = ""
|
||||
meet_colon = False
|
||||
meet_equals = False
|
||||
for c in v:
|
||||
if c == ":":
|
||||
meet_colon = True
|
||||
elif c == "=":
|
||||
meet_equals = True
|
||||
if not meet_colon:
|
||||
val += ":"
|
||||
meet_colon = True
|
||||
val += c
|
||||
if not meet_colon:
|
||||
val += ":"
|
||||
if not meet_equals:
|
||||
val += "="
|
||||
|
||||
cix = val.find(":")
|
||||
eix = val.rfind("=")
|
||||
name = val[0:cix].strip()
|
||||
type_ = val[cix + 1 : eix]
|
||||
default_ = val[eix + 1 :].strip()
|
||||
|
||||
type_ = remove_white_spaces(type_) # remove white space
|
||||
if type_ == "NoneType":
|
||||
type_ = ""
|
||||
if "Literal[" in type_:
|
||||
pre_l, literal, post_l = cls._split_literal(type_)
|
||||
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
|
||||
type_ = pre_l + literal + post_l
|
||||
else:
|
||||
type_ = re.sub(r"['\"]+", "", type_) # remove '"
|
||||
composition_val = type_
|
||||
|
||||
if default_ == "None":
|
||||
default_ = ""
|
||||
compositions = cls.parse_compositions(composition_val)
|
||||
return cls(name=name, type_=type_, default_=default_, description=v, compositions=compositions)
|
||||
|
||||
@staticmethod
|
||||
def parse_compositions(types_part) -> List[str]:
|
||||
"""
|
||||
Parses the type definition code block of source code and returns a list of compositions.
|
||||
|
||||
Args:
|
||||
types_part: The type definition code block to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of compositions extracted from the type definition code block.
|
||||
"""
|
||||
if not types_part:
|
||||
return []
|
||||
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
|
||||
types = modified_string.split("|")
|
||||
filters = {
|
||||
"str",
|
||||
"frozenset",
|
||||
"set",
|
||||
"int",
|
||||
"float",
|
||||
"complex",
|
||||
"bool",
|
||||
"dict",
|
||||
"list",
|
||||
"Union",
|
||||
"Dict",
|
||||
"Set",
|
||||
"Tuple",
|
||||
"NoneType",
|
||||
"None",
|
||||
"Any",
|
||||
"Optional",
|
||||
"Iterator",
|
||||
"Literal",
|
||||
"List",
|
||||
}
|
||||
result = set()
|
||||
for t in types:
|
||||
t = re.sub(r"['\"]+", "", t.strip())
|
||||
if t and t not in filters:
|
||||
result.add(t)
|
||||
return list(result)
|
||||
|
||||
@staticmethod
|
||||
def _split_literal(v):
|
||||
"""
|
||||
Parses the literal definition code block and returns three parts: pre-part, literal-part, and post-part.
|
||||
|
||||
Args:
|
||||
v: The literal definition code block to be parsed.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: A tuple containing the pre-part, literal-part, and post-part of the code block.
|
||||
"""
|
||||
tag = "Literal["
|
||||
bix = v.find(tag)
|
||||
eix = len(v) - 1
|
||||
counter = 1
|
||||
for i in range(bix + len(tag), len(v) - 1):
|
||||
c = v[i]
|
||||
if c == "[":
|
||||
counter += 1
|
||||
continue
|
||||
if c == "]":
|
||||
counter -= 1
|
||||
if counter > 0:
|
||||
continue
|
||||
eix = i
|
||||
break
|
||||
pre_l = v[0:bix]
|
||||
post_l = v[eix + 1 :]
|
||||
pre_l = re.sub(r"['\"]", "", pre_l) # remove '"
|
||||
pos_l = re.sub(r"['\"]", "", post_l) # remove '"
|
||||
|
||||
return pre_l, v[bix : eix + 1], pos_l
|
||||
|
||||
@field_validator("compositions", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
"""
|
||||
Auto-sorts a list attribute after making changes.
|
||||
|
||||
Args:
|
||||
lst (List): The list attribute to be sorted.
|
||||
|
||||
Returns:
|
||||
List: The sorted list.
|
||||
"""
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class DotClassInfo(BaseModel):
|
||||
"""
|
||||
Repository data element representing information about a class in dot format.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the class.
|
||||
package (Optional[str]): The package to which the class belongs (optional).
|
||||
attributes (Dict[str, DotClassAttribute]): A dictionary of attributes associated with the class.
|
||||
methods (Dict[str, DotClassMethod]): A dictionary of methods associated with the class.
|
||||
compositions (List[str]): A list of compositions associated with the class.
|
||||
aggregations (List[str]): A list of aggregations associated with the class.
|
||||
"""
|
||||
|
||||
name: str
|
||||
package: Optional[str] = None
|
||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
||||
methods: Dict[str, str] = Field(default_factory=dict)
|
||||
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
|
||||
methods: Dict[str, DotClassMethod] = Field(default_factory=dict)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("compositions", "aggregations", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
"""
|
||||
Auto-sorts a list attribute after making changes.
|
||||
|
||||
Args:
|
||||
lst (List): The list attribute to be sorted.
|
||||
|
||||
Returns:
|
||||
List: The sorted list.
|
||||
"""
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class ClassRelationship(BaseModel):
|
||||
class DotClassRelationship(BaseModel):
|
||||
"""
|
||||
Repository data element representing a relationship between two classes in dot format.
|
||||
|
||||
Attributes:
|
||||
src (str): The source class of the relationship.
|
||||
dest (str): The destination class of the relationship.
|
||||
relationship (str): The type or nature of the relationship.
|
||||
label (Optional[str]): An optional label associated with the relationship.
|
||||
"""
|
||||
|
||||
src: str = ""
|
||||
dest: str = ""
|
||||
relationship: str = ""
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
class DotReturn(BaseModel):
|
||||
"""
|
||||
Repository data element representing a function or method return type in dot format.
|
||||
|
||||
Attributes:
|
||||
type_ (str): The type of the return.
|
||||
description (str): A description of the return type.
|
||||
compositions (List[str]): A list of compositions associated with the return type.
|
||||
"""
|
||||
|
||||
type_: str = ""
|
||||
description: str
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotReturn" | None:
|
||||
"""
|
||||
Parses the return type part of dot format text and returns a DotReturn object.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing the return type part to be parsed.
|
||||
|
||||
Returns:
|
||||
DotReturn | None: An instance of the DotReturn class representing the parsed return type,
|
||||
or None if parsing fails.
|
||||
"""
|
||||
if not v:
|
||||
return DotReturn(description=v)
|
||||
type_ = remove_white_spaces(v)
|
||||
compositions = DotClassAttribute.parse_compositions(type_)
|
||||
return cls(type_=type_, description=v, compositions=compositions)
|
||||
|
||||
@field_validator("compositions", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
"""
|
||||
Auto-sorts a list attribute after making changes.
|
||||
|
||||
Args:
|
||||
lst (List): The list attribute to be sorted.
|
||||
|
||||
Returns:
|
||||
List: The sorted list.
|
||||
"""
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class DotClassMethod(BaseModel):
|
||||
name: str
|
||||
args: List[DotClassAttribute] = Field(default_factory=list)
|
||||
return_args: Optional[DotReturn] = None
|
||||
description: str
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotClassMethod":
|
||||
"""
|
||||
Parses a dot format method text and returns a DotClassMethod object.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing method information to be parsed.
|
||||
|
||||
Returns:
|
||||
DotClassMethod: An instance of the DotClassMethod class representing the parsed method.
|
||||
"""
|
||||
bix = v.find("(")
|
||||
eix = v.rfind(")")
|
||||
rix = v.rfind(":")
|
||||
if rix < 0 or rix < eix:
|
||||
rix = eix
|
||||
name_part = v[0:bix].strip()
|
||||
args_part = v[bix + 1 : eix].strip()
|
||||
return_args_part = v[rix + 1 :].strip()
|
||||
|
||||
name = cls._parse_name(name_part)
|
||||
args = cls._parse_args(args_part)
|
||||
return_args = DotReturn.parse(return_args_part)
|
||||
aggregations = set()
|
||||
for i in args:
|
||||
aggregations.update(set(i.compositions))
|
||||
aggregations.update(set(return_args.compositions))
|
||||
|
||||
return cls(name=name, args=args, description=v, return_args=return_args, aggregations=list(aggregations))
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(v: str) -> str:
|
||||
"""
|
||||
Parses the dot format method name part and returns the method name.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing the method name part to be parsed.
|
||||
|
||||
Returns:
|
||||
str: The parsed method name.
|
||||
"""
|
||||
tags = [">", "</"]
|
||||
if tags[0] in v:
|
||||
bix = v.find(tags[0]) + len(tags[0])
|
||||
eix = v.rfind(tags[1])
|
||||
return v[bix:eix].strip()
|
||||
return v.strip()
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(v: str) -> List[DotClassAttribute]:
|
||||
"""
|
||||
Parses the dot format method arguments part and returns the parsed arguments.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing the arguments part to be parsed.
|
||||
|
||||
Returns:
|
||||
str: The parsed method arguments.
|
||||
"""
|
||||
if not v:
|
||||
return []
|
||||
parts = []
|
||||
bix = 0
|
||||
counter = 0
|
||||
for i in range(0, len(v)):
|
||||
c = v[i]
|
||||
if c == "[":
|
||||
counter += 1
|
||||
continue
|
||||
elif c == "]":
|
||||
counter -= 1
|
||||
continue
|
||||
elif c == "," and counter == 0:
|
||||
parts.append(v[bix:i].strip())
|
||||
bix = i + 1
|
||||
parts.append(v[bix:].strip())
|
||||
|
||||
attrs = []
|
||||
for p in parts:
|
||||
if p:
|
||||
attr = DotClassAttribute.parse(p)
|
||||
attrs.append(attr)
|
||||
return attrs
|
||||
|
||||
|
||||
class RepoParser(BaseModel):
|
||||
"""
|
||||
Tool to build a symbols repository from a project directory.
|
||||
|
||||
Attributes:
|
||||
base_directory (Path): The base directory of the project.
|
||||
"""
|
||||
|
||||
base_directory: Path = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
@handle_exception(exception_type=Exception, default_return=[])
|
||||
def _parse_file(cls, file_path: Path) -> list:
|
||||
"""Parse a Python file in the repository."""
|
||||
"""
|
||||
Parses a Python file in the repository.
|
||||
|
||||
Args:
|
||||
file_path (Path): The path to the Python file to be parsed.
|
||||
|
||||
Returns:
|
||||
list: A list containing the parsed symbols from the file.
|
||||
"""
|
||||
return ast.parse(file_path.read_text()).body
|
||||
|
||||
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
|
||||
"""Extract class, function, and global variable information from the AST."""
|
||||
"""
|
||||
Extracts class, function, and global variable information from the Abstract Syntax Tree (AST).
|
||||
|
||||
Args:
|
||||
tree: The Abstract Syntax Tree (AST) of the Python file.
|
||||
file_path: The path to the Python file.
|
||||
|
||||
Returns:
|
||||
RepoFileInfo: A RepoFileInfo object containing the extracted information.
|
||||
"""
|
||||
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
|
||||
for node in tree:
|
||||
info = RepoParser.node_to_str(node)
|
||||
|
|
@ -81,11 +471,17 @@ class RepoParser(BaseModel):
|
|||
return file_info
|
||||
|
||||
def generate_symbols(self) -> List[RepoFileInfo]:
|
||||
"""
|
||||
Builds a symbol repository from '.py' and '.js' files in the project directory.
|
||||
|
||||
Returns:
|
||||
List[RepoFileInfo]: A list of RepoFileInfo objects containing the extracted information.
|
||||
"""
|
||||
files_classes = []
|
||||
directory = self.base_directory
|
||||
|
||||
matching_files = []
|
||||
extensions = ["*.py", "*.js"]
|
||||
extensions = ["*.py"]
|
||||
for ext in extensions:
|
||||
matching_files += directory.rglob(ext)
|
||||
for path in matching_files:
|
||||
|
|
@ -95,19 +491,38 @@ class RepoParser(BaseModel):
|
|||
|
||||
return files_classes
|
||||
|
||||
def generate_json_structure(self, output_path):
|
||||
"""Generate a JSON file documenting the repository structure."""
|
||||
def generate_json_structure(self, output_path: Path):
|
||||
"""
|
||||
Generates a JSON file documenting the repository structure.
|
||||
|
||||
Args:
|
||||
output_path (Path): The path to the JSON file to be generated.
|
||||
"""
|
||||
files_classes = [i.model_dump() for i in self.generate_symbols()]
|
||||
output_path.write_text(json.dumps(files_classes, indent=4))
|
||||
|
||||
def generate_dataframe_structure(self, output_path):
|
||||
"""Generate a DataFrame documenting the repository structure and save as CSV."""
|
||||
def generate_dataframe_structure(self, output_path: Path):
|
||||
"""
|
||||
Generates a DataFrame documenting the repository structure and saves it as a CSV file.
|
||||
|
||||
Args:
|
||||
output_path (Path): The path to the CSV file to be generated.
|
||||
"""
|
||||
files_classes = [i.model_dump() for i in self.generate_symbols()]
|
||||
df = pd.DataFrame(files_classes)
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
def generate_structure(self, output_path=None, mode="json") -> Path:
|
||||
"""Generate the structure of the repository as a specified format."""
|
||||
def generate_structure(self, output_path: str | Path = None, mode="json") -> Path:
|
||||
"""
|
||||
Generates the structure of the repository in a specified format.
|
||||
|
||||
Args:
|
||||
output_path (str | Path): The path to the output file or directory. Default is None.
|
||||
mode (str): The output format mode. Options: "json" (default), "csv", etc.
|
||||
|
||||
Returns:
|
||||
Path: The path to the generated output file or directory.
|
||||
"""
|
||||
output_file = self.base_directory / f"{self.base_directory.name}-structure.{mode}"
|
||||
output_path = Path(output_path) if output_path else output_file
|
||||
|
||||
|
|
@ -119,6 +534,16 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def node_to_str(node) -> CodeBlockInfo | None:
|
||||
"""
|
||||
Parses and converts an Abstract Syntax Tree (AST) node to a CodeBlockInfo object.
|
||||
|
||||
Args:
|
||||
node: The AST node to be converted.
|
||||
|
||||
Returns:
|
||||
CodeBlockInfo | None: A CodeBlockInfo object representing the parsed AST node,
|
||||
or None if the conversion fails.
|
||||
"""
|
||||
if isinstance(node, ast.Try):
|
||||
return None
|
||||
if any_to_str(node) == any_to_str(ast.Expr):
|
||||
|
|
@ -159,9 +584,19 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_expr(node) -> List:
|
||||
"""
|
||||
Parses an expression Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
node: The AST node representing an expression.
|
||||
|
||||
Returns:
|
||||
List: A list containing the parsed information from the expression node.
|
||||
"""
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
|
||||
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
|
||||
any_to_str(ast.Tuple): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
|
||||
}
|
||||
func = funcs.get(any_to_str(node.value))
|
||||
if func:
|
||||
|
|
@ -170,12 +605,30 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_name(n):
|
||||
"""
|
||||
Gets the 'name' value of an Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
n: The AST node.
|
||||
|
||||
Returns:
|
||||
The 'name' value of the AST node.
|
||||
"""
|
||||
if n.asname:
|
||||
return f"{n.name} as {n.asname}"
|
||||
return n.name
|
||||
|
||||
@staticmethod
|
||||
def _parse_if(n):
|
||||
"""
|
||||
Parses an 'if' statement Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
n: The AST node representing an 'if' statement.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the 'if' statement node.
|
||||
"""
|
||||
tokens = []
|
||||
try:
|
||||
if isinstance(n.test, ast.BoolOp):
|
||||
|
|
@ -187,10 +640,14 @@ class RepoParser(BaseModel):
|
|||
v = RepoParser._parse_variable(n.test.left)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
for item in n.test.comparators:
|
||||
v = RepoParser._parse_variable(item)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
if isinstance(n.test, ast.Name):
|
||||
v = RepoParser._parse_variable(n.test)
|
||||
tokens.append(v)
|
||||
if hasattr(n.test, "comparators"):
|
||||
for item in n.test.comparators:
|
||||
v = RepoParser._parse_variable(item)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
return tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Unsupported if: {n}, err:{e}")
|
||||
|
|
@ -198,6 +655,15 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_if_compare(n):
|
||||
"""
|
||||
Parses an 'if' condition Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
n: The AST node representing an 'if' condition.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the 'if' condition node.
|
||||
"""
|
||||
if hasattr(n, "left"):
|
||||
return RepoParser._parse_variable(n.left)
|
||||
else:
|
||||
|
|
@ -205,6 +671,15 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_variable(node):
|
||||
"""
|
||||
Parses a variable Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
node: The AST node representing a variable.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the variable node.
|
||||
"""
|
||||
try:
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: x.value,
|
||||
|
|
@ -213,7 +688,7 @@ class RepoParser(BaseModel):
|
|||
if hasattr(x.value, "id")
|
||||
else f"{x.attr}",
|
||||
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
|
||||
any_to_str(ast.Tuple): lambda x: "",
|
||||
any_to_str(ast.Tuple): lambda x: [d.value for d in x.dims],
|
||||
}
|
||||
func = funcs.get(any_to_str(node))
|
||||
if not func:
|
||||
|
|
@ -224,22 +699,42 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_assign(node):
|
||||
"""
|
||||
Parses an assignment Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
node: The AST node representing an assignment.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the assignment node.
|
||||
"""
|
||||
return [RepoParser._parse_variable(t) for t in node.targets]
|
||||
|
||||
async def rebuild_class_views(self, path: str | Path = None):
|
||||
"""
|
||||
Executes `pylint` to reconstruct the dot format class view repository file.
|
||||
|
||||
Args:
|
||||
path (str | Path): The path to the target directory or file. Default is None.
|
||||
"""
|
||||
if not path:
|
||||
path = self.base_directory
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return
|
||||
init_file = path / "__init__.py"
|
||||
if not init_file.exists():
|
||||
raise ValueError("Failed to import module __init__ with error:No module named __init__.")
|
||||
command = f"pyreverse {str(path)} -o dot"
|
||||
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
|
||||
output_dir = path / "__dot__"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"{result}")
|
||||
class_view_pathname = path / "classes.dot"
|
||||
class_view_pathname = output_dir / "classes.dot"
|
||||
class_views = await self._parse_classes(class_view_pathname)
|
||||
relationship_views = await self._parse_class_relationships(class_view_pathname)
|
||||
packages_pathname = path / "packages.dot"
|
||||
packages_pathname = output_dir / "packages.dot"
|
||||
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
|
||||
class_views=class_views, relationship_views=relationship_views, path=path
|
||||
)
|
||||
|
|
@ -247,7 +742,17 @@ class RepoParser(BaseModel):
|
|||
packages_pathname.unlink(missing_ok=True)
|
||||
return class_views, relationship_views, package_root
|
||||
|
||||
async def _parse_classes(self, class_view_pathname):
|
||||
@staticmethod
|
||||
async def _parse_classes(class_view_pathname: Path) -> List[DotClassInfo]:
|
||||
"""
|
||||
Parses a dot format class view repository file.
|
||||
|
||||
Args:
|
||||
class_view_pathname (Path): The path to the dot format class view repository file.
|
||||
|
||||
Returns:
|
||||
List[DotClassInfo]: A list of DotClassInfo objects representing the parsed classes.
|
||||
"""
|
||||
class_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return class_views
|
||||
|
|
@ -258,22 +763,38 @@ class RepoParser(BaseModel):
|
|||
if not package_name:
|
||||
continue
|
||||
class_name, members, functions = re.split(r"(?<!\\)\|", info)
|
||||
class_info = ClassInfo(name=class_name)
|
||||
class_info = DotClassInfo(name=class_name)
|
||||
class_info.package = package_name
|
||||
for m in members.split("\n"):
|
||||
if not m:
|
||||
continue
|
||||
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
|
||||
class_info.attributes[member_name] = m
|
||||
attr = DotClassAttribute.parse(m)
|
||||
class_info.attributes[attr.name] = attr
|
||||
for i in attr.compositions:
|
||||
if i not in class_info.compositions:
|
||||
class_info.compositions.append(i)
|
||||
for f in functions.split("\n"):
|
||||
if not f:
|
||||
continue
|
||||
function_name, _ = f.split("(", 1)
|
||||
class_info.methods[function_name] = f
|
||||
method = DotClassMethod.parse(f)
|
||||
class_info.methods[method.name] = method
|
||||
for i in method.aggregations:
|
||||
if i not in class_info.compositions and i not in class_info.aggregations:
|
||||
class_info.aggregations.append(i)
|
||||
class_views.append(class_info)
|
||||
return class_views
|
||||
|
||||
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
|
||||
@staticmethod
|
||||
async def _parse_class_relationships(class_view_pathname: Path) -> List[DotClassRelationship]:
|
||||
"""
|
||||
Parses a dot format class view repository file.
|
||||
|
||||
Args:
|
||||
class_view_pathname (Path): The path to the dot format class view repository file.
|
||||
|
||||
Returns:
|
||||
List[DotClassRelationship]: A list of DotClassRelationship objects representing the parsed class relationships.
|
||||
"""
|
||||
relationship_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return relationship_views
|
||||
|
|
@ -287,7 +808,16 @@ class RepoParser(BaseModel):
|
|||
return relationship_views
|
||||
|
||||
@staticmethod
|
||||
def _split_class_line(line):
|
||||
def _split_class_line(line: str) -> (str, str):
|
||||
"""
|
||||
Parses a dot format line about class info and returns the class name part and class members part.
|
||||
|
||||
Args:
|
||||
line (str): The dot format line containing class information.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the class name part and class members part.
|
||||
"""
|
||||
part_splitor = '" ['
|
||||
if part_splitor not in line:
|
||||
return None, None
|
||||
|
|
@ -305,14 +835,25 @@ class RepoParser(BaseModel):
|
|||
return class_name, info
|
||||
|
||||
@staticmethod
|
||||
def _split_relationship_line(line):
|
||||
def _split_relationship_line(line: str) -> DotClassRelationship:
|
||||
"""
|
||||
Parses a dot format line about the relationship of two classes and returns 'Generalize', 'Composite',
|
||||
or 'Aggregate'.
|
||||
|
||||
Args:
|
||||
line (str): The dot format line containing relationship information.
|
||||
|
||||
Returns:
|
||||
DotClassRelationship: The object of relationship representing either 'Generalize', 'Composite',
|
||||
or 'Aggregate' relationship.
|
||||
"""
|
||||
splitters = [" -> ", " [", "];"]
|
||||
idxs = []
|
||||
for tag in splitters:
|
||||
if tag not in line:
|
||||
return None
|
||||
idxs.append(line.find(tag))
|
||||
ret = ClassRelationship()
|
||||
ret = DotClassRelationship()
|
||||
ret.src = line[0 : idxs[0]].strip('"')
|
||||
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
|
||||
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
|
||||
|
|
@ -330,7 +871,16 @@ class RepoParser(BaseModel):
|
|||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _get_label(line):
|
||||
def _get_label(line: str) -> str:
|
||||
"""
|
||||
Parses a dot format line and returns the label information.
|
||||
|
||||
Args:
|
||||
line (str): The dot format line containing label information.
|
||||
|
||||
Returns:
|
||||
str: The label information parsed from the line.
|
||||
"""
|
||||
tag = 'label="'
|
||||
if tag not in line:
|
||||
return ""
|
||||
|
|
@ -340,6 +890,15 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
|
||||
"""
|
||||
Creates a mapping table between source code files' paths and module names.
|
||||
|
||||
Args:
|
||||
path (str | Path): The path to the source code files or directory.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping source code file paths to their corresponding module names.
|
||||
"""
|
||||
mappings = {
|
||||
str(path).replace("/", "."): str(path),
|
||||
}
|
||||
|
|
@ -363,8 +922,21 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _repair_namespaces(
|
||||
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
|
||||
) -> (List[ClassInfo], List[ClassRelationship], str):
|
||||
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
|
||||
) -> (List[DotClassInfo], List[DotClassRelationship], str):
|
||||
"""
|
||||
Augments namespaces to the path-prefixed classes and relationships.
|
||||
|
||||
Args:
|
||||
class_views (List[DotClassInfo]): List of DotClassInfo objects representing class views.
|
||||
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects representing
|
||||
relationships.
|
||||
path (str | Path): The path to the source code files or directory.
|
||||
|
||||
Returns:
|
||||
Tuple[List[DotClassInfo], List[DotClassRelationship], str]: A tuple containing the augmented class views,
|
||||
relationships, and the root path of the package.
|
||||
"""
|
||||
if not class_views:
|
||||
return [], [], ""
|
||||
c = class_views[0]
|
||||
|
|
@ -383,28 +955,49 @@ class RepoParser(BaseModel):
|
|||
|
||||
for c in class_views:
|
||||
c.package = RepoParser._repair_ns(c.package, new_mappings)
|
||||
for i in range(len(relationship_views)):
|
||||
v = relationship_views[i]
|
||||
for _, v in enumerate(relationship_views):
|
||||
v.src = RepoParser._repair_ns(v.src, new_mappings)
|
||||
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
|
||||
relationship_views[i] = v
|
||||
return class_views, relationship_views, root_path
|
||||
return class_views, relationship_views, str(path)[: len(root_path)]
|
||||
|
||||
@staticmethod
|
||||
def _repair_ns(package, mappings):
|
||||
def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
|
||||
"""
|
||||
Replaces the package-prefix with the namespace-prefix.
|
||||
|
||||
Args:
|
||||
package (str): The package to be repaired.
|
||||
mappings (Dict[str, str]): A dictionary mapping source code file paths to their corresponding packages.
|
||||
|
||||
Returns:
|
||||
str: The repaired namespace.
|
||||
"""
|
||||
file_ns = package
|
||||
ix = 0
|
||||
while file_ns != "":
|
||||
if file_ns not in mappings:
|
||||
ix = file_ns.rfind(".")
|
||||
file_ns = file_ns[0:ix]
|
||||
continue
|
||||
break
|
||||
if file_ns == "":
|
||||
return ""
|
||||
internal_ns = package[ix + 1 :]
|
||||
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
|
||||
return ns
|
||||
|
||||
@staticmethod
|
||||
def _find_root(full_key, package) -> str:
|
||||
def _find_root(full_key: str, package: str) -> str:
|
||||
"""
|
||||
Returns the package root path based on the key, which is the full path, and the package information.
|
||||
|
||||
Args:
|
||||
full_key (str): The full key representing the full path.
|
||||
package (str): The package information.
|
||||
|
||||
Returns:
|
||||
str: The package root path.
|
||||
"""
|
||||
left = full_key
|
||||
while left != "":
|
||||
if left in package:
|
||||
|
|
@ -417,5 +1010,14 @@ class RepoParser(BaseModel):
|
|||
return "." + full_key[0:ix]
|
||||
|
||||
|
||||
def is_func(node):
|
||||
def is_func(node) -> bool:
|
||||
"""
|
||||
Returns True if the given node represents a function.
|
||||
|
||||
Args:
|
||||
node: The Abstract Syntax Tree (AST) node.
|
||||
|
||||
Returns:
|
||||
bool: True if the node represents a function, False otherwise.
|
||||
"""
|
||||
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class Assistant(Role):
|
|||
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
|
||||
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
|
||||
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
|
||||
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
|
||||
rsp = await self.llm.aask(prompt, ["You are an action classifier"], stream=False)
|
||||
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
|
||||
return await self._plan(rsp, last_talk=last_talk)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from typing import Literal, Union
|
|||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions.mi.ask_review import ReviewConst
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.write_analysis_code import CheckData, WriteCodeWithTools
|
||||
from metagpt.actions.di.ask_review import ReviewConst
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.di.write_analysis_code import CheckData, WriteCodeWithTools
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.mi.write_analysis_code import DATA_INFO
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -32,9 +32,9 @@ Output a json following the format:
|
|||
"""
|
||||
|
||||
|
||||
class Interpreter(Role):
|
||||
name: str = "Ivy"
|
||||
profile: str = "Interpreter"
|
||||
class DataInterpreter(Role):
|
||||
name: str = "David"
|
||||
profile: str = "DataInterpreter"
|
||||
auto_run: bool = True
|
||||
use_plan: bool = True
|
||||
use_reflection: bool = False
|
||||
|
|
@ -20,7 +20,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
|
@ -32,7 +31,6 @@ from metagpt.actions.summarize_code import SummarizeCode
|
|||
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
|
||||
from metagpt.const import (
|
||||
CODE_PLAN_AND_CHANGE_FILE_REPO,
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
|
|
@ -119,10 +117,10 @@ class Engineer(Role):
|
|||
|
||||
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
|
||||
if self.config.inc:
|
||||
dependencies.add(os.path.join(CODE_PLAN_AND_CHANGE_FILE_REPO, CODE_PLAN_AND_CHANGE_FILENAME))
|
||||
dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path)
|
||||
await self.project_repo.srcs.save(
|
||||
filename=coding_context.filename,
|
||||
dependencies=dependencies,
|
||||
dependencies=list(dependencies),
|
||||
content=coding_context.code_doc.content,
|
||||
)
|
||||
msg = Message(
|
||||
|
|
@ -206,7 +204,6 @@ class Engineer(Role):
|
|||
|
||||
async def _act_code_plan_and_change(self):
|
||||
"""Write code plan and change that guides subsequent WriteCode and WriteCodeReview"""
|
||||
logger.info("Writing code plan and change..")
|
||||
node = await self.rc.todo.run()
|
||||
code_plan_and_change = node.instruct_content.model_dump_json()
|
||||
dependencies = {
|
||||
|
|
@ -215,11 +212,12 @@ class Engineer(Role):
|
|||
self.rc.todo.i_context.design_filename,
|
||||
self.rc.todo.i_context.task_filename,
|
||||
}
|
||||
code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename)
|
||||
await self.project_repo.docs.code_plan_and_change.save(
|
||||
filename=self.rc.todo.i_context.filename, content=code_plan_and_change, dependencies=dependencies
|
||||
filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies
|
||||
)
|
||||
await self.project_repo.resources.code_plan_and_change.save(
|
||||
filename=Path(self.rc.todo.i_context.filename).with_suffix(".md").name,
|
||||
filename=code_plan_and_change_filepath.with_suffix(".md").name,
|
||||
content=node.content,
|
||||
dependencies=dependencies,
|
||||
)
|
||||
|
|
@ -269,15 +267,24 @@ class Engineer(Role):
|
|||
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
|
||||
task_doc = None
|
||||
design_doc = None
|
||||
code_plan_and_change_doc = None
|
||||
for i in dependencies:
|
||||
if str(i.parent) == TASK_FILE_REPO:
|
||||
task_doc = await self.project_repo.docs.task.get(i.name)
|
||||
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
|
||||
design_doc = await self.project_repo.docs.system_design.get(i.name)
|
||||
elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO:
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name)
|
||||
if not task_doc or not design_doc:
|
||||
logger.error(f'Detected source code "{filename}" from an unknown origin.')
|
||||
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
|
||||
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
|
||||
context = CodingContext(
|
||||
filename=filename,
|
||||
design_doc=design_doc,
|
||||
task_doc=task_doc,
|
||||
code_doc=old_code_doc,
|
||||
code_plan_and_change_doc=code_plan_and_change_doc,
|
||||
)
|
||||
return context
|
||||
|
||||
async def _new_coding_doc(self, filename, dependency):
|
||||
|
|
@ -296,6 +303,7 @@ class Engineer(Role):
|
|||
for filename in changed_task_files:
|
||||
design_doc = await self.project_repo.docs.system_design.get(filename)
|
||||
task_doc = await self.project_repo.docs.task.get(filename)
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename)
|
||||
task_list = self._parse_tasks(task_doc)
|
||||
for task_filename in task_list:
|
||||
old_code_doc = await self.project_repo.srcs.get(task_filename)
|
||||
|
|
@ -303,9 +311,18 @@ class Engineer(Role):
|
|||
old_code_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
|
||||
)
|
||||
context = CodingContext(
|
||||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
if not code_plan_and_change_doc:
|
||||
context = CodingContext(
|
||||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
else:
|
||||
context = CodingContext(
|
||||
filename=task_filename,
|
||||
design_doc=design_doc,
|
||||
task_doc=task_doc,
|
||||
code_doc=old_code_doc,
|
||||
code_plan_and_change_doc=code_plan_and_change_doc,
|
||||
)
|
||||
coding_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path),
|
||||
filename=task_filename,
|
||||
|
|
@ -342,9 +359,17 @@ class Engineer(Role):
|
|||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
ctx.codes_filenames = filenames
|
||||
self.summarize_todos.append(SummarizeCode(i_context=ctx, context=self.context, llm=self.llm))
|
||||
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
|
||||
for i, act in enumerate(self.summarize_todos):
|
||||
if act.i_context.task_filename == new_summarize.i_context.task_filename:
|
||||
self.summarize_todos[i] = new_summarize
|
||||
new_summarize = None
|
||||
break
|
||||
if new_summarize:
|
||||
self.summarize_todos.append(new_summarize)
|
||||
if self.summarize_todos:
|
||||
self.set_todo(self.summarize_todos[0])
|
||||
self.summarize_todos.pop(0)
|
||||
|
||||
async def _new_code_plan_and_change_action(self):
|
||||
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
i = action
|
||||
self._init_action(i)
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
self.states.append(f"{len(self.actions) - 1}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ from pydantic import (
|
|||
)
|
||||
|
||||
from metagpt.const import (
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
|
|
@ -47,6 +46,7 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import DotClassInfo
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.serialize import (
|
||||
|
|
@ -613,6 +613,7 @@ class CodingContext(BaseContext):
|
|||
design_doc: Optional[Document] = None
|
||||
task_doc: Optional[Document] = None
|
||||
code_doc: Optional[Document] = None
|
||||
code_plan_and_change_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class TestingContext(BaseContext):
|
||||
|
|
@ -667,7 +668,6 @@ class BugFixContext(BaseContext):
|
|||
|
||||
|
||||
class CodePlanAndChangeContext(BaseModel):
|
||||
filename: str = CODE_PLAN_AND_CHANGE_FILENAME
|
||||
requirement: str = ""
|
||||
prd_filename: str = ""
|
||||
design_filename: str = ""
|
||||
|
|
@ -691,54 +691,64 @@ class CodePlanAndChangeContext(BaseModel):
|
|||
|
||||
|
||||
# mermaid class view
|
||||
class ClassMeta(BaseModel):
|
||||
class UMLClassMeta(BaseModel):
|
||||
name: str = ""
|
||||
abstraction: bool = False
|
||||
static: bool = False
|
||||
visibility: str = ""
|
||||
|
||||
@staticmethod
|
||||
def name_to_visibility(name: str) -> str:
|
||||
if name == "__init__":
|
||||
return "+"
|
||||
if name.startswith("__"):
|
||||
return "-"
|
||||
elif name.startswith("_"):
|
||||
return "#"
|
||||
return "+"
|
||||
|
||||
class ClassAttribute(ClassMeta):
|
||||
|
||||
class UMLClassAttribute(UMLClassMeta):
|
||||
value_type: str = ""
|
||||
default_value: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
if self.value_type:
|
||||
content += self.value_type + " "
|
||||
content += self.name
|
||||
content += self.value_type.replace(" ", "") + " "
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name
|
||||
if self.default_value:
|
||||
content += "="
|
||||
if self.value_type not in ["str", "string", "String"]:
|
||||
content += self.default_value
|
||||
else:
|
||||
content += '"' + self.default_value.replace('"', "") + '"'
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassMethod(ClassMeta):
|
||||
args: List[ClassAttribute] = Field(default_factory=list)
|
||||
class UMLClassMethod(UMLClassMeta):
|
||||
args: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
return_type: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
if self.return_type:
|
||||
content += ":" + self.return_type
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
content += " " + self.return_type.replace(" ", "")
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassView(ClassMeta):
|
||||
attributes: List[ClassAttribute] = Field(default_factory=list)
|
||||
methods: List[ClassMethod] = Field(default_factory=list)
|
||||
class UMLClassView(UMLClassMeta):
|
||||
attributes: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
methods: List[UMLClassMethod] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
||||
|
|
@ -748,3 +758,21 @@ class ClassView(ClassMeta):
|
|||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
content += "".join(["\t" for i in range(align)]) + "}\n"
|
||||
return content
|
||||
|
||||
@classmethod
|
||||
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
|
||||
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
|
||||
class_view = cls(name=dot_class_info.name, visibility=visibility)
|
||||
for i in dot_class_info.attributes.values():
|
||||
visibility = UMLClassAttribute.name_to_visibility(i.name)
|
||||
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
|
||||
class_view.attributes.append(attr)
|
||||
for i in dot_class_info.methods.values():
|
||||
visibility = UMLClassMethod.name_to_visibility(i.name)
|
||||
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
|
||||
for j in i.args:
|
||||
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
|
||||
method.args.append(arg)
|
||||
method.return_type = i.return_args.type_
|
||||
class_view.methods.append(method)
|
||||
return class_view
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.context import Context
|
||||
from metagpt.const import CONFIG_ROOT
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
|
@ -30,6 +27,8 @@ def generate_repo(
|
|||
recover_path=None,
|
||||
) -> ProjectRepo:
|
||||
"""Run the startup logic. Can be called from CLI or other Python scripts."""
|
||||
from metagpt.config2 import config
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import (
|
||||
Architect,
|
||||
Engineer,
|
||||
|
|
@ -122,7 +121,17 @@ def startup(
|
|||
)
|
||||
|
||||
|
||||
def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
|
||||
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
|
||||
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
|
||||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "YOUR_API_KEY"
|
||||
"""
|
||||
|
||||
|
||||
def copy_config_to():
|
||||
"""Initialize the configuration file for MetaGPT."""
|
||||
target_path = CONFIG_ROOT / "config2.yaml"
|
||||
|
||||
|
|
@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
|
|||
print(f"Existing configuration file backed up at {backup_path}")
|
||||
|
||||
# 复制文件
|
||||
shutil.copy(str(config_path), target_path)
|
||||
target_path.write_text(DEFAULT_CONFIG, encoding="utf-8")
|
||||
print(f"Configuration file initialized at {target_path}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import json
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.mi.ask_review import AskReview, ReviewConst
|
||||
from metagpt.actions.mi.write_plan import (
|
||||
from metagpt.actions.di.ask_review import AskReview, ReviewConst
|
||||
from metagpt.actions.di.write_plan import (
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
update_plan_from_rsp,
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ class TOTSolver(BaseSolver):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class InterpreterSolver(BaseSolver):
|
||||
"""InterpreterSolver: Write&Run code in the graph"""
|
||||
class DataInterpreterSolver(BaseSolver):
|
||||
"""DataInterpreterSolver: Write&Run code in the graph"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -23,10 +23,10 @@ import platform
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, List, Literal, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
|
|
@ -423,23 +423,109 @@ def is_send_to(message: "Message", addresses: set):
|
|||
def any_to_name(val):
|
||||
"""
|
||||
Convert a value to its name by extracting the last part of the dotted path.
|
||||
|
||||
:param val: The value to convert.
|
||||
|
||||
:return: The name of the value.
|
||||
"""
|
||||
return any_to_str(val).split(".")[-1]
|
||||
|
||||
|
||||
def concat_namespace(*args) -> str:
|
||||
return ":".join(str(value) for value in args)
|
||||
def concat_namespace(*args, delimiter: str = ":") -> str:
|
||||
"""Concatenate fields to create a unique namespace prefix.
|
||||
|
||||
Example:
|
||||
>>> concat_namespace('prefix', 'field1', 'field2', delimiter=":")
|
||||
'prefix:field1:field2'
|
||||
"""
|
||||
return delimiter.join(str(value) for value in args)
|
||||
|
||||
|
||||
def split_namespace(ns_class_name: str) -> List[str]:
|
||||
return ns_class_name.split(":")
|
||||
def split_namespace(ns_class_name: str, delimiter: str = ":", maxsplit: int = 1) -> List[str]:
|
||||
"""Split a namespace-prefixed name into its namespace-prefix and name parts.
|
||||
|
||||
Example:
|
||||
>>> split_namespace('prefix:classname')
|
||||
['prefix', 'classname']
|
||||
|
||||
>>> split_namespace('prefix:module:class', delimiter=":", maxsplit=2)
|
||||
['prefix', 'module', 'class']
|
||||
"""
|
||||
return ns_class_name.split(delimiter, maxsplit=maxsplit)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
def auto_namespace(name: str, delimiter: str = ":") -> str:
|
||||
"""Automatically handle namespace-prefixed names.
|
||||
|
||||
If the input name is empty, returns a default namespace prefix and name.
|
||||
If the input name is not namespace-prefixed, adds a default namespace prefix.
|
||||
Otherwise, returns the input name unchanged.
|
||||
|
||||
Example:
|
||||
>>> auto_namespace('classname')
|
||||
'?:classname'
|
||||
|
||||
>>> auto_namespace('prefix:classname')
|
||||
'prefix:classname'
|
||||
|
||||
>>> auto_namespace('')
|
||||
'?:?'
|
||||
|
||||
>>> auto_namespace('?:custom')
|
||||
'?:custom'
|
||||
"""
|
||||
if not name:
|
||||
return f"?{delimiter}?"
|
||||
v = split_namespace(name, delimiter=delimiter)
|
||||
if len(v) < 2:
|
||||
return f"?{delimiter}{name}"
|
||||
return name
|
||||
|
||||
|
||||
def add_affix(text: str, affix: Literal["brace", "url", "none"] = "brace"):
|
||||
"""Add affix to encapsulate data.
|
||||
|
||||
Example:
|
||||
>>> add_affix("data", affix="brace")
|
||||
'{data}'
|
||||
|
||||
>>> add_affix("example.com", affix="url")
|
||||
'%7Bexample.com%7D'
|
||||
|
||||
>>> add_affix("text", affix="none")
|
||||
'text'
|
||||
"""
|
||||
mappings = {
|
||||
"brace": lambda x: "{" + x + "}",
|
||||
"url": lambda x: quote("{" + x + "}"),
|
||||
}
|
||||
encoder = mappings.get(affix, lambda x: x)
|
||||
return encoder(text)
|
||||
|
||||
|
||||
def remove_affix(text, affix: Literal["brace", "url", "none"] = "brace"):
|
||||
"""Remove affix to extract encapsulated data.
|
||||
|
||||
Args:
|
||||
text (str): The input text with affix to be removed.
|
||||
affix (str, optional): The type of affix used. Defaults to "brace".
|
||||
Supported affix types: "brace" for removing curly braces, "url" for URL decoding within curly braces.
|
||||
|
||||
Returns:
|
||||
str: The text with affix removed.
|
||||
|
||||
Example:
|
||||
>>> remove_affix('{data}', affix="brace")
|
||||
'data'
|
||||
|
||||
>>> remove_affix('%7Bexample.com%7D', affix="url")
|
||||
'example.com'
|
||||
|
||||
>>> remove_affix('text', affix="none")
|
||||
'text'
|
||||
"""
|
||||
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
|
||||
decoder = mappings.get(affix, lambda x: x)
|
||||
return decoder(text)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
||||
|
|
@ -626,6 +712,54 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
return files
|
||||
|
||||
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
||||
|
||||
def remove_white_spaces(v: str) -> str:
|
||||
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
|
||||
|
||||
|
||||
async def aread_bin(filename: str | Path) -> bytes:
|
||||
"""Read binary file asynchronously.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The name or path of the file to be read.
|
||||
|
||||
Returns:
|
||||
bytes: The content of the file as bytes.
|
||||
|
||||
Example:
|
||||
>>> content = await aread_bin('example.txt')
|
||||
b'This is the content of the file.'
|
||||
|
||||
>>> content = await aread_bin(Path('example.txt'))
|
||||
b'This is the content of the file.'
|
||||
"""
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
content = await reader.read()
|
||||
return content
|
||||
|
||||
|
||||
async def awrite_bin(filename: str | Path, data: bytes):
|
||||
"""Write binary file asynchronously.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The name or path of the file to be written.
|
||||
data (bytes): The binary data to be written to the file.
|
||||
|
||||
Example:
|
||||
>>> await awrite_bin('output.bin', b'This is binary data.')
|
||||
|
||||
>>> await awrite_bin(Path('output.bin'), b'Another set of binary data.')
|
||||
"""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="wb") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
|
||||
def is_coroutine_func(func: Callable) -> bool:
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
|
@ -689,3 +823,14 @@ def process_message(messages: Union[str, Message, list[dict], list[Message], lis
|
|||
else:
|
||||
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
|
||||
return processed_messages
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import NamedTuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
|
|
@ -29,6 +30,7 @@ class CostManager(BaseModel):
|
|||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
|
|
@ -39,14 +41,17 @@ class CostManager(BaseModel):
|
|||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
if prompt_tokens + completion_tokens == 0 or not model:
|
||||
return
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
if model not in TOKEN_COSTS:
|
||||
if model not in self.token_costs:
|
||||
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
|
||||
return
|
||||
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
prompt_tokens * self.token_costs[model]["prompt"]
|
||||
+ completion_tokens * self.token_costs[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
|
|
@ -101,3 +106,44 @@ class TokenCostManager(CostManager):
|
|||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : di_graph_repository.py
|
||||
@Desc : Graph repository based on DiGraph
|
||||
@Desc : Graph repository based on DiGraph.
|
||||
This script defines a graph repository class based on a directed graph (DiGraph), providing functionalities
|
||||
specific to handling directed relationships between entities.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -19,20 +21,41 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
|
|||
|
||||
|
||||
class DiGraphRepository(GraphRepository):
|
||||
"""Graph repository based on DiGraph."""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
"""Insert a new triple into the directed graph repository.
|
||||
|
||||
Args:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
Example:
|
||||
await my_di_graph_repo.insert(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Adds a directed relationship: Node1 connects_to Node2
|
||||
"""
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
"""Retrieve triples from the directed graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
List[SPO]: A list of SPO objects representing the selected triples.
|
||||
|
||||
Example:
|
||||
selected_triples = await my_di_graph_repo.select(subject="Node1", predicate="connects_to")
|
||||
# Retrieves directed relationships where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
result = []
|
||||
for s, o, p in self._repo.edges(data="predicate"):
|
||||
if subject and subject != s:
|
||||
|
|
@ -44,12 +67,41 @@ class DiGraphRepository(GraphRepository):
|
|||
result.append(SPO(subject=s, predicate=p, object_=o))
|
||||
return result
|
||||
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
"""Delete triples from the directed graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
int: The number of triples deleted from the repository.
|
||||
|
||||
Example:
|
||||
deleted_count = await my_di_graph_repo.delete(subject="Node1", predicate="connects_to")
|
||||
# Deletes directed relationships where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
|
||||
if not rows:
|
||||
return 0
|
||||
for r in rows:
|
||||
self._repo.remove_edge(r.subject, r.object_)
|
||||
return len(rows)
|
||||
|
||||
def json(self) -> str:
|
||||
"""Convert the directed graph repository to a JSON-formatted string."""
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
return data
|
||||
|
||||
async def save(self, path: str | Path = None):
|
||||
"""Save the directed graph repository to a JSON file.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path], optional): The directory path where the JSON file will be saved.
|
||||
If not provided, the default path is taken from the 'root' key in the keyword arguments.
|
||||
"""
|
||||
data = self.json()
|
||||
path = path or self._kwargs.get("root")
|
||||
if not path.exists():
|
||||
|
|
@ -58,12 +110,21 @@ class DiGraphRepository(GraphRepository):
|
|||
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
"""Load a directed graph repository from a JSON file."""
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
"""Create and load a directed graph repository from a JSON file.
|
||||
|
||||
Args:
|
||||
pathname (Union[str, Path]): The path to the JSON file to be loaded.
|
||||
|
||||
Returns:
|
||||
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
|
||||
"""
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
|
|
@ -74,9 +135,16 @@ class DiGraphRepository(GraphRepository):
|
|||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
"""Return the root directory path for the graph repository files."""
|
||||
return self._kwargs.get("root")
|
||||
|
||||
@property
|
||||
def pathname(self) -> Path:
|
||||
"""Return the path and filename to the graph repository file."""
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
|
||||
@property
|
||||
def repo(self):
|
||||
"""Get the underlying directed graph repository."""
|
||||
return self._repo
|
||||
|
|
|
|||
|
|
@ -4,21 +4,28 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : graph_repository.py
|
||||
@Desc : Superclass for graph repository.
|
||||
@Desc : Superclass for graph repository. This script defines a superclass for a graph repository, providing a
|
||||
foundation for specific implementations.
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
"""Basic words for a Graph database.
|
||||
|
||||
This class defines a set of basic words commonly used in the context of a Graph database.
|
||||
"""
|
||||
|
||||
IS = "is"
|
||||
OF = "Of"
|
||||
ON = "On"
|
||||
|
|
@ -28,51 +35,149 @@ class GraphKeyword:
|
|||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_METHOD = "class_method"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_METHOD = "has_class_method"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_DETAIL = "has_detail"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_ARGS_DESC = "has_args_desc"
|
||||
HAS_TYPE_DESC = "has_type_desc"
|
||||
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
|
||||
HAS_CLASS_USE_CASE = "has_class_use_case"
|
||||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
HAS_PARTICIPANT = "has_participant"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
"""Graph repository record type.
|
||||
|
||||
This class represents a record in a graph repository with three components:
|
||||
- Subject: The subject of the triple.
|
||||
- Predicate: The predicate describing the relationship between the subject and the object.
|
||||
- Object: The object of the triple.
|
||||
|
||||
Attributes:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
Example:
|
||||
spo_record = SPO(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Represents a triple: Node1 connects_to Node2
|
||||
"""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object_: str
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
"""Abstract base class for a Graph Repository.
|
||||
|
||||
This class defines the interface for a graph repository, providing methods for inserting, selecting,
|
||||
deleting, and saving graph data. Concrete implementations of this class must provide functionality
|
||||
for these operations.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self._repo_name = name
|
||||
self._kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
"""Insert a new triple into the graph repository.
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
Args:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
Example:
|
||||
await my_repository.insert(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Inserts a triple: Node1 connects_to Node2 into the graph repository.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
"""Retrieve triples from the graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
List[SPO]: A list of SPO objects representing the selected triples.
|
||||
|
||||
Example:
|
||||
selected_triples = await my_repository.select(subject="Node1", predicate="connects_to")
|
||||
# Retrieves triples where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
"""Delete triples from the graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
int: The number of triples deleted from the repository.
|
||||
|
||||
Example:
|
||||
deleted_count = await my_repository.delete(subject="Node1", predicate="connects_to")
|
||||
# Deletes triples where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save(self):
|
||||
"""Save any changes made to the graph repository.
|
||||
|
||||
Example:
|
||||
await my_repository.save()
|
||||
# Persists any changes made to the graph repository.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the name of the graph repository."""
|
||||
return self._repo_name
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
|
||||
"""Insert information of RepoFileInfo into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with information from the given RepoFileInfo object.
|
||||
The function inserts triples related to various dimensions such as file type, class, class method, function,
|
||||
global variable, and page info.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is, [file type])
|
||||
- (?, has class, ?)
|
||||
- (?, is, [class])
|
||||
- (?, has class method, ?)
|
||||
- (?, has function, ?)
|
||||
- (?, is, [function])
|
||||
- (?, is, global variable)
|
||||
- (?, has page info, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
file_info (RepoFileInfo): The RepoFileInfo object containing information to be inserted.
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_file_info(my_graph_repo, my_file_info)
|
||||
# Updates 'my_graph_repo' with information from 'my_file_info'.
|
||||
"""
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
|
||||
|
|
@ -95,13 +200,13 @@ class GraphRepository(ABC):
|
|||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(file_info.file, class_name, fn),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
# file -> function
|
||||
|
|
@ -133,7 +238,34 @@ class GraphRepository(ABC):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
|
||||
"""Insert dot format class information into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with class information from the given list of DotClassInfo objects.
|
||||
The function inserts triples related to various aspects of class views, including source code, file type, class,
|
||||
class property, class detail, method, composition, and aggregation.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is, source code)
|
||||
- (?, is, file type)
|
||||
- (?, has class, ?)
|
||||
- (?, is, class)
|
||||
- (?, has class property, ?)
|
||||
- (?, is, class property)
|
||||
- (?, has detail, ?)
|
||||
- (?, has method, ?)
|
||||
- (?, is composite of, ?)
|
||||
- (?, is aggregate of, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
class_views (List[DotClassInfo]): List of DotClassInfo objects containing class information to be inserted.
|
||||
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_class_views(my_graph_repo, [class_info1, class_info2])
|
||||
# Updates 'my_graph_repo' with class information from the provided list of DotClassInfo objects.
|
||||
"""
|
||||
for c in class_views:
|
||||
filename, _ = c.package.split(":", 1)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
|
|
@ -146,6 +278,7 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
|
||||
for vn, vt in c.attributes.items():
|
||||
# class -> property
|
||||
await graph_db.insert(
|
||||
|
|
@ -160,33 +293,61 @@ class GraphRepository(ABC):
|
|||
object_=GraphKeyword.CLASS_PROPERTY,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=vt.model_dump_json(),
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
if "</I>" in desc and "<I>" not in desc:
|
||||
logger.error(desc)
|
||||
for fn, ft in c.methods.items():
|
||||
# class -> function
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(c.package, fn),
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=ft.model_dump_json(),
|
||||
)
|
||||
for i in c.compositions:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
for i in c.aggregations:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_relationship_views(
|
||||
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
|
||||
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
|
||||
):
|
||||
"""Insert class relationships and labels into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with class relationship information from the given list
|
||||
of DotClassRelationship objects. The function inserts triples representing relationships and labels between
|
||||
classes.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is relationship of, ?)
|
||||
- (?, is relationship on, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects containing
|
||||
class relationship information to be inserted.
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_class_relationship_views(my_graph_repo, [relationship1, relationship2])
|
||||
# Updates 'my_graph_repo' with class relationship information from the provided list of DotClassRelationship objects.
|
||||
|
||||
"""
|
||||
for r in relationship_views:
|
||||
await graph_db.insert(
|
||||
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
|
||||
|
|
@ -198,3 +359,32 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
|
||||
object_=concat_namespace(r.dest, r.label),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
|
||||
"""Append namespace-prefixed information to relationship SPO (Subject-Predicate-Object) objects in the graph
|
||||
repository.
|
||||
|
||||
This function updates the provided graph repository by appending namespace-prefixed information to existing
|
||||
relationship SPO objects.
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
"""
|
||||
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mapping = defaultdict(list)
|
||||
for c in classes:
|
||||
name = split_namespace(c.subject)[-1]
|
||||
mapping[name].append(c.subject)
|
||||
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
|
||||
for r in rows:
|
||||
ns, class_ = split_namespace(r.object_)
|
||||
if ns != "?":
|
||||
continue
|
||||
val = mapping[class_]
|
||||
if len(val) != 1:
|
||||
continue
|
||||
ns_name = val[0]
|
||||
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from metagpt.const import (
|
|||
TASK_PDF_FILE_REPO,
|
||||
TEST_CODES_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
VISUAL_GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
code_summary: FileRepository
|
||||
sd_output: FileRepository
|
||||
code_plan_and_change: FileRepository
|
||||
graph_repo: FileRepository
|
||||
|
||||
def __init__(self, git_repo):
|
||||
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
|
||||
|
|
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
|
||||
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
|
||||
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
|
||||
|
||||
|
||||
class ProjectRepo(FileRepository):
|
||||
|
|
@ -133,6 +136,7 @@ class ProjectRepo(FileRepository):
|
|||
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
return bool(code_files)
|
||||
|
||||
def with_src_path(self, path: str | Path) -> ProjectRepo:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
|
|||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
|
|
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
|
||||
# backslash problem like \" in the output
|
||||
char = rline[col_no - 1]
|
||||
nearest_char_idx = rline[col_no:].find(char)
|
||||
new_line = (
|
||||
rline[: col_no - 1]
|
||||
+ "\\"
|
||||
+ rline[col_no - 1 : col_no + nearest_char_idx]
|
||||
+ "\\"
|
||||
+ rline[col_no + nearest_char_idx :]
|
||||
)
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif not line.endswith(","):
|
||||
|
|
|
|||
|
|
@ -35,9 +35,111 @@ TOKEN_COSTS = {
|
|||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
|
||||
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
|
||||
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
|
||||
"open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025},
|
||||
"open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007},
|
||||
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
|
||||
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
|
||||
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024},
|
||||
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
|
||||
"""
|
||||
QIANFAN_MODEL_TOKEN_COSTS = {
|
||||
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
|
||||
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
|
||||
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
|
||||
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS = {
|
||||
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
|
||||
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
|
||||
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
|
||||
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
|
||||
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
|
||||
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
|
||||
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
|
||||
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
|
||||
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
|
||||
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
|
||||
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
|
||||
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
|
||||
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
|
||||
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
|
||||
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
|
||||
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
|
||||
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
|
||||
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
|
||||
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
|
||||
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
|
||||
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
|
||||
}
|
||||
|
||||
"""
|
||||
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
Different model has different detail page. Attention, some model are free for a limited time.
|
||||
"""
|
||||
DASHSCOPE_TOKEN_COSTS = {
|
||||
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"qwen-max": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
|
||||
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
|
||||
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
FIREWORKS_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
|
|
@ -61,6 +163,19 @@ TOKEN_MAX = {
|
|||
"glm-3-turbo": 128000,
|
||||
"glm-4": 128000,
|
||||
"gemini-pro": 32768,
|
||||
"moonshot-v1-8k": 8192,
|
||||
"moonshot-v1-32k": 32768,
|
||||
"moonshot-v1-128k": 128000,
|
||||
"open-mistral-7b": 8192,
|
||||
"open-mixtral-8x7b": 32768,
|
||||
"mistral-small-latest": 32768,
|
||||
"mistral-medium-latest": 32768,
|
||||
"mistral-large-latest": 32768,
|
||||
"claude-instant-1.2": 100000,
|
||||
"claude-2.0": 100000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-3-sonnet-20240229": 200000,
|
||||
"claude-3-opus-20240229": 200000,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
162
metagpt/utils/visual_graph_repo.py
Normal file
162
metagpt/utils/visual_graph_repo.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : visualize_graph.py
|
||||
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class _VisualClassView(BaseModel):
|
||||
"""Protected class used by VisualGraphRepo internally.
|
||||
|
||||
Attributes:
|
||||
package (str): The package associated with the class.
|
||||
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
|
||||
generalizations (List[str]): List of generalizations for the class.
|
||||
compositions (List[str]): List of compositions for the class.
|
||||
aggregations (List[str]): List of aggregations for the class.
|
||||
"""
|
||||
|
||||
package: str
|
||||
uml: Optional[UMLClassView] = None
|
||||
generalizations: List[str] = Field(default_factory=list)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align: int = 1) -> str:
|
||||
"""Creates a Markdown Mermaid class diagram text.
|
||||
|
||||
Args:
|
||||
align (int): Indent count used for alignment.
|
||||
|
||||
Returns:
|
||||
str: The Markdown text representing the Mermaid class diagram.
|
||||
"""
|
||||
if not self.uml:
|
||||
return ""
|
||||
prefix = "\t" * align
|
||||
|
||||
mermaid_txt = self.uml.get_mermaid(align=align)
|
||||
for i in self.generalizations:
|
||||
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
|
||||
for i in self.compositions:
|
||||
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
|
||||
for i in self.aggregations:
|
||||
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
|
||||
return mermaid_txt
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Returns the class name without the namespace prefix."""
|
||||
return split_namespace(self.package)[-1]
|
||||
|
||||
|
||||
class VisualGraphRepo(ABC):
|
||||
"""Abstract base class for VisualGraphRepo."""
|
||||
|
||||
graph_db: GraphRepository
|
||||
|
||||
def __init__(self, graph_db):
|
||||
self.graph_db = graph_db
|
||||
|
||||
|
||||
class VisualDiGraphRepo(VisualGraphRepo):
|
||||
"""Implementation of VisualGraphRepo for DiGraph graph repository.
|
||||
|
||||
This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def load_from(cls, filename: str | Path):
|
||||
"""Load a VisualDiGraphRepo instance from a file."""
|
||||
graph_db = await DiGraphRepository.load_from(str(filename))
|
||||
return cls(graph_db=graph_db)
|
||||
|
||||
async def get_mermaid_class_view(self) -> str:
|
||||
"""
|
||||
Returns a Markdown Mermaid class diagram code block object.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mermaid_txt = "classDiagram\n"
|
||||
for r in rows:
|
||||
v = await self._get_class_view(ns_class_name=r.subject)
|
||||
mermaid_txt += v.get_mermaid()
|
||||
return mermaid_txt
|
||||
|
||||
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
|
||||
"""Returns the Markdown Mermaid class diagram code block object for the specified class."""
|
||||
rows = await self.graph_db.select(subject=ns_class_name)
|
||||
class_view = _VisualClassView(package=ns_class_name)
|
||||
for r in rows:
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
|
||||
class_view.uml = UMLClassView.model_validate_json(r.object_)
|
||||
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.generalizations.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.compositions.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.aggregations.append(name)
|
||||
return class_view
|
||||
|
||||
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
|
||||
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys."""
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
|
||||
@staticmethod
|
||||
def _refine_name(name: str) -> str:
|
||||
"""Removes impurity content from the given name.
|
||||
|
||||
Example:
|
||||
>>> _refine_name("int")
|
||||
""
|
||||
|
||||
>>> _refine_name('"Class1"')
|
||||
'Class1'
|
||||
|
||||
>>> _refine_name("pkg.Class1")
|
||||
"Class1"
|
||||
"""
|
||||
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
|
||||
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
|
||||
return ""
|
||||
if "." in name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
return name
|
||||
|
||||
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
|
||||
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys."""
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
|
|
@ -11,10 +11,11 @@ typer==0.9.0
|
|||
# godot==0.1.1
|
||||
# google_api_python_client==2.93.0 # Used by search_engine.py
|
||||
lancedb==0.4.0
|
||||
langchain==0.0.352
|
||||
langchain==0.1.8
|
||||
sqlalchemy==2.0.0 # along with langchain
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy>=1.24.3
|
||||
numpy>=1.24.3,<1.25.0
|
||||
openai==1.6.0
|
||||
openpyxl
|
||||
beautifulsoup4==4.12.2
|
||||
|
|
@ -27,13 +28,13 @@ python_docx==0.8.11
|
|||
PyYAML==6.0.1
|
||||
# sentence_transformers==2.2.2
|
||||
setuptools==65.6.3
|
||||
tenacity==8.2.2
|
||||
tenacity==8.2.3
|
||||
tiktoken==0.5.2
|
||||
tqdm==4.65.0
|
||||
#unstructured[local-inference]
|
||||
# selenium>4
|
||||
# webdriver_manager<3.9
|
||||
anthropic==0.8.1
|
||||
anthropic==0.18.1
|
||||
typing-inspect==0.8.0
|
||||
libcst==1.0.1
|
||||
qdrant-client==1.7.0
|
||||
|
|
@ -54,7 +55,7 @@ rich==13.6.0
|
|||
nbclient==0.9.0
|
||||
nbformat==5.9.2
|
||||
ipython==8.17.2
|
||||
ipykernel==6.27.0
|
||||
ipykernel==6.27.1
|
||||
scikit_learn==1.3.2
|
||||
typing-extensions==4.9.0
|
||||
socksio~=1.0.0
|
||||
|
|
@ -68,3 +69,5 @@ anytree
|
|||
ipywidgets==8.1.1
|
||||
Pillow
|
||||
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
|
||||
qianfan==0.3.2
|
||||
dashscope==1.14.1
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -57,7 +57,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr
|
|||
|
||||
setup(
|
||||
name="metagpt",
|
||||
version="0.7.0",
|
||||
version="0.7.4",
|
||||
description="The Multi-Agent Framework",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import aiohttp.web
|
||||
|
|
@ -270,3 +271,11 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
|
|||
aiohttp_mocker.rsp_cache = mermaid_rsp_cache
|
||||
aiohttp_mocker.check_funcs = check_funcs
|
||||
yield check_funcs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_dir():
|
||||
"""Fixture to get the unittest directory."""
|
||||
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir.mkdir(parents=True, exist_ok=True)
|
||||
return git_dir
|
||||
|
|
|
|||
1
tests/data/graph_db/networkx.class_view.json
Normal file
1
tests/data/graph_db/networkx.class_view.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
tests/data/graph_db/networkx.sequence_view.json
Normal file
1
tests/data/graph_db/networkx.sequence_view.json
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
|
|
@ -149,7 +149,7 @@ sequenceDiagram
|
|||
|
||||
The requirement analysis suggests the need for a clean and intuitive interface. Since we are using a command-line interface, we need to ensure that the text-based UI is as user-friendly as possible. Further clarification on whether a graphical user interface (GUI) is expected in the future would be helpful for planning the extendability of the game."""
|
||||
|
||||
TASKS_SAMPLE = """
|
||||
TASK_SAMPLE = """
|
||||
## Required Python packages
|
||||
|
||||
- random==2.2.1
|
||||
|
|
@ -345,7 +345,7 @@ REFINED_DESIGN_JSON = {
|
|||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
REFINED_TASKS_JSON = {
|
||||
REFINED_TASK_JSON = {
|
||||
"Required Python packages": ["random==2.2.1", "Tkinter==8.6"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Refined Logic Analysis": [
|
||||
|
|
@ -373,7 +373,14 @@ REFINED_TASKS_JSON = {
|
|||
}
|
||||
|
||||
CODE_PLAN_AND_CHANGE_SAMPLE = {
|
||||
"Code Plan And Change": '\n1. Plan for gui.py: Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.\n```python\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```\n\n2. Plan for main.py: Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.\n```python\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n'
|
||||
"Development Plan": [
|
||||
"Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.",
|
||||
"Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.",
|
||||
],
|
||||
"Incremental Change": [
|
||||
"""```diff\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```""",
|
||||
"""```diff\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n""",
|
||||
],
|
||||
}
|
||||
|
||||
REFINED_CODE_INPUT_SAMPLE = """
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.mi.ask_review import AskReview
|
||||
from metagpt.actions.di.ask_review import AskReview
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue