Merge pull request #991 from garylin2099/code_interpreter

Solve confilcts
This commit is contained in:
garylin2099 2024-03-12 15:16:37 +08:00 committed by GitHub
commit 25fbab1cd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
146 changed files with 4466 additions and 1375 deletions

View file

@ -54,7 +54,6 @@ jobs:
export ALLOW_OPENAI_API_CALL=0
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -31,7 +31,7 @@ jobs:
- name: Test with pytest
run: |
export ALLOW_OPENAI_API_CALL=0
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -26,7 +26,9 @@ # MetaGPT: The Multi-Agent Framework
</p>
## News
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/mi/README.md), a powerful agent capable of solving a wide range of real-world problems.
🚀 March. 01, 2024: Our Data Interpreter paper is on arxiv. Find all design and benchmark details [here](https://arxiv.org/abs/2402.18679)!
🚀 Feb. 08, 2024: [v0.7.0](https://github.com/geekan/MetaGPT/releases/tag/v0.7.0) released, supporting assigning different LLMs to different Roles. We also introduced [Data Interpreter](https://github.com/geekan/MetaGPT/blob/main/examples/di/README.md), a powerful agent capable of solving a wide range of real-world problems.
🚀 Jan. 16, 2024: Our paper [MetaGPT: Meta Programming for A Multi-Agent Collaborative Framework
](https://arxiv.org/abs/2308.00352) accepted for oral presentation **(top 1.2%)** at ICLR 2024, **ranking #1** in the LLM-based Agent category.
@ -97,6 +99,45 @@ ### Usage
detail installation please refer to [cli_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-stable-version)
or [docker_install](https://docs.deepwisdom.ai/main/en/guide/get_started/installation.html#install-with-docker)
### Docker installation
<details><summary><strong>⏬ Step 1: Download metagpt image and prepare config2.yaml </strong><i>:: click to expand ::</i></summary>
<div>
```bash
docker pull metagpt/metagpt:latest
mkdir -p /opt/metagpt/{config,workspace}
docker run --rm metagpt/metagpt:latest cat /app/metagpt/config/config2.yaml > /opt/metagpt/config/config2.yaml
vim /opt/metagpt/config/config2.yaml # Change the config
```
</div>
</details>
<details><summary><strong>⏬ Step 2: Run metagpt container </strong><i>:: click to expand ::</i></summary>
<div>
```bash
docker run --name metagpt -d \
--privileged \
-v /opt/metagpt/config/config2.yaml:/app/metagpt/config/config2.yaml \
-v /opt/metagpt/workspace:/app/metagpt/workspace \
metagpt/metagpt:latest
```
</div>
</details>
<details><summary><strong>⏬ Step 3: Use metagpt </strong><i>:: click to expand ::</i></summary>
<div>
```bash
docker exec -it metagpt /bin/bash
$ metagpt "Create a 2048 game" # this will create a repo in ./workspace
```
</div>
</details>
### QuickStart & Demo Video
- Try it on [MetaGPT Huggingface Space](https://huggingface.co/spaces/deepwisdom/MetaGPT)
- [Matthew Berman: How To Install MetaGPT - Build A Startup With One Prompt!!](https://youtu.be/uT75J_KG_aY)

14
SECURITY.md Normal file
View file

@ -0,0 +1,14 @@
# Security Policy
## Supported Versions
| Version | Supported |
|---------|--------------------|
| 7.x | :x: |
| 6.x | :x: |
| < 6.x | :x: |
## Reporting a Vulnerability
If you have any vulnerability reports, please contact alexanderwu@deepwisdom.ai .

View file

@ -3,8 +3,16 @@ llm:
base_url: "YOUR_BASE_URL"
api_key: "YOUR_API_KEY"
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
repair_llm_output: true # when the output is not a valid json, try to repair it
proxy: "YOUR_PROXY" # for LLM API requests
pricing_plan: "" # Optional. If invalid, it will be automatically filled in with the value of the `model`.
# Azure-exclusive pricing plan mappings
# - gpt-3.5-turbo 4k: "gpt-3.5-turbo-1106"
# - gpt-4-turbo: "gpt-4-turbo-preview"
# - gpt-4-turbo-vision: "gpt-4-vision-preview"
# - gpt-4 8k: "gpt-4"
# See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
repair_llm_output: true # when the output is not a valid json, try to repair it
proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc.

View file

@ -5,6 +5,7 @@ Author: garylin2099
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `send_to`
value of the `Message` object; modify the argument type of `get_by_actions`.
"""
import asyncio
import platform
from typing import Any
@ -105,4 +106,4 @@ def main(idea: str, investment: float = 3.0, n_round: int = 10):
if __name__ == "__main__":
fire.Fire(main)
fire.Fire(main) # run as python debate.py --idea="TOPIC" --investment=3.0 --n_round=5

View file

@ -8,14 +8,17 @@
import asyncio
from metagpt.actions import Action
from metagpt.config2 import Config
from metagpt.environment import Environment
from metagpt.roles import Role
from metagpt.team import Team
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action1.llm.model = "gpt-4-1106-preview"
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
action2.llm.model = "gpt-3.5-turbo-1106"
gpt35 = Config.default()
gpt35.llm.model = "gpt-3.5-turbo-1106"
gpt4 = Config.default()
gpt4.llm.model = "gpt-4-1106-preview"
action1 = Action(config=gpt4, name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action2 = Action(config=gpt35, name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
env = Environment(desc="US election live broadcast")

18
examples/di/README.md Normal file
View file

@ -0,0 +1,18 @@
# Data Interpreter (DI)
## What is Data Interpreter
Data Interpreter is an agent who solves problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below.
## Example List
- Data visualization
- Machine learning modeling
- Image background removal
- Solve math problems
- Receipt OCR
- Tool usage: web page imitation
- Tool usage: web crawling
- Tool usage: text2image
- Tool usage: email summarization and response
- More on the way!
Please see [here](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html) for detailed explanation.

View file

@ -5,15 +5,15 @@
@File : crawl_webpage.py
"""
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
prompt = """Get data from `paperlist` table in https://papercopilot.com/statistics/iclr-statistics/iclr-2024-statistics/,
and save it to a csv file. paper title must include `multiagent` or `large language model`. *notice: print key variables*"""
mi = Interpreter(use_tools=True)
di = DataInterpreter(use_tools=True)
await mi.run(prompt)
await di.run(prompt)
if __name__ == "__main__":

View file

@ -1,11 +1,11 @@
import asyncio
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
mi = Interpreter(use_tools=False)
await mi.run(requirement)
di = DataInterpreter(use_tools=False)
await di.run(requirement)
if __name__ == "__main__":

View file

@ -6,7 +6,7 @@
"""
import os
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
@ -22,9 +22,9 @@ async def main():
Firstly, Please help me fetch the latest 5 senders and full letter contents.
Then, summarize each of the 5 emails into one sentence (you can do this by yourself, no need to import other models to do this) and output them in a markdown format."""
mi = Interpreter(use_tools=True)
di = DataInterpreter(use_tools=True)
await mi.run(prompt)
await di.run(prompt)
if __name__ == "__main__":

View file

@ -5,7 +5,7 @@
@Author : mannaandpoem
@File : imitate_webpage.py
"""
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
@ -15,9 +15,9 @@ Firstly, utilize Selenium and WebDriver for rendering.
Secondly, convert image to a webpage including HTML, CSS and JS in one go.
Finally, save webpage in a text file.
Note: All required dependencies and environments have been fully installed and configured."""
mi = Interpreter(use_tools=True)
di = DataInterpreter(use_tools=True)
await mi.run(prompt)
await di.run(prompt)
if __name__ == "__main__":

View file

@ -0,0 +1,13 @@
import fire
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(auto_run: bool = True):
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
di = DataInterpreter(auto_run=auto_run)
await di.run(requirement)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,6 +1,6 @@
import asyncio
from metagpt.roles.mi.ml_engineer import MLEngineer
from metagpt.roles.di.ml_engineer import MLEngineer
async def main(requirement: str):

View file

@ -0,0 +1,21 @@
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main():
# Notice: pip install metagpt[ocr] before using this example
image_path = "image.jpg"
language = "English"
requirement = f"""This is a {language} receipt image.
Your goal is to perform OCR on images using PaddleOCR, output text content from the OCR results and discard
coordinates and confidence levels, then recognize the total amount from ocr text content, and finally save as table.
Image path: {image_path}.
NOTE: The environments for Paddle and PaddleOCR are all ready and has been fully installed."""
di = DataInterpreter()
await di.run(requirement)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -1,11 +1,11 @@
import asyncio
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
mi = Interpreter(use_tools=False)
await mi.run(requirement)
di = DataInterpreter(use_tools=False)
await di.run(requirement)
if __name__ == "__main__":

View file

@ -4,12 +4,12 @@
# @Desc :
import asyncio
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
mi = Interpreter(use_tools=True, goal=requirement)
await mi.run(requirement)
di = DataInterpreter(use_tools=True, goal=requirement)
await di.run(requirement)
if __name__ == "__main__":

View file

@ -1,11 +1,11 @@
import asyncio
from metagpt.roles.mi.interpreter import Interpreter
from metagpt.roles.di.data_interpreter import DataInterpreter
async def main(requirement: str = ""):
mi = Interpreter(use_tools=False)
await mi.run(requirement)
di = DataInterpreter(use_tools=False)
await di.run(requirement)
if __name__ == "__main__":

View file

@ -13,7 +13,18 @@ from metagpt.logs import logger
async def main():
llm = LLM()
logger.info(await llm.aask("hello world"))
# llm type check
question = "what's your name"
logger.info(f"{question}: ")
logger.info(await llm.aask(question))
logger.info("\n\n")
logger.info(
await llm.aask(
"who are you", system_msgs=["act as a robot, just answer 'I'am robot' if the question is 'who are you'"]
)
)
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]

View file

@ -1,18 +0,0 @@
# MetaGPT Interpreter (MI)
## What is Interpreter
Interpreter is an agent who solves problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below.
## Example List
- Data visualization
- Machine learning modeling
- Image background removal
- Solve math problems
- Receipt OCR
- Tool usage: web page imitation
- Tool usage: web crawling
- Tool usage: text2image
- Tool usage: email summarization and response
- More on the way!
Please see [here](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/mi_intro.html) for detailed explanation.

View file

@ -1,23 +0,0 @@
import fire
from metagpt.roles.mi.interpreter import Interpreter
WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
DATA_DIR = "path/to/your/data"
# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data
SALES_FORECAST_REQ = f"""Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset, the others is train dataset), include plot total sales trends, print metric and plot scatter plots of
groud truth and predictions on validation data. Dataset is {DATA_DIR}/train.csv, the metric is weighted mean absolute error (WMAE) for test data. Notice: *print* key variables to get more information for next task step.
"""
REQUIREMENTS = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ}
async def main(auto_run: bool = True, use_case: str = "wine"):
mi = Interpreter(auto_run=auto_run)
requirement = REQUIREMENTS[use_case]
await mi.run(requirement)
if __name__ == "__main__":
fire.Fire(main)

View file

@ -1,19 +0,0 @@
from metagpt.roles.mi.interpreter import Interpreter
async def main():
# Notice: pip install metagpt[ocr] before using this example
image_path = "image.jpg"
language = "English"
requirement = f"""This is a {language} receipt image.
Your goal is to perform OCR on images using PaddleOCR, then extract the total amount from ocr text results, and finally save as table. Image path: {image_path}.
NOTE: The environments for Paddle and PaddleOCR are all ready and has been fully installed."""
mi = Interpreter()
await mi.run(requirement)
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import shutil
from pathlib import Path
import typer
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@app.command("", help="Python project reverse engineering.")
def startup(
project_root: str = typer.Argument(
default="",
help="Specify the root directory of the existing project for reverse engineering.",
),
output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."),
):
package_root = Path(project_root)
if not package_root.exists():
raise FileNotFoundError(f"{project_root} not exists")
if not _is_python_package_root(package_root):
raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".')
init_file = package_root / "__init__.py" # used by pyreverse
init_file_exists = init_file.exists()
if not init_file_exists:
init_file.touch()
if not output_dir:
output_dir = package_root / "../reverse_engineering_output"
logger.info(f"output dir:{output_dir}")
try:
asyncio.run(reverse_engineering(package_root, Path(output_dir)))
finally:
if not init_file_exists:
init_file.unlink(missing_ok=True)
tmp_dir = package_root / "__dot__"
if tmp_dir.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)
def _is_python_package_root(package_root: Path) -> bool:
for file_path in package_root.iterdir():
if file_path.is_file():
if file_path.suffix == ".py":
return True
return False
async def reverse_engineering(package_root: Path, output_dir: Path):
ctx = Context()
ctx.git_repo = GitRepository(output_dir)
ctx.repo = ProjectRepo(ctx.git_repo)
action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx)
await action.run()
action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx)
await action.run()
if __name__ == "__main__":
app()

View file

@ -4,21 +4,17 @@
"""
import asyncio
from metagpt.config2 import Config
from metagpt.roles import Searcher
from metagpt.tools.search_engine import SearchEngine, SearchEngineType
from metagpt.tools.search_engine import SearchEngine
async def main():
question = "What are the most interesting human facts?"
kwargs = {"api_key": "", "cse_id": "", "proxy": None}
# Serper API
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question)
# SerpAPI
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question)
# Google API
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question)
# DDG API
await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question)
search = Config.default().search
kwargs = {"api_key": search.api_key, "cse_id": search.cse_id, "proxy": None}
await Searcher(search_engine=SearchEngine(engine=search.api_type, **kwargs)).run(question)
if __name__ == "__main__":

View file

@ -14,6 +14,22 @@ from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
class Chapter(BaseModel):
name: str = Field(default="Chapter 1", description="The name of the chapter.")
content: str = Field(default="...", description="The content of the chapter. No more than 1000 words.")
class Chapters(BaseModel):
chapters: List[Chapter] = Field(
default=[
{"name": "Chapter 1", "content": "..."},
{"name": "Chapter 2", "content": "..."},
{"name": "Chapter 3", "content": "..."},
],
description="The chapters of the novel.",
)
class Novel(BaseModel):
name: str = Field(default="The Lord of the Rings", description="The name of the novel.")
user_group: str = Field(default="...", description="The user group of the novel.")
@ -28,22 +44,17 @@ class Novel(BaseModel):
ending: str = Field(default="...", description="The ending of the novel.")
class Chapter(BaseModel):
name: str = Field(default="Chapter 1", description="The name of the chapter.")
content: str = Field(default="...", description="The content of the chapter. No more than 1000 words.")
async def generate_novel():
instruction = (
"Write a novel named 'Harry Potter in The Lord of the Rings'. "
"Write a novel named 'Reborn in Skyrim'. "
"Fill the empty nodes with your own ideas. Be creative! Use your own words!"
"I will tip you $100,000 if you write a good novel."
)
novel_node = await ActionNode.from_pydantic(Novel).fill(context=instruction, llm=LLM())
chap_node = await ActionNode.from_pydantic(Chapter).fill(
chap_node = await ActionNode.from_pydantic(Chapters).fill(
context=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM()
)
print(chap_node.content)
print(chap_node.instruct_content)
asyncio.run(generate_novel())

View file

@ -22,9 +22,9 @@ from metagpt.actions.write_code_review import WriteCodeReview
from metagpt.actions.write_prd import WritePRD
from metagpt.actions.write_prd_review import WritePRDReview
from metagpt.actions.write_test import WriteTest
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
from metagpt.actions.mi.write_analysis_code import WriteCodeWithTools
from metagpt.actions.mi.write_plan import WritePlan
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.actions.di.write_analysis_code import WriteCodeWithTools
from metagpt.actions.di.write_plan import WritePlan
class ActionType(Enum):

View file

@ -8,7 +8,6 @@
from typing import List
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
from metagpt.utils.mermaid import MMC1, MMC2
IMPLEMENTATION_APPROACH = ActionNode(
@ -109,14 +108,3 @@ REFINED_NODES = [
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES)
def main():
prompt = DESIGN_API_NODE.compile(context="")
logger.info(prompt)
prompt = REFINED_DESIGN_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":
main()

View file

@ -9,7 +9,7 @@ from __future__ import annotations
import json
from metagpt.actions import Action
from metagpt.prompts.mi.write_analysis_code import (
from metagpt.prompts.di.write_analysis_code import (
CHECK_DATA_PROMPT,
DEBUG_REFLECTION_EXAMPLE,
INTERPRETER_SYSTEM_MSG,

View file

@ -8,7 +8,6 @@
from typing import List
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
REQUIRED_PYTHON_PACKAGES = ActionNode(
key="Required Python packages",
@ -119,14 +118,3 @@ REFINED_NODES = [
PM_NODE = ActionNode.from_children("PM_NODE", NODES)
REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES)
def main():
prompt = PM_NODE.compile(context="")
logger.info(prompt)
prompt = REFINED_PM_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":
main()

View file

@ -4,10 +4,12 @@
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view.py
@Desc : Rebuild class view info
@Desc : Reconstructs class diagram from a source code project.
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
"""
import re
from pathlib import Path
from typing import Optional, Set, Tuple
import aiofiles
@ -21,86 +23,144 @@ from metagpt.const import (
GRAPH_REPO_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
from metagpt.utils.common import split_namespace
from metagpt.repo_parser import DotClassInfo, RepoParser
from metagpt.schema import UMLClassView
from metagpt.utils.common import concat_namespace, split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
"""
Reconstructs a graph repository about class diagram from a source code project.
Attributes:
graph_db (Optional[GraphRepository]): The optional graph repository.
"""
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
format (str): The format for the prompt schema.
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.i_context))
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
await GraphRepository.rebuild_composition_relationship(self.graph_db)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.
file_info.file = self._align_root(file_info.file, direction, diff_path)
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_views(graph_db=graph_db)
await graph_db.save()
await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info)
await self._create_mermaid_class_views()
await self.graph_db.save()
async def _create_mermaid_class_views(self, graph_db):
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
async def _create_mermaid_class_views(self) -> str:
"""Creates a Mermaid class diagram using data from the `graph_db` graph repository.
This method utilizes information stored in the graph repository to generate a Mermaid class diagram.
Returns:
mermaid class diagram file name.
"""
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
filename = str(pathname.with_suffix(".class_diagram.mmd"))
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
await writer.write(content)
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
class_distinct = set()
relationship_distinct = set()
for r in rows:
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
content = await self._create_mermaid_class(r.subject)
if content:
await writer.write(content)
class_distinct.add(r.subject)
for r in rows:
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
content, distinct = await self._create_mermaid_relationship(r.subject)
if content:
logger.debug(content)
await writer.write(content)
relationship_distinct.update(distinct)
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
if self.i_context:
r_filename = Path(filename).relative_to(self.context.git_repo.workdir)
await self.graph_db.insert(
subject=self.i_context, predicate="hasMermaidClassDiagramFile", object_=str(r_filename)
)
logger.info(f"{self.i_context} hasMermaidClassDiagramFile {filename}")
return filename
async def _create_mermaid_class(self, ns_class_name) -> str:
"""Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository.
Args:
ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created.
Returns:
str: A Mermaid code block object in markdown representing the class diagram.
"""
fields = split_namespace(ns_class_name)
if len(fields) > 2:
# Ignore sub-class
return
return ""
class_view = ClassView(name=fields[1])
rows = await graph_db.select(subject=ns_class_name)
for r in rows:
name = split_namespace(r.object_)[-1]
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
attribute = ClassAttribute(
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
)
class_view.attributes.append(attribute)
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
class_view.methods.append(method)
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return ""
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
class_view = UMLClassView.load_dot_class_info(dot_class_info)
# update graph db
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
# update uml view
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
# update uml isCompositeOf
for c in dot_class_info.compositions:
await self.graph_db.insert(
subject=ns_class_name,
predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF,
object_=concat_namespace("?", c),
)
# update uml isAggregateOf
for a in dot_class_info.aggregations:
await self.graph_db.insert(
subject=ns_class_name,
predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF,
object_=concat_namespace("?", a),
)
content = class_view.get_mermaid(align=1)
logger.debug(content)
await file_writer.write(content)
distinct.add(ns_class_name)
return content
@staticmethod
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[Optional[str], Optional[Set]]:
"""Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository.
Args:
ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created.
Returns:
Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication.
"""
s_fields = split_namespace(ns_class_name)
if len(s_fields) > 2:
# Ignore sub-class
return
return None, None
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
mappings = {
@ -109,8 +169,9 @@ class RebuildClassView(Action):
AGGREGATION: " o-- ",
}
content = ""
distinct = set()
for p, v in predicates.items():
rows = await graph_db.select(subject=ns_class_name, predicate=p)
rows = await self.graph_db.select(subject=ns_class_name, predicate=p)
for r in rows:
o_fields = split_namespace(r.object_)
if len(o_fields) > 2:
@ -121,86 +182,26 @@ class RebuildClassView(Action):
distinct.add(link)
content += f"\t{link}\n"
if content:
logger.debug(content)
await file_writer.write(content)
@staticmethod
def _parse_name(name: str, language="python"):
pattern = re.compile(r"<I>(.*?)<\/I>")
result = re.search(pattern, name)
abstraction = ""
if result:
name = result.group(1)
abstraction = "*"
if name.startswith("__"):
visibility = "-"
elif name.startswith("_"):
visibility = "#"
else:
visibility = "+"
return name, visibility, abstraction
@staticmethod
async def _parse_variable_type(ns_name, graph_db) -> str:
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
if not rows:
return ""
vals = rows[0].object_.replace("'", "").split(":")
if len(vals) == 1:
return ""
val = vals[-1].strip()
return "" if val == "NoneType" else val + " "
@staticmethod
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
if not rows:
return
info = rows[0].object_.replace("'", "")
fs_tag = "("
ix = info.find(fs_tag)
fe_tag = "):"
eix = info.rfind(fe_tag)
if eix < 0:
fe_tag = ")"
eix = info.rfind(fe_tag)
args_info = info[ix + len(fs_tag) : eix].strip()
method.return_type = info[eix + len(fe_tag) :].strip()
if method.return_type == "None":
method.return_type = ""
if "(" in method.return_type:
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
# parse args
if not args_info:
return
splitter_ixs = []
cost = 0
for i in range(len(args_info)):
if args_info[i] == "[":
cost += 1
elif args_info[i] == "]":
cost -= 1
if args_info[i] == "," and cost == 0:
splitter_ixs.append(i)
splitter_ixs.append(len(args_info))
args = []
ix = 0
for eix in splitter_ixs:
args.append(args_info[ix:eix])
ix = eix + 1
for arg in args:
parts = arg.strip().split(":")
if len(parts) == 1:
method.args.append(ClassAttribute(name=parts[0].strip()))
continue
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
return content, distinct
@staticmethod
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
"""Returns the difference between the root path and the path information represented in the package name.
Args:
path_root (Path): The root path.
package_root (Path): The package root path.
Returns:
Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part.
Example:
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
"-", "metagpt"
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT/metagpt"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
"=", "."
"""
if len(str(path_root)) > len(str(package_root)):
return "+", str(path_root.relative_to(package_root))
if len(str(path_root)) < len(str(package_root)):
@ -208,7 +209,24 @@ class RebuildClassView(Action):
return "=", "."
@staticmethod
def _align_root(path: str, direction: str, diff_path: str):
def _align_root(path: str, direction: str, diff_path: str) -> str:
"""Aligns the path to the same root represented by `diff_path`.
Args:
path (str): The path to be aligned.
direction (str): The direction of alignment ('+', '-', '=').
diff_path (str): The path representing the difference.
Returns:
str: The aligned path.
Example:
>>> _align_root(path="metagpt/software_company.py", direction="+", diff_path="MetaGPT")
"MetaGPT/metagpt/software_company.py"
>>> _align_root(path="metagpt/software_company.py", direction="-", diff_path="metagpt")
"software_company.py"
"""
if direction == "=":
return path
if direction == "+":

View file

@ -4,34 +4,214 @@
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view.py
@Desc : Rebuild sequence view info
@Desc : Reconstruct sequence view information through reverse engineering.
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
"""
from __future__ import annotations
import re
from datetime import datetime
from pathlib import Path
from typing import List
from typing import List, Optional, Set
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
from metagpt.schema import UMLClassView
from metagpt.utils.common import (
add_affix,
aread,
auto_namespace,
concat_namespace,
general_after_log,
list_files,
parse_json_code_block,
read_file_block,
split_namespace,
)
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
class ReverseUseCase(BaseModel):
"""
Represents a reverse engineered use case.
Attributes:
description (str): A description of the reverse use case.
inputs (List[str]): List of inputs for the reverse use case.
outputs (List[str]): List of outputs for the reverse use case.
actors (List[str]): List of actors involved in the reverse use case.
steps (List[str]): List of steps for the reverse use case.
reason (str): The reason behind the reverse use case.
"""
description: str
inputs: List[str]
outputs: List[str]
actors: List[str]
steps: List[str]
reason: str
class ReverseUseCaseDetails(BaseModel):
"""
Represents details of a reverse engineered use case.
Attributes:
description (str): A description of the reverse use case details.
use_cases (List[ReverseUseCase]): List of reverse use cases.
relationship (List[str]): List of relationships associated with the reverse use case details.
"""
description: str
use_cases: List[ReverseUseCase]
relationship: List[str]
class RebuildSequenceView(Action):
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
for entry in entries:
await self._rebuild_sequence_view(entry, graph_db)
await graph_db.save()
"""
Represents an action to reconstruct sequence view through reverse engineering.
@staticmethod
async def _search_main_entry(graph_db) -> List:
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
Attributes:
graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations.
"""
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
format (str): The format for the prompt schema.
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
if not self.i_context:
entries = await self._search_main_entry()
else:
entries = [SPO(subject=self.i_context, predicate="", object_="")]
for entry in entries:
await self._rebuild_main_sequence_view(entry)
while await self._merge_sequence_view(entry):
pass
await self.graph_db.save()
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_main_sequence_view(self, entry: SPO):
"""
Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering.
Args:
entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the
subject `__name__:__main__`.
"""
filename = entry.subject.split(":", 1)[0]
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
classes = []
prefix = filename + ":"
for r in rows:
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
participants = await self._search_participants(split_namespace(entry.subject)[0])
class_details = []
class_views = []
for c in classes:
detail = await self._get_class_detail(c.subject)
if not detail:
continue
class_details.append(detail)
view = await self._get_uml_class_view(c.subject)
if view:
class_views.append(view)
actors = await self._get_participants(c.subject)
participants.update(set(actors))
use_case_blocks = []
for c in classes:
use_cases = await self._get_class_use_cases(c.subject)
use_case_blocks.append(use_cases)
prompt_blocks = ["## Use Cases\n" + "\n".join(use_case_blocks)]
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += "\n\n".join([c.get_mermaid() for c in class_views])
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += await self._get_source_code(filename)
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
"You are a python code to Mermaid Sequence Diagram translator in function detail.",
"Translate the given markdown text to a Mermaid Sequence Diagram.",
"Return the merged Mermaid sequence diagram in a markdown code block format.",
],
stream=False,
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.insert(
subject=entry.subject,
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
)
for c in classes:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram.
Args:
entry (SPO): The SPO object representing the relationship in the graph database.
Returns:
bool: True if additional information has been augmented, otherwise False.
"""
new_participant = await self._search_new_participant(entry)
if not new_participant:
return False
await self._merge_participant(entry, new_participant)
return True
async def _search_main_entry(self) -> List:
"""
Asynchronously searches for the SPO object that is related to `__name__:__main__`.
Returns:
List: A list containing information about the main entry in the sequence diagram.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
for r in rows:
@ -39,24 +219,395 @@ class RebuildSequenceView(Action):
entries.append(r)
return entries
async def _rebuild_sequence_view(self, entry, graph_db):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_use_case(self, ns_class_name: str):
"""
Asynchronously reconstructs the use case for the provided namespace-prefixed class name.
Args:
ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
if rows:
return
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
detail = await self._get_class_detail(ns_class_name)
if not detail:
return
participants = set()
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
class_view = await self._get_uml_class_view(ns_class_name)
source_code = await self._get_source_code(ns_class_name)
# prompt_blocks = [
# "## Instruction\n"
# "You are a python code to UML 2.0 Use Case translator.\n"
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
# 'conflict with the information in "Mermaid Class Views".\n'
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
# "system interactions with the internal system.\n"
# ]
prompt_blocks = []
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += class_view.get_mermaid()
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += source_code
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
"You are a python code to UML 2.0 Use Case translator.",
'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".',
"The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
'conflict with the information in "Mermaid Class Views".',
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
"system interactions with the internal system.",
"Return a markdown JSON object with:\n"
'- a "description" key to explain what the whole source code want to do;\n'
'- a "use_cases" key list all use cases, each use case in the list should including a `description` '
"key describes about what the use case to do, a `inputs` key lists the input names of the use case "
"from external sources, a `outputs` key lists the output names of the use case to external sources, "
"a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how "
"the use case works step by step, a `reason` key explaining under what circumstances would the "
"external system execute this use case.\n"
'- a "relationship" key lists all the descriptions of relationship among these use cases.\n',
],
stream=False,
)
code_blocks = parse_json_code_block(rsp)
for block in code_blocks:
detail = ReverseUseCaseDetails.model_validate_json(block)
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_sequence_view(self, ns_class_name: str):
"""
Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name.
Args:
ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed.
"""
await self._rebuild_use_case(ns_class_name)
prompts_blocks = []
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
participants = await self._get_participants(ns_class_name)
block = "## Participants\n" + "\n".join([f"- {s}" for s in participants])
prompts_blocks.append(block)
view = await self._get_uml_class_view(ns_class_name)
block = "## Mermaid Class Views\n```mermaid\n"
block += view.get_mermaid()
block += "\n```\n"
prompts_blocks.append(block)
block = "## Source Code\n```python\n"
block += await self._get_source_code(ns_class_name)
block += "\n```\n"
prompts_blocks.append(block)
prompt = "\n---\n".join(prompts_blocks)
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a Mermaid Sequence Diagram translator in function detail.",
"Translate the markdown text to a Mermaid Sequence Diagram.",
"Return a markdown mermaid code block.",
],
stream=False,
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO
object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list.
Returns:
List[str]: A list of participants in the sequence diagram.
"""
participants = set()
detail = await self._get_class_detail(ns_class_name)
if not detail:
return []
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
return list(participants)
async def _get_class_use_cases(self, ns_class_name: str) -> str:
"""
Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information.
Returns:
str: A string containing the assembled context about the use case information.
"""
block = ""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
for i, r in enumerate(rows):
detail = ReverseUseCaseDetails.model_validate_json(r.object_)
block += f"\n### {i + 1}. {detail.description}"
for j, use_case in enumerate(detail.use_cases):
block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n"
block += "\n##### Inputs\n" + "\n".join([f"- {s}" for s in use_case.inputs])
block += "\n##### Outputs\n" + "\n".join([f"- {s}" for s in use_case.outputs])
block += "\n##### Actors\n" + "\n".join([f"- {s}" for s in use_case.actors])
block += "\n##### Steps\n" + "\n".join([f"- {s}" for s in use_case.steps])
block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship])
return block + "\n"
async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None:
"""
Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve class details.
Returns:
Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details,
or None if the details are not available.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return None
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
return dot_class_info
async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None:
"""
Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details.
Returns:
Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details,
or None if the details are not available.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not rows:
return None
class_view = UMLClassView.model_validate_json(rows[0].object_)
return class_view
async def _get_source_code(self, ns_class_name: str) -> str:
"""
Asynchronously retrieves the source code of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code.
Returns:
str: A string containing the source code of the specified namespace-prefixed class.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
filename = split_namespace(ns_class_name=ns_class_name)[0]
if not rows:
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return ""
return await aread(filename=src_filename, encoding="utf-8")
code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_)
return await read_file_block(
filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno
)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
logger.info(data)
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
"""
Convert package name to the full path of the module.
Args:
root (Union[str, Path]): The root path or string representing the package.
pathname (Union[str, Path]): The pathname or string representing the module.
Returns:
Union[Path, None]: The full path of the module, or None if the path cannot be determined.
Examples:
If `root`(workdir) is "/User/xxx/github/MetaGPT/metagpt", and the `pathname` is
"metagpt/management/skill_manager.py", then the returned value will be
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
"""
if re.match(r"^/.+", pathname):
return pathname
files = list_files(root=root)
postfix = "/" + str(pathname)
for i in files:
if str(i).endswith(postfix):
return i
return None
@staticmethod
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
"""
Parses the provided Mermaid sequence diagram and returns the list of participants.
Args:
mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed.
Returns:
List[str]: A list of participants extracted from the sequence diagram.
"""
pattern = r"participant ([a-zA-Z\.0-9_]+)"
matches = re.findall(pattern, mermaid_sequence_diagram)
matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches]
return matches
async def _search_new_participant(self, entry: SPO) -> str | None:
"""
Asynchronously retrieves a participant whose sequence diagram has not been augmented.
Args:
entry (SPO): The SPO object representing the relationship in the graph database.
Returns:
Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found.
"""
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
if not rows:
return None
sequence_view = rows[0].object_
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
merged_participants = []
for r in rows:
name = split_namespace(r.object_)[-1]
merged_participants.append(name)
participants = self.parse_participant(sequence_view)
for p in participants:
if p in merged_participants:
continue
return p
return None
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _merge_participant(self, entry: SPO, class_name: str):
"""
Augments the sequence diagram of `class_name` to the sequence diagram of `entry`.
Args:
entry (SPO): The SPO object representing the base sequence diagram.
class_name (str): The class name whose sequence diagram is to be augmented.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
participants = []
for r in rows:
name = split_namespace(r.subject)[-1]
if name == class_name:
participants.append(r)
if len(participants) == 0: # external participants
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
return
participant = participants[0]
await self._rebuild_sequence_view(participant.subject)
sequence_views = await self.graph_db.select(
subject=participant.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW
)
if not sequence_views: # external class
return
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
prompt = f"```mermaid\n{sequence_views[0].object_}\n```\n---\n```mermaid\n{rows[0].object_}\n```"
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a tool to merge sequence diagrams into one.",
"Participants with the same name are considered identical.",
"Return the merged Mermaid sequence diagram in a markdown code block format.",
],
stream=False,
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.insert(
subject=entry.subject,
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
async def _save_sequence_view(self, subject: str, content: str):
pattern = re.compile(r"[^a-zA-Z0-9]")
name = re.sub(pattern, "_", subject)
filename = Path(name).with_suffix(".sequence_diagram.mmd")
await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)
async def _search_participants(self, filename: str) -> Set:
content = await self._get_source_code(filename)
rsp = await self.llm.aask(
msg=content,
system_msgs=[
"You are a tool for listing all class names used in a source file.",
"Return a markdown JSON object with: "
'- a "class_names" key containing the list of class names used in the file; '
'- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
],
)
class _Data(BaseModel):
class_names: List[str]
reasons: List
json_blocks = parse_json_code_block(rsp)
data = _Data.model_validate_json(json_blocks[0])
return set(data.class_names)

View file

@ -1,16 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view_an.py
"""
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC2
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
key="Program call flow",
expected_type=str,
instruction='Translate the "context" content into "format example" format.',
example=MMC2,
)

View file

@ -50,6 +50,7 @@ class ArgumentsParingAction(Action):
rsp = await self.llm.aask(
msg=prompt,
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
stream=False,
)
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)

View file

@ -92,7 +92,7 @@ class TalkAction(Action):
async def run(self, with_message=None, **kwargs) -> Message:
msg, format_msgs, system_msgs = self.aask_args
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False)
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
return self.rsp

View file

@ -23,11 +23,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE
from metagpt.const import (
BUGFIX_FILENAME,
CODE_PLAN_AND_CHANGE_FILENAME,
REQUIREMENT_FILENAME,
)
from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME
from metagpt.logs import logger
from metagpt.schema import CodingContext, Document, RunCodeResult
from metagpt.utils.common import CodeParser
@ -98,8 +94,6 @@ class WriteCode(Action):
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
coding_context = CodingContext.loads(self.i_context.content)
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
code_plan_and_change = code_plan_and_change_doc.content if code_plan_and_change_doc else ""
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
@ -111,7 +105,7 @@ class WriteCode(Action):
if bug_feedback:
code_context = coding_context.code_doc.content
elif code_plan_and_change:
elif self.config.inc:
code_context = await self.get_codes(
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
)
@ -122,10 +116,10 @@ class WriteCode(Action):
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
if code_plan_and_change:
if self.config.inc:
prompt = REFINED_TEMPLATE.format(
user_requirement=requirement_doc.content if requirement_doc else "",
code_plan_and_change=code_plan_and_change,
code_plan_and_change=str(coding_context.code_plan_and_change_doc),
design=coding_context.design_doc.content if coding_context.design_doc else "",
task=coding_context.task_doc.content if coding_context.task_doc else "",
code=code_context,

View file

@ -6,30 +6,44 @@
@File : write_code_plan_and_change_an.py
"""
import os
from typing import List
from pydantic import Field
from metagpt.actions.action import Action
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
from metagpt.schema import CodePlanAndChangeContext
CODE_PLAN_AND_CHANGE = ActionNode(
key="Code Plan And Change",
expected_type=str,
instruction="Developing comprehensive and step-by-step incremental development plan, and write Incremental "
"Change by making a code draft that how to implement incremental development including detailed steps based on the "
"context. Note: Track incremental changes using mark of '+' or '-' for add/modify/delete code, and conforms to the "
"output format of git diff",
example="""
1. Plan for calculator.py: Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, multiplication, and division. Additionally, implement robust error handling for the division operation to mitigate potential issues related to division by zero.
```python
DEVELOPMENT_PLAN = ActionNode(
key="Development Plan",
expected_type=List[str],
instruction="Develop a comprehensive and step-by-step incremental development plan, providing the detail "
"changes to be implemented at each step based on the order of 'Task List'",
example=[
"Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, ...",
"Update the existing codebase in main.py to incorporate new API endpoints for subtraction, ...",
],
)
INCREMENTAL_CHANGE = ActionNode(
key="Incremental Change",
expected_type=List[str],
instruction="Write Incremental Change by making a code draft that how to implement incremental development "
"including detailed steps based on the context. Note: Track incremental changes using the marks `+` and `-` to "
"indicate additions and deletions, and ensure compliance with the output format of `git diff`",
example=[
'''```diff
--- Old/calculator.py
+++ New/calculator.py
class Calculator:
self.result = number1 + number2
return self.result
- def sub(self, number1, number2) -> float:
+ def subtract(self, number1: float, number2: float) -> float:
+ '''
+ """
+ Subtracts the second number from the first and returns the result.
+
+ Args:
@ -38,13 +52,13 @@ class Calculator:
+
+ Returns:
+ float: The difference of number1 and number2.
+ '''
+ """
+ self.result = number1 - number2
+ return self.result
+
def multiply(self, number1: float, number2: float) -> float:
- pass
+ '''
+ """
+ Multiplies two numbers and returns the result.
+
+ Args:
@ -53,15 +67,15 @@ class Calculator:
+
+ Returns:
+ float: The product of number1 and number2.
+ '''
+ """
+ self.result = number1 * number2
+ return self.result
+
def divide(self, number1: float, number2: float) -> float:
- pass
+ '''
+ """
+ ValueError: If the second number is zero.
+ '''
+ """
+ if number2 == 0:
+ raise ValueError('Cannot divide by zero')
+ self.result = number1 / number2
@ -75,10 +89,11 @@ class Calculator:
+ print("Result is already zero, no need to clear.")
+
self.result = 0.0
```
```''',
"""```diff
--- Old/main.py
+++ New/main.py
2. Plan for main.py: Integrate new API endpoints for subtraction, multiplication, and division into the existing codebase of `main.py`. Then, ensure seamless integration with the overall application architecture and maintain consistency with coding standards.
```python
def add_numbers():
result = calculator.add_numbers(num1, num2)
return jsonify({'result': result}), 200
@ -106,6 +121,7 @@ def add_numbers():
if __name__ == '__main__':
app.run()
```""",
],
)
CODE_PLAN_AND_CHANGE_CONTEXT = """
@ -172,14 +188,16 @@ Role: You are a professional engineer; The main goal is to complete incremental
2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.
3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import.
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
5. Follow Code Plan And Change: If there is any Incremental Change that is marked by the git diff format using '+' and '-' for add/modify/delete code, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the plan.
5. Follow Code Plan And Change: If there is any "Incremental Change" that is marked by the git diff format with '+' and '-' symbols, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the "Development Plan".
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
7. Before using a external variable/module, make sure you import it first.
8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.
9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code.
"""
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", [CODE_PLAN_AND_CHANGE])
CODE_PLAN_AND_CHANGE = [DEVELOPMENT_PLAN, INCREMENTAL_CHANGE]
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", CODE_PLAN_AND_CHANGE)
class WriteCodePlanAndChange(Action):
@ -192,14 +210,14 @@ class WriteCodePlanAndChange(Action):
prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename)
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
code_text = await self.get_old_codes()
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
requirement=self.i_context.requirement,
prd=prd_doc.content,
design=design_doc.content,
task=task_doc.content,
code=code_text,
code=await self.get_old_codes(),
)
logger.info("Writing code plan and change..")
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
async def get_old_codes(self) -> str:

View file

@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import WriteCode
from metagpt.actions.action import Action
from metagpt.const import CODE_PLAN_AND_CHANGE_FILENAME, REQUIREMENT_FILENAME
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.logs import logger
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser
@ -149,29 +149,21 @@ class WriteCodeReview(Action):
use_inc=self.config.inc,
)
if not self.config.inc:
context = "\n".join(
[
"## System Design\n" + str(self.i_context.design_doc) + "\n",
"## Task\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",
]
)
else:
ctx_list = [
"## System Design\n" + str(self.i_context.design_doc) + "\n",
"## Task\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",
]
if self.config.inc:
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
context = "\n".join(
[
"## User New Requirements\n" + str(requirement_doc) + "\n",
"## Code Plan And Change\n" + str(code_plan_and_change_doc) + "\n",
"## System Design\n" + str(self.i_context.design_doc) + "\n",
"## Task\n" + task_content + "\n",
"## Code Files\n" + code_context + "\n",
]
)
insert_ctx_list = [
"## User New Requirements\n" + str(requirement_doc) + "\n",
"## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n",
]
ctx_list = insert_ctx_list + ctx_list
context_prompt = PROMPT_TEMPLATE.format(
context=context,
context="\n".join(ctx_list),
code=iterative_code,
filename=self.i_context.code_doc.filename,
)

View file

@ -56,7 +56,7 @@ REFINED_PRODUCT_GOALS = ActionNode(
key="Refined Product Goals",
expected_type=List[str],
instruction="Update and expand the original product goals to reflect the evolving needs due to incremental "
"development.Ensure that the refined goals align with the current project direction and contribute to its success.",
"development. Ensure that the refined goals align with the current project direction and contribute to its success.",
example=[
"Enhance user engagement through new features",
"Optimize performance for scalability",

View file

@ -16,6 +16,7 @@ from metagpt.utils.yaml_model import YamlModel
class LLMType(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
CLAUDE = "claude" # alias name of anthropic
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
@ -24,6 +25,10 @@ class LLMType(Enum):
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
MOONSHOT = "moonshot"
MISTRAL = "mistral"
def __missing__(self, key):
return self.OPENAI
@ -36,12 +41,18 @@ class LLMConfig(YamlModel):
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str
api_key: str = "sk-"
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters.
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None

View file

@ -84,7 +84,6 @@ MESSAGE_ROUTE_TO_NONE = "<none>"
REQUIREMENT_FILENAME = "requirement.txt"
BUGFIX_FILENAME = "bugfix.txt"
PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt"
CODE_PLAN_AND_CHANGE_FILENAME = "code_plan_and_change.json"
DOCS_FILE_REPO = "docs"
PRDS_FILE_REPO = "docs/prd"
@ -105,6 +104,7 @@ CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/sd_output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db"
CLASS_VIEW_FILE_REPO = "docs/class_view"
YAPI_URL = "http://yapi.deepwisdomai.com/"

View file

@ -12,10 +12,14 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import (
CostManager,
FireworksCostManager,
TokenCostManager,
)
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
@ -80,12 +84,21 @@ class Context(BaseModel):
# self._llm = None
# return self._llm
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
"""Return a CostManager instance"""
if llm_config.api_type == LLMType.FIREWORKS:
return FireworksCostManager()
elif llm_config.api_type == LLMType.OPEN_LLM:
return TokenCostManager()
else:
return self.cost_manager
def llm(self) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
self._llm = create_llm_instance(self.config.llm)
if self._llm.cost_manager is None:
self._llm.cost_manager = self.cost_manager
self._llm.cost_manager = self._select_costmanager(self.config.llm)
return self._llm
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
@ -93,5 +106,5 @@ class Context(BaseModel):
# if self._llm is None:
llm = create_llm_instance(llm_config)
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
llm.cost_manager = self._select_costmanager(llm_config)
return llm

View file

@ -11,15 +11,16 @@ from pathlib import Path
from typing import Optional, Union
import pandas as pd
from langchain.document_loaders import (
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import (
TextLoader,
UnstructuredPDFLoader,
UnstructuredWordDocumentLoader,
)
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
@ -130,9 +131,12 @@ class IndexableDocument(Document):
if isinstance(data, pd.DataFrame):
validate_cols(content_col, data)
return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col)
else:
try:
content = data_path.read_text()
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
except Exception as e:
logger.debug(f"Load {str(data_path)} error: {e}")
content = ""
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
def _get_docs_and_metadatas_by_df(self) -> (list, list):
df = self.data

View file

@ -186,7 +186,7 @@ class BrainMemory(BaseModel):
summaries = [summary, command]
msg = "\n".join(summaries)
logger.debug(f"title ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
response = await llm.aask(msg=msg, system_msgs=[], stream=False)
logger.debug(f"title rsp: {response}")
return response
@ -201,11 +201,15 @@ class BrainMemory(BaseModel):
@staticmethod
async def _openai_is_related(text1, text2, llm, **kwargs):
command = (
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
context = f"## Paragraph 1\n{text2}\n---\n## Paragraph 2\n{text1}\n"
rsp = await llm.aask(
msg=context,
system_msgs=[
"You are a tool capable of determining whether two paragraphs are semantically related."
'Return "TRUE" if "Paragraph 1" is semantically relevant to "Paragraph 2", otherwise return "FALSE".'
],
stream=False,
)
rsp = await llm.aask(msg=command, system_msgs=[])
result = True if "TRUE" in rsp else False
p2 = text2.replace("\n", "")
p1 = text1.replace("\n", "")
@ -223,12 +227,17 @@ class BrainMemory(BaseModel):
@staticmethod
async def _openai_rewrite(sentence: str, context: str, llm):
command = (
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
prompt = f"## Context\n{context}\n---\n## Sentence\n{sentence}\n"
rsp = await llm.aask(
msg=prompt,
system_msgs=[
'You are a tool augmenting the "Sentence" with information from the "Context".',
"Do not supplement the context with information that is not present, especially regarding the subject and object.",
"Return the augmented sentence.",
],
stream=False,
)
rsp = await llm.aask(msg=command, system_msgs=[])
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
logger.info(f"REWRITE:\nCommand: {prompt}\nRESULT: {rsp}\n")
return rsp
@staticmethod
@ -293,14 +302,14 @@ class BrainMemory(BaseModel):
"""Generate text summary"""
if len(text) < max_words:
return text
system_msgs = [
"You are a tool for summarizing and abstracting text.",
f"Return the summarized text to less than {max_words} words.",
]
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.llm.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
system_msgs.append("The generated summary should be in the same language as the original text.")
response = await self.llm.aask(msg=text, system_msgs=system_msgs, stream=False)
logger.debug(f"{text}\nsummary rsp: {response}")
return response
@staticmethod

View file

@ -7,7 +7,6 @@
from pathlib import Path
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.embedding import get_embedding
from metagpt.utils.serialize import deserialize_message, serialize_message
@ -30,7 +30,7 @@ class MemoryStorage(FaissStore):
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.embedding = embedding or get_embedding()
self.store: FAISS = None # Faiss engine
@property

View file

@ -6,21 +6,20 @@
@File : __init__.py
"""
from metagpt.provider.fireworks_api import FireworksLLM
from metagpt.provider.google_gemini_api import GeminiLLM
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
__all__ = [
"FireworksLLM",
"GeminiLLM",
"OpenLLM",
"OpenAILLM",
"ZhiPuAILLM",
"AzureOpenAILLM",
@ -28,4 +27,7 @@ __all__ = [
"OllamaLLM",
"HumanProvider",
"SparkLLM",
"QianFanLLM",
"DashScopeLLM",
"AnthropicLLM",
]

View file

@ -1,37 +1,71 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/21 11:15
@Author : Leo Xiao
@File : anthropic_api.py
"""
import anthropic
from anthropic import Anthropic, AsyncAnthropic
from anthropic import AsyncAnthropic
from anthropic.types import Message, Usage
from metagpt.configs.llm_config import LLMConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
class Claude2:
@register_provider([LLMType.ANTHROPIC, LLMType.CLAUDE])
class AnthropicLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.__init_anthropic()
def ask(self, prompt: str) -> str:
client = Anthropic(api_key=self.config.api_key)
def __init_anthropic(self):
self.model = self.config.model
self.aclient: AsyncAnthropic = AsyncAnthropic(api_key=self.config.api_key, base_url=self.config.base_url)
res = client.completions.create(
model="claude-2",
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
max_tokens_to_sample=1000,
)
return res.completion
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"model": self.model,
"messages": messages,
"max_tokens": self.config.max_token,
"stream": stream,
}
if self.use_system_prompt:
# if the model support system prompt, extract and pass it
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
return kwargs
async def aask(self, prompt: str) -> str:
aclient = AsyncAnthropic(api_key=self.config.api_key)
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens}
super()._update_costs(usage, model)
res = await aclient.completions.create(
model="claude-2",
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
max_tokens_to_sample=1000,
)
return res.completion
def get_choice_text(self, resp: Message) -> str:
return resp.content[0].text
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
self._update_costs(resp.usage, self.model)
return resp
async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = Usage(input_tokens=0, output_tokens=0)
async for event in stream:
event_type = event.type
if event_type == "message_start":
usage.input_tokens = event.message.usage.input_tokens
usage.output_tokens = event.message.usage.output_tokens
elif event_type == "content_block_delta":
content = event.delta.text
log_llm_stream(content)
collected_content.append(content)
elif event_type == "message_delta":
usage.output_tokens = event.usage.output_tokens # update final output_tokens
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content

View file

@ -6,8 +6,6 @@
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
@ -27,6 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan
def _make_client_kwargs(self) -> dict:
kwargs = dict(

View file

@ -6,16 +6,29 @@
@File : base_llm.py
@Desc : mashenquan, 2023/8/22. + try catch
"""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Dict, Optional, Union
from openai import AsyncOpenAI
from openai.types import CompletionUsage
from pydantic import BaseModel
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
class BaseLLM(ABC):
@ -29,6 +42,7 @@ class BaseLLM(ABC):
aclient: Optional[Union[AsyncOpenAI]] = None
cost_manager: Optional[CostManager] = None
model: Optional[str] = None
pricing_plan: Optional[str] = None
@abstractmethod
def __init__(self, config: LLMConfig):
@ -67,6 +81,28 @@ class BaseLLM(ABC):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
"""update each request's token cost
Args:
model (str): model name or in some scenarios called endpoint
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
model = model or self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
async def aask(
self,
msg: Union[str, list[dict[str, str]]],
@ -108,6 +144,10 @@ class BaseLLM(ABC):
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
raise NotImplementedError
@abstractmethod
async def _achat_completion(self, messages: list[dict], timeout=3):
"""_achat_completion implemented by inherited class"""
@abstractmethod
async def acompletion(self, messages: list[dict], timeout=3):
"""Asynchronous version of completion
@ -120,8 +160,22 @@ class BaseLLM(ABC):
"""
@abstractmethod
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
"""_achat_completion_stream implemented by inherited class"""
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
resp = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(resp)
def get_choice_text(self, rsp: dict) -> str:
"""Required to provide the first text of choice"""
@ -171,6 +225,20 @@ class BaseLLM(ABC):
"""
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
@handle_exception
def _update_costs(self, usage: CompletionUsage | Dict):
"""
Updates the costs based on the provided usage information.
"""
if self.config.calc_usage and usage and self.cost_manager:
if isinstance(usage, Dict):
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
else:
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -0,0 +1,227 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import json
from http import HTTPStatus
from typing import Any, AsyncGenerator, Dict, List, Union
import dashscope
from dashscope.aigc.generation import Generation
from dashscope.api_entities.aiohttp_request import AioHttpRequest
from dashscope.api_entities.api_request_data import ApiRequestData
from dashscope.api_entities.api_request_factory import _get_protocol_params
from dashscope.api_entities.dashscope_response import (
GenerationOutput,
GenerationResponse,
Message,
)
from dashscope.client.base_api import BaseAioApi
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
from dashscope.common.error import (
InputDataRequired,
InputRequired,
ModelRequired,
UnsupportedApiProtocol,
)
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
def build_api_arequest(
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
):
(
api_protocol,
ws_stream_mode,
is_binary_input,
http_method,
stream,
async_request,
query,
headers,
request_timeout,
form,
resources,
) = _get_protocol_params(kwargs)
task_id = kwargs.pop("task_id", None)
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
if not dashscope.base_http_api_url.endswith("/"):
http_url = dashscope.base_http_api_url + "/"
else:
http_url = dashscope.base_http_api_url
if is_service:
http_url = http_url + SERVICE_API_PATH + "/"
if task_group:
http_url += "%s/" % task_group
if task:
http_url += "%s/" % task
if function:
http_url += function
request = AioHttpRequest(
url=http_url,
api_key=api_key,
http_method=http_method,
stream=stream,
async_request=async_request,
query=query,
timeout=request_timeout,
task_id=task_id,
)
else:
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
if headers is not None:
request.add_headers(headers=headers)
if input is None and form is None:
raise InputDataRequired("There is no input data and form data")
request_data = ApiRequestData(
model,
task_group=task_group,
task=task,
function=function,
input=input,
form=form,
is_binary_input=is_binary_input,
api_protocol=api_protocol,
)
request_data.add_resources(resources)
request_data.add_parameters(**kwargs)
request.data = request_data
return request
class AGeneration(Generation, BaseAioApi):
@classmethod
async def acall(
cls,
model: str,
prompt: Any = None,
history: list = None,
api_key: str = None,
messages: List[Message] = None,
plugins: Union[str, Dict[str, Any]] = None,
**kwargs,
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
if (prompt is None or not prompt) and (messages is None or not messages):
raise InputRequired("prompt or messages is required!")
if model is None or not model:
raise ModelRequired("Model is required!")
task_group, function = "aigc", "generation" # fixed value
if plugins is not None:
headers = kwargs.pop("headers", {})
if isinstance(plugins, str):
headers["X-DashScope-Plugin"] = plugins
else:
headers["X-DashScope-Plugin"] = json.dumps(plugins)
kwargs["headers"] = headers
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
api_key, model = BaseAioApi._validate_params(api_key, model)
request = build_api_arequest(
model=model,
input=input,
task_group=task_group,
task=Generation.task,
function=function,
api_key=api_key,
**kwargs,
)
response = await request.aio_call()
is_stream = kwargs.get("stream", False)
if is_stream:
async def aresp_iterator(response):
async for resp in response:
yield GenerationResponse.from_api_response(resp)
return aresp_iterator(response)
else:
return GenerationResponse.from_api_response(response)
@register_provider(LLMType.DASHSCOPE)
class DashScopeLLM(BaseLLM):
def __init__(self, llm_config: LLMConfig):
self.config = llm_config
self.use_system_prompt = False # only some models support system_prompt
self.__init_dashscope()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_dashscope(self):
self.model = self.config.model
self.api_key = self.config.api_key
self.token_costs = DASHSCOPE_TOKEN_COSTS
self.aclient: AGeneration = AGeneration
# check support system_message models
support_system_models = [
"qwen-", # all support
"llama2-", # all support
"baichuan2-7b-chat-v1",
"chatglm3-6b",
]
for support_model in support_system_models:
if support_model in self.model:
self.use_system_prompt = True
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"api_key": self.api_key,
"model": self.model,
"messages": messages,
"stream": stream,
"result_format": "message",
}
if self.config.temperature > 0:
# different model has default temperature. only set when it"s specified.
kwargs["temperature"] = self.config.temperature
if stream:
kwargs["incremental_output"] = True
return kwargs
def _check_response(self, resp: GenerationResponse):
if resp.status_code != HTTPStatus.OK:
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
def get_choice_text(self, output: GenerationOutput) -> str:
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
def completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
self._check_response(chunk)
content = chunk.output.choices[0]["message"]["content"]
usage = dict(chunk.usage) # each chunk has usage
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content

View file

@ -1,131 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : fireworks.ai's api
import re
from openai import APIConnectionError, AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
class FireworksCostManager(CostManager):
def model_grade_token_costs(self, model: str) -> dict[str, float]:
def _get_model_size(model: str) -> float:
size = re.findall(".*-([0-9.]+)b", model)
size = float(size[0]) if len(size) > 0 else -1
return size
if "mixtral-8x7b" in model:
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
else:
model_size = _get_model_size(model)
if 0 < model_size <= 16:
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
elif 16 < model_size <= 80:
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
else:
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
return token_costs
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
"""
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.4f}"
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
@register_provider(LLMType.FIREWORKS)
class FireworksLLM(OpenAILLM):
def __init__(self, config: LLMConfig):
super().__init__(config=config)
self.auto_max_tokens = False
self.cost_manager = FireworksCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)
collected_content = []
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
# iterate through the stream of events
async for chunk in response:
if chunk.choices:
choice = chunk.choices[0]
choice_delta = choice.delta
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
if choice_delta.content:
collected_content.append(choice_delta.content)
print(choice_delta.content, end="")
if finish_reason:
# fireworks api return usage when finish_reason is not None
usage = CompletionUsage(**chunk.usage)
full_content = "".join(collected_content)
self._update_costs(usage)
return full_content
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)

View file

@ -13,19 +13,11 @@ from google.generativeai.types.generation_types import (
GenerateContentResponse,
GenerationConfig,
)
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
class GeminiGenerativeModel(GenerativeModel):
@ -55,6 +47,7 @@ class GeminiLLM(BaseLLM):
self.__init_gemini(config)
self.config = config
self.model = "gemini-pro" # so far only one model
self.pricing_plan = self.config.pricing_plan or self.model
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: LLMConfig):
@ -72,16 +65,6 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text
@ -105,16 +88,16 @@ class GeminiLLM(BaseLLM):
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
usage = await self.aget_usage(messages, resp.text)
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages)
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
**self._const_kwargs(messages, stream=True)
)
@ -129,17 +112,3 @@ class GeminiLLM(BaseLLM):
usage = await self.aget_usage(messages, full_content)
self._update_costs(usage)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -35,10 +35,16 @@ class HumanProvider(BaseLLM):
) -> str:
return self.ask(msg, timeout=timeout)
async def _achat_completion(self, messages: list[dict], timeout=3):
pass
async def acompletion(self, messages: list[dict], timeout=3):
"""dummy implementation of abstract method in base"""
return []
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
pass
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""dummy implementation of abstract method in base"""
return ""

View file

@ -21,11 +21,15 @@ class LLMProviderRegistry:
return self.providers[enum]
def register_provider(key):
def register_provider(keys):
"""register provider to registry"""
def decorator(cls):
LLM_REGISTRY.register(key, cls)
if isinstance(keys, list):
for key in keys:
LLM_REGISTRY.register(key, cls)
else:
LLM_REGISTRY.register(keys, cls)
return cls
return decorator

View file

@ -5,6 +5,8 @@
@File : metagpt_api.py
@Desc : MetaGPT LLM provider.
"""
from openai.types import CompletionUsage
from metagpt.configs.llm_config import LLMType
from metagpt.provider import OpenAILLM
from metagpt.provider.llm_provider_registry import register_provider
@ -12,4 +14,7 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMType.METAGPT)
class MetaGPTLLM(OpenAILLM):
pass
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
# The current billing is based on usage frequency. If there is a future billing logic based on the
# number of tokens, please refine the logic here accordingly.
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)

View file

@ -4,22 +4,12 @@
import json
from requests import ConnectionError
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import TokenCostManager
@ -36,26 +26,17 @@ class OllamaLLM(BaseLLM):
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = TokenCostManager()
self.cost_manager = TokenCostManager()
def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
self.pricing_plan = self.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})
@ -69,7 +50,7 @@ class OllamaLLM(BaseLLM):
chunk = chunk.decode(encoding)
return json.loads(chunk)
async def _achat_completion(self, messages: list[dict]) -> dict:
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
@ -82,9 +63,9 @@ class OllamaLLM(BaseLLM):
return resp
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages)
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
stream_resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.suffix_url,
@ -110,17 +91,3 @@ class OllamaLLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -1,47 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with openai-compatible interface
from openai.types import CompletionUsage
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@register_provider(LLMType.OPEN_LLM)
class OpenLLM(OpenAILLM):
def __init__(self, config: LLMConfig):
super().__init__(config)
self._cost_manager = TokenCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
return kwargs
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not self.config.calc_usage:
return usage
try:
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
except Exception as e:
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -6,10 +6,11 @@
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from __future__ import annotations
import json
import re
from typing import AsyncIterator, Optional, Union
from typing import Optional, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
@ -28,8 +29,13 @@ from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import CodeParser, decode_image, process_message
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.common import (
CodeParser,
decode_image,
log_and_reraise,
process_message,
)
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -38,33 +44,20 @@ from metagpt.utils.token_counter import (
)
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(
"""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
"""
)
raise retry_state.outcome.exception()
@register_provider(LLMType.OPENAI)
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL])
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -86,22 +79,41 @@ class OpenAILLM(BaseLLM):
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
usage = None
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if finish_reason:
if hasattr(chunk, "usage"):
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage)
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
if not usage:
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
"n": 1,
# "n": 1, # Some services do not provide this parameter, such as mistral
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": 0.3,
"temperature": self.config.temperature,
"model": self.model,
"timeout": max(self.config.timeout, timeout),
}
@ -128,18 +140,7 @@ class OpenAILLM(BaseLLM):
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
resp = self._achat_completion_stream(messages, timeout=timeout)
collected_messages = []
async for i in resp:
log_llm_stream(i)
collected_messages.append(i)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
await self._achat_completion_stream(messages, timeout=timeout)
rsp = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(rsp)
@ -239,23 +240,13 @@ class OpenAILLM(BaseLLM):
return usage
try:
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
usage.completion_tokens = count_string_tokens(rsp, self.pricing_plan)
except Exception as e:
logger.warning(f"usage calculation failed: {e}")
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token

View file

@ -0,0 +1,131 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
import copy
import os
import qianfan
from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import (
QIANFAN_ENDPOINT_TOKEN_COSTS,
QIANFAN_MODEL_TOKEN_COSTS,
)
@register_provider(LLMType.QIANFAN)
class QianFanLLM(BaseLLM):
"""
Refs
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
"""
def __init__(self, config: LLMConfig):
self.config = config
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
self.__init_qianfan()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_qianfan(self):
if self.config.access_key and self.config.secret_key:
# for system level auth, use access_key and secret_key, recommended by official
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
elif self.config.api_key and self.config.secret_key:
# for application level auth, use api_key and secret_key
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
("ERNIE-Bot", "completions"),
("ERNIE-Bot-turbo", "eb-instant"),
("ERNIE-Speed", "ernie_speed"),
("EB-turbo-AppBuilder", "ai_apaas"),
]
if self.config.model in [pair[0] for pair in support_system_pairs]:
# only some ERNIE models support
self.use_system_prompt = True
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
self.use_system_prompt = True
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
self.aclient: ChatCompletion = qianfan.ChatCompletion()
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"messages": messages,
"stream": stream,
}
if self.config.temperature > 0:
# different model has default temperature. only set when it's specified.
kwargs["temperature"] = self.config.temperature
if self.config.endpoint:
kwargs["endpoint"] = self.config.endpoint
elif self.config.model:
kwargs["model"] = self.config.model
if self.use_system_prompt:
# if the model support system prompt, extract and pass it
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
model_or_endpoint = self.config.model or self.config.endpoint
local_calc_usage = model_or_endpoint in self.token_costs
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
content = chunk.body.get("result", "")
usage = chunk.body.get("usage", {})
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content

View file

@ -31,12 +31,18 @@ class SparkLLM(BaseLLM):
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
pass
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
# 不支持
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
w = GetMessageFromWeb(messages, self.config)
return w.run()
async def _achat_completion(self, messages: list[dict], timeout=3):
pass
async def acompletion(self, messages: list[dict], timeout=3):
# 不支持异步
w = GetMessageFromWeb(messages, self.config)

View file

@ -5,21 +5,12 @@
from enum import Enum
from typing import Optional
from requests import ConnectionError
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from zhipuai.types.chat.chat_completion import Completion
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
from metagpt.utils.cost_manager import CostManager
@ -47,22 +38,13 @@ class ZhiPuAILLM(BaseLLM):
assert self.config.api_key
self.api_key = self.config.api_key
self.model = self.config.model # so far, it support glm-3-turbo、glm-4
self.pricing_plan = self.config.pricing_plan or self.model
self.llm = ZhiPuModelAPI(api_key=self.api_key)
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
@ -96,17 +78,3 @@ class ZhiPuAILLM(BaseLLM):
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -1,6 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Build a symbols repository from source code.
This script is designed to create a symbols repository from the provided source code.
@Time : 2023/11/17 17:58
@Author : alexanderwu
@File : repo_parser.py
@ -15,15 +19,26 @@ from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
from metagpt.utils.common import any_to_str, aread
from metagpt.utils.common import any_to_str, aread, remove_white_spaces
from metagpt.utils.exceptions import handle_exception
class RepoFileInfo(BaseModel):
"""
Repository data element that represents information about a file.
Attributes:
file (str): The name or path of the file.
classes (List): A list of class names present in the file.
functions (List): A list of function names present in the file.
globals (List): A list of global variable names present in the file.
page_info (List): A list of page-related information associated with the file.
"""
file: str
classes: List = Field(default_factory=list)
functions: List = Field(default_factory=list)
@ -32,6 +47,17 @@ class RepoFileInfo(BaseModel):
class CodeBlockInfo(BaseModel):
"""
Repository data element representing information about a code block.
Attributes:
lineno (int): The starting line number of the code block.
end_lineno (int): The ending line number of the code block.
type_name (str): The type or category of the code block.
tokens (List): A list of tokens present in the code block.
properties (Dict): A dictionary containing additional properties associated with the code block.
"""
lineno: int
end_lineno: int
type_name: str
@ -39,31 +65,395 @@ class CodeBlockInfo(BaseModel):
properties: Dict = Field(default_factory=dict)
class ClassInfo(BaseModel):
class DotClassAttribute(BaseModel):
"""
Repository data element representing a class attribute in dot format.
Attributes:
name (str): The name of the class attribute.
type_ (str): The type of the class attribute.
default_ (str): The default value of the class attribute.
description (str): A description of the class attribute.
compositions (List[str]): A list of compositions associated with the class attribute.
"""
name: str = ""
type_: str = ""
default_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotClassAttribute":
"""
Parses dot format text and returns a DotClassAttribute object.
Args:
v (str): Dot format text to be parsed.
Returns:
DotClassAttribute: An instance of the DotClassAttribute class representing the parsed data.
"""
val = ""
meet_colon = False
meet_equals = False
for c in v:
if c == ":":
meet_colon = True
elif c == "=":
meet_equals = True
if not meet_colon:
val += ":"
meet_colon = True
val += c
if not meet_colon:
val += ":"
if not meet_equals:
val += "="
cix = val.find(":")
eix = val.rfind("=")
name = val[0:cix].strip()
type_ = val[cix + 1 : eix]
default_ = val[eix + 1 :].strip()
type_ = remove_white_spaces(type_) # remove white space
if type_ == "NoneType":
type_ = ""
if "Literal[" in type_:
pre_l, literal, post_l = cls._split_literal(type_)
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
type_ = pre_l + literal + post_l
else:
type_ = re.sub(r"['\"]+", "", type_) # remove '"
composition_val = type_
if default_ == "None":
default_ = ""
compositions = cls.parse_compositions(composition_val)
return cls(name=name, type_=type_, default_=default_, description=v, compositions=compositions)
@staticmethod
def parse_compositions(types_part) -> List[str]:
"""
Parses the type definition code block of source code and returns a list of compositions.
Args:
types_part: The type definition code block to be parsed.
Returns:
List[str]: A list of compositions extracted from the type definition code block.
"""
if not types_part:
return []
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
types = modified_string.split("|")
filters = {
"str",
"frozenset",
"set",
"int",
"float",
"complex",
"bool",
"dict",
"list",
"Union",
"Dict",
"Set",
"Tuple",
"NoneType",
"None",
"Any",
"Optional",
"Iterator",
"Literal",
"List",
}
result = set()
for t in types:
t = re.sub(r"['\"]+", "", t.strip())
if t and t not in filters:
result.add(t)
return list(result)
@staticmethod
def _split_literal(v):
"""
Parses the literal definition code block and returns three parts: pre-part, literal-part, and post-part.
Args:
v: The literal definition code block to be parsed.
Returns:
Tuple[str, str, str]: A tuple containing the pre-part, literal-part, and post-part of the code block.
"""
tag = "Literal["
bix = v.find(tag)
eix = len(v) - 1
counter = 1
for i in range(bix + len(tag), len(v) - 1):
c = v[i]
if c == "[":
counter += 1
continue
if c == "]":
counter -= 1
if counter > 0:
continue
eix = i
break
pre_l = v[0:bix]
post_l = v[eix + 1 :]
pre_l = re.sub(r"['\"]", "", pre_l) # remove '"
pos_l = re.sub(r"['\"]", "", post_l) # remove '"
return pre_l, v[bix : eix + 1], pos_l
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class DotClassInfo(BaseModel):
"""
Repository data element representing information about a class in dot format.
Attributes:
name (str): The name of the class.
package (Optional[str]): The package to which the class belongs (optional).
attributes (Dict[str, DotClassAttribute]): A dictionary of attributes associated with the class.
methods (Dict[str, DotClassMethod]): A dictionary of methods associated with the class.
compositions (List[str]): A list of compositions associated with the class.
aggregations (List[str]): A list of aggregations associated with the class.
"""
name: str
package: Optional[str] = None
attributes: Dict[str, str] = Field(default_factory=dict)
methods: Dict[str, str] = Field(default_factory=dict)
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
methods: Dict[str, DotClassMethod] = Field(default_factory=dict)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
@field_validator("compositions", "aggregations", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class ClassRelationship(BaseModel):
class DotClassRelationship(BaseModel):
"""
Repository data element representing a relationship between two classes in dot format.
Attributes:
src (str): The source class of the relationship.
dest (str): The destination class of the relationship.
relationship (str): The type or nature of the relationship.
label (Optional[str]): An optional label associated with the relationship.
"""
src: str = ""
dest: str = ""
relationship: str = ""
label: Optional[str] = None
class DotReturn(BaseModel):
"""
Repository data element representing a function or method return type in dot format.
Attributes:
type_ (str): The type of the return.
description (str): A description of the return type.
compositions (List[str]): A list of compositions associated with the return type.
"""
type_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotReturn" | None:
"""
Parses the return type part of dot format text and returns a DotReturn object.
Args:
v (str): The dot format text containing the return type part to be parsed.
Returns:
DotReturn | None: An instance of the DotReturn class representing the parsed return type,
or None if parsing fails.
"""
if not v:
return DotReturn(description=v)
type_ = remove_white_spaces(v)
compositions = DotClassAttribute.parse_compositions(type_)
return cls(type_=type_, description=v, compositions=compositions)
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class DotClassMethod(BaseModel):
name: str
args: List[DotClassAttribute] = Field(default_factory=list)
return_args: Optional[DotReturn] = None
description: str
aggregations: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotClassMethod":
"""
Parses a dot format method text and returns a DotClassMethod object.
Args:
v (str): The dot format text containing method information to be parsed.
Returns:
DotClassMethod: An instance of the DotClassMethod class representing the parsed method.
"""
bix = v.find("(")
eix = v.rfind(")")
rix = v.rfind(":")
if rix < 0 or rix < eix:
rix = eix
name_part = v[0:bix].strip()
args_part = v[bix + 1 : eix].strip()
return_args_part = v[rix + 1 :].strip()
name = cls._parse_name(name_part)
args = cls._parse_args(args_part)
return_args = DotReturn.parse(return_args_part)
aggregations = set()
for i in args:
aggregations.update(set(i.compositions))
aggregations.update(set(return_args.compositions))
return cls(name=name, args=args, description=v, return_args=return_args, aggregations=list(aggregations))
@staticmethod
def _parse_name(v: str) -> str:
"""
Parses the dot format method name part and returns the method name.
Args:
v (str): The dot format text containing the method name part to be parsed.
Returns:
str: The parsed method name.
"""
tags = [">", "</"]
if tags[0] in v:
bix = v.find(tags[0]) + len(tags[0])
eix = v.rfind(tags[1])
return v[bix:eix].strip()
return v.strip()
@staticmethod
def _parse_args(v: str) -> List[DotClassAttribute]:
"""
Parses the dot format method arguments part and returns the parsed arguments.
Args:
v (str): The dot format text containing the arguments part to be parsed.
Returns:
str: The parsed method arguments.
"""
if not v:
return []
parts = []
bix = 0
counter = 0
for i in range(0, len(v)):
c = v[i]
if c == "[":
counter += 1
continue
elif c == "]":
counter -= 1
continue
elif c == "," and counter == 0:
parts.append(v[bix:i].strip())
bix = i + 1
parts.append(v[bix:].strip())
attrs = []
for p in parts:
if p:
attr = DotClassAttribute.parse(p)
attrs.append(attr)
return attrs
class RepoParser(BaseModel):
"""
Tool to build a symbols repository from a project directory.
Attributes:
base_directory (Path): The base directory of the project.
"""
base_directory: Path = Field(default=None)
@classmethod
@handle_exception(exception_type=Exception, default_return=[])
def _parse_file(cls, file_path: Path) -> list:
"""Parse a Python file in the repository."""
"""
Parses a Python file in the repository.
Args:
file_path (Path): The path to the Python file to be parsed.
Returns:
list: A list containing the parsed symbols from the file.
"""
return ast.parse(file_path.read_text()).body
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
"""Extract class, function, and global variable information from the AST."""
"""
Extracts class, function, and global variable information from the Abstract Syntax Tree (AST).
Args:
tree: The Abstract Syntax Tree (AST) of the Python file.
file_path: The path to the Python file.
Returns:
RepoFileInfo: A RepoFileInfo object containing the extracted information.
"""
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
@ -81,11 +471,17 @@ class RepoParser(BaseModel):
return file_info
def generate_symbols(self) -> List[RepoFileInfo]:
"""
Builds a symbol repository from '.py' and '.js' files in the project directory.
Returns:
List[RepoFileInfo]: A list of RepoFileInfo objects containing the extracted information.
"""
files_classes = []
directory = self.base_directory
matching_files = []
extensions = ["*.py", "*.js"]
extensions = ["*.py"]
for ext in extensions:
matching_files += directory.rglob(ext)
for path in matching_files:
@ -95,19 +491,38 @@ class RepoParser(BaseModel):
return files_classes
def generate_json_structure(self, output_path):
"""Generate a JSON file documenting the repository structure."""
def generate_json_structure(self, output_path: Path):
"""
Generates a JSON file documenting the repository structure.
Args:
output_path (Path): The path to the JSON file to be generated.
"""
files_classes = [i.model_dump() for i in self.generate_symbols()]
output_path.write_text(json.dumps(files_classes, indent=4))
def generate_dataframe_structure(self, output_path):
"""Generate a DataFrame documenting the repository structure and save as CSV."""
def generate_dataframe_structure(self, output_path: Path):
"""
Generates a DataFrame documenting the repository structure and saves it as a CSV file.
Args:
output_path (Path): The path to the CSV file to be generated.
"""
files_classes = [i.model_dump() for i in self.generate_symbols()]
df = pd.DataFrame(files_classes)
df.to_csv(output_path, index=False)
def generate_structure(self, output_path=None, mode="json") -> Path:
"""Generate the structure of the repository as a specified format."""
def generate_structure(self, output_path: str | Path = None, mode="json") -> Path:
"""
Generates the structure of the repository in a specified format.
Args:
output_path (str | Path): The path to the output file or directory. Default is None.
mode (str): The output format mode. Options: "json" (default), "csv", etc.
Returns:
Path: The path to the generated output file or directory.
"""
output_file = self.base_directory / f"{self.base_directory.name}-structure.{mode}"
output_path = Path(output_path) if output_path else output_file
@ -119,6 +534,16 @@ class RepoParser(BaseModel):
@staticmethod
def node_to_str(node) -> CodeBlockInfo | None:
"""
Parses and converts an Abstract Syntax Tree (AST) node to a CodeBlockInfo object.
Args:
node: The AST node to be converted.
Returns:
CodeBlockInfo | None: A CodeBlockInfo object representing the parsed AST node,
or None if the conversion fails.
"""
if isinstance(node, ast.Try):
return None
if any_to_str(node) == any_to_str(ast.Expr):
@ -159,9 +584,19 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_expr(node) -> List:
"""
Parses an expression Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing an expression.
Returns:
List: A list containing the parsed information from the expression node.
"""
funcs = {
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
any_to_str(ast.Tuple): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
}
func = funcs.get(any_to_str(node.value))
if func:
@ -170,12 +605,30 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_name(n):
"""
Gets the 'name' value of an Abstract Syntax Tree (AST) node.
Args:
n: The AST node.
Returns:
The 'name' value of the AST node.
"""
if n.asname:
return f"{n.name} as {n.asname}"
return n.name
@staticmethod
def _parse_if(n):
"""
Parses an 'if' statement Abstract Syntax Tree (AST) node.
Args:
n: The AST node representing an 'if' statement.
Returns:
None or Parsed information from the 'if' statement node.
"""
tokens = []
try:
if isinstance(n.test, ast.BoolOp):
@ -187,10 +640,14 @@ class RepoParser(BaseModel):
v = RepoParser._parse_variable(n.test.left)
if v:
tokens.append(v)
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
if isinstance(n.test, ast.Name):
v = RepoParser._parse_variable(n.test)
tokens.append(v)
if hasattr(n.test, "comparators"):
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
return tokens
except Exception as e:
logger.warning(f"Unsupported if: {n}, err:{e}")
@ -198,6 +655,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if_compare(n):
"""
Parses an 'if' condition Abstract Syntax Tree (AST) node.
Args:
n: The AST node representing an 'if' condition.
Returns:
None or Parsed information from the 'if' condition node.
"""
if hasattr(n, "left"):
return RepoParser._parse_variable(n.left)
else:
@ -205,6 +671,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_variable(node):
"""
Parses a variable Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing a variable.
Returns:
None or Parsed information from the variable node.
"""
try:
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
@ -213,7 +688,7 @@ class RepoParser(BaseModel):
if hasattr(x.value, "id")
else f"{x.attr}",
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
any_to_str(ast.Tuple): lambda x: "",
any_to_str(ast.Tuple): lambda x: [d.value for d in x.dims],
}
func = funcs.get(any_to_str(node))
if not func:
@ -224,22 +699,42 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_assign(node):
"""
Parses an assignment Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing an assignment.
Returns:
None or Parsed information from the assignment node.
"""
return [RepoParser._parse_variable(t) for t in node.targets]
async def rebuild_class_views(self, path: str | Path = None):
"""
Executes `pylint` to reconstruct the dot format class view repository file.
Args:
path (str | Path): The path to the target directory or file. Default is None.
"""
if not path:
path = self.base_directory
path = Path(path)
if not path.exists():
return
init_file = path / "__init__.py"
if not init_file.exists():
raise ValueError("Failed to import module __init__ with error:No module named __init__.")
command = f"pyreverse {str(path)} -o dot"
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
output_dir = path / "__dot__"
output_dir.mkdir(parents=True, exist_ok=True)
result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
if result.returncode != 0:
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_view_pathname = output_dir / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
packages_pathname = output_dir / "packages.dot"
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
@ -247,7 +742,17 @@ class RepoParser(BaseModel):
packages_pathname.unlink(missing_ok=True)
return class_views, relationship_views, package_root
async def _parse_classes(self, class_view_pathname):
@staticmethod
async def _parse_classes(class_view_pathname: Path) -> List[DotClassInfo]:
"""
Parses a dot format class view repository file.
Args:
class_view_pathname (Path): The path to the dot format class view repository file.
Returns:
List[DotClassInfo]: A list of DotClassInfo objects representing the parsed classes.
"""
class_views = []
if not class_view_pathname.exists():
return class_views
@ -258,22 +763,38 @@ class RepoParser(BaseModel):
if not package_name:
continue
class_name, members, functions = re.split(r"(?<!\\)\|", info)
class_info = ClassInfo(name=class_name)
class_info = DotClassInfo(name=class_name)
class_info.package = package_name
for m in members.split("\n"):
if not m:
continue
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
class_info.attributes[member_name] = m
attr = DotClassAttribute.parse(m)
class_info.attributes[attr.name] = attr
for i in attr.compositions:
if i not in class_info.compositions:
class_info.compositions.append(i)
for f in functions.split("\n"):
if not f:
continue
function_name, _ = f.split("(", 1)
class_info.methods[function_name] = f
method = DotClassMethod.parse(f)
class_info.methods[method.name] = method
for i in method.aggregations:
if i not in class_info.compositions and i not in class_info.aggregations:
class_info.aggregations.append(i)
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
@staticmethod
async def _parse_class_relationships(class_view_pathname: Path) -> List[DotClassRelationship]:
"""
Parses a dot format class view repository file.
Args:
class_view_pathname (Path): The path to the dot format class view repository file.
Returns:
List[DotClassRelationship]: A list of DotClassRelationship objects representing the parsed class relationships.
"""
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
@ -287,7 +808,16 @@ class RepoParser(BaseModel):
return relationship_views
@staticmethod
def _split_class_line(line):
def _split_class_line(line: str) -> (str, str):
"""
Parses a dot format line about class info and returns the class name part and class members part.
Args:
line (str): The dot format line containing class information.
Returns:
Tuple[str, str]: A tuple containing the class name part and class members part.
"""
part_splitor = '" ['
if part_splitor not in line:
return None, None
@ -305,14 +835,25 @@ class RepoParser(BaseModel):
return class_name, info
@staticmethod
def _split_relationship_line(line):
def _split_relationship_line(line: str) -> DotClassRelationship:
"""
Parses a dot format line about the relationship of two classes and returns 'Generalize', 'Composite',
or 'Aggregate'.
Args:
line (str): The dot format line containing relationship information.
Returns:
DotClassRelationship: The object of relationship representing either 'Generalize', 'Composite',
or 'Aggregate' relationship.
"""
splitters = [" -> ", " [", "];"]
idxs = []
for tag in splitters:
if tag not in line:
return None
idxs.append(line.find(tag))
ret = ClassRelationship()
ret = DotClassRelationship()
ret.src = line[0 : idxs[0]].strip('"')
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
@ -330,7 +871,16 @@ class RepoParser(BaseModel):
return ret
@staticmethod
def _get_label(line):
def _get_label(line: str) -> str:
"""
Parses a dot format line and returns the label information.
Args:
line (str): The dot format line containing label information.
Returns:
str: The label information parsed from the line.
"""
tag = 'label="'
if tag not in line:
return ""
@ -340,6 +890,15 @@ class RepoParser(BaseModel):
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
"""
Creates a mapping table between source code files' paths and module names.
Args:
path (str | Path): The path to the source code files or directory.
Returns:
Dict[str, str]: A dictionary mapping source code file paths to their corresponding module names.
"""
mappings = {
str(path).replace("/", "."): str(path),
}
@ -363,8 +922,21 @@ class RepoParser(BaseModel):
@staticmethod
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationship], str):
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
) -> (List[DotClassInfo], List[DotClassRelationship], str):
"""
Augments namespaces to the path-prefixed classes and relationships.
Args:
class_views (List[DotClassInfo]): List of DotClassInfo objects representing class views.
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects representing
relationships.
path (str | Path): The path to the source code files or directory.
Returns:
Tuple[List[DotClassInfo], List[DotClassRelationship], str]: A tuple containing the augmented class views,
relationships, and the root path of the package.
"""
if not class_views:
return [], [], ""
c = class_views[0]
@ -383,28 +955,49 @@ class RepoParser(BaseModel):
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
for i in range(len(relationship_views)):
v = relationship_views[i]
for _, v in enumerate(relationship_views):
v.src = RepoParser._repair_ns(v.src, new_mappings)
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
relationship_views[i] = v
return class_views, relationship_views, root_path
return class_views, relationship_views, str(path)[: len(root_path)]
@staticmethod
def _repair_ns(package, mappings):
def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
"""
Replaces the package-prefix with the namespace-prefix.
Args:
package (str): The package to be repaired.
mappings (Dict[str, str]): A dictionary mapping source code file paths to their corresponding packages.
Returns:
str: The repaired namespace.
"""
file_ns = package
ix = 0
while file_ns != "":
if file_ns not in mappings:
ix = file_ns.rfind(".")
file_ns = file_ns[0:ix]
continue
break
if file_ns == "":
return ""
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns
@staticmethod
def _find_root(full_key, package) -> str:
def _find_root(full_key: str, package: str) -> str:
"""
Returns the package root path based on the key, which is the full path, and the package information.
Args:
full_key (str): The full key representing the full path.
package (str): The package information.
Returns:
str: The package root path.
"""
left = full_key
while left != "":
if left in package:
@ -417,5 +1010,14 @@ class RepoParser(BaseModel):
return "." + full_key[0:ix]
def is_func(node):
def is_func(node) -> bool:
"""
Returns True if the given node represents a function.
Args:
node: The Abstract Syntax Tree (AST) node.
Returns:
bool: True if the node represents a function, False otherwise.
"""
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))

View file

@ -65,7 +65,7 @@ class Assistant(Role):
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
rsp = await self.llm.aask(prompt, ["You are an action classifier"], stream=False)
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
return await self._plan(rsp, last_talk=last_talk)

View file

@ -5,9 +5,9 @@ from typing import Literal, Union
from pydantic import Field, model_validator
from metagpt.actions.mi.ask_review import ReviewConst
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
from metagpt.actions.mi.write_analysis_code import CheckData, WriteCodeWithTools
from metagpt.actions.di.ask_review import ReviewConst
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.actions.di.write_analysis_code import CheckData, WriteCodeWithTools
from metagpt.logs import logger
from metagpt.prompts.mi.write_analysis_code import DATA_INFO
from metagpt.roles import Role
@ -32,9 +32,9 @@ Output a json following the format:
"""
class Interpreter(Role):
name: str = "Ivy"
profile: str = "Interpreter"
class DataInterpreter(Role):
name: str = "David"
profile: str = "DataInterpreter"
auto_run: bool = True
use_plan: bool = True
use_reflection: bool = False

View file

@ -20,7 +20,6 @@
from __future__ import annotations
import json
import os
from collections import defaultdict
from pathlib import Path
from typing import Set
@ -32,7 +31,6 @@ from metagpt.actions.summarize_code import SummarizeCode
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
from metagpt.const import (
CODE_PLAN_AND_CHANGE_FILE_REPO,
CODE_PLAN_AND_CHANGE_FILENAME,
REQUIREMENT_FILENAME,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
@ -119,10 +117,10 @@ class Engineer(Role):
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
if self.config.inc:
dependencies.add(os.path.join(CODE_PLAN_AND_CHANGE_FILE_REPO, CODE_PLAN_AND_CHANGE_FILENAME))
dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path)
await self.project_repo.srcs.save(
filename=coding_context.filename,
dependencies=dependencies,
dependencies=list(dependencies),
content=coding_context.code_doc.content,
)
msg = Message(
@ -206,7 +204,6 @@ class Engineer(Role):
async def _act_code_plan_and_change(self):
"""Write code plan and change that guides subsequent WriteCode and WriteCodeReview"""
logger.info("Writing code plan and change..")
node = await self.rc.todo.run()
code_plan_and_change = node.instruct_content.model_dump_json()
dependencies = {
@ -215,11 +212,12 @@ class Engineer(Role):
self.rc.todo.i_context.design_filename,
self.rc.todo.i_context.task_filename,
}
code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename)
await self.project_repo.docs.code_plan_and_change.save(
filename=self.rc.todo.i_context.filename, content=code_plan_and_change, dependencies=dependencies
filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies
)
await self.project_repo.resources.code_plan_and_change.save(
filename=Path(self.rc.todo.i_context.filename).with_suffix(".md").name,
filename=code_plan_and_change_filepath.with_suffix(".md").name,
content=node.content,
dependencies=dependencies,
)
@ -269,15 +267,24 @@ class Engineer(Role):
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
task_doc = None
design_doc = None
code_plan_and_change_doc = None
for i in dependencies:
if str(i.parent) == TASK_FILE_REPO:
task_doc = await self.project_repo.docs.task.get(i.name)
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
design_doc = await self.project_repo.docs.system_design.get(i.name)
elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO:
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name)
if not task_doc or not design_doc:
logger.error(f'Detected source code "{filename}" from an unknown origin.')
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
context = CodingContext(
filename=filename,
design_doc=design_doc,
task_doc=task_doc,
code_doc=old_code_doc,
code_plan_and_change_doc=code_plan_and_change_doc,
)
return context
async def _new_coding_doc(self, filename, dependency):
@ -296,6 +303,7 @@ class Engineer(Role):
for filename in changed_task_files:
design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename)
task_list = self._parse_tasks(task_doc)
for task_filename in task_list:
old_code_doc = await self.project_repo.srcs.get(task_filename)
@ -303,9 +311,18 @@ class Engineer(Role):
old_code_doc = Document(
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
)
context = CodingContext(
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
)
if not code_plan_and_change_doc:
context = CodingContext(
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
)
else:
context = CodingContext(
filename=task_filename,
design_doc=design_doc,
task_doc=task_doc,
code_doc=old_code_doc,
code_plan_and_change_doc=code_plan_and_change_doc,
)
coding_doc = Document(
root_path=str(self.project_repo.src_relative_path),
filename=task_filename,
@ -342,9 +359,17 @@ class Engineer(Role):
summarizations[ctx].append(filename)
for ctx, filenames in summarizations.items():
ctx.codes_filenames = filenames
self.summarize_todos.append(SummarizeCode(i_context=ctx, context=self.context, llm=self.llm))
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
for i, act in enumerate(self.summarize_todos):
if act.i_context.task_filename == new_summarize.i_context.task_filename:
self.summarize_todos[i] = new_summarize
new_summarize = None
break
if new_summarize:
self.summarize_todos.append(new_summarize)
if self.summarize_todos:
self.set_todo(self.summarize_todos[0])
self.summarize_todos.pop(0)
async def _new_code_plan_and_change_action(self):
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""

View file

@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
i = action
self._init_action(i)
self.actions.append(i)
self.states.append(f"{len(self.actions)}. {action}")
self.states.append(f"{len(self.actions) - 1}. {action}")
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True):
"""Set strategy of the Role reacting to observed Message. Variation lies in how

View file

@ -37,7 +37,6 @@ from pydantic import (
)
from metagpt.const import (
CODE_PLAN_AND_CHANGE_FILENAME,
MESSAGE_ROUTE_CAUSE_BY,
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
@ -47,6 +46,7 @@ from metagpt.const import (
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import DotClassInfo
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import (
@ -613,6 +613,7 @@ class CodingContext(BaseContext):
design_doc: Optional[Document] = None
task_doc: Optional[Document] = None
code_doc: Optional[Document] = None
code_plan_and_change_doc: Optional[Document] = None
class TestingContext(BaseContext):
@ -667,7 +668,6 @@ class BugFixContext(BaseContext):
class CodePlanAndChangeContext(BaseModel):
filename: str = CODE_PLAN_AND_CHANGE_FILENAME
requirement: str = ""
prd_filename: str = ""
design_filename: str = ""
@ -691,54 +691,64 @@ class CodePlanAndChangeContext(BaseModel):
# mermaid class view
class ClassMeta(BaseModel):
class UMLClassMeta(BaseModel):
name: str = ""
abstraction: bool = False
static: bool = False
visibility: str = ""
@staticmethod
def name_to_visibility(name: str) -> str:
if name == "__init__":
return "+"
if name.startswith("__"):
return "-"
elif name.startswith("_"):
return "#"
return "+"
class ClassAttribute(ClassMeta):
class UMLClassAttribute(UMLClassMeta):
value_type: str = ""
default_value: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
if self.value_type:
content += self.value_type + " "
content += self.name
content += self.value_type.replace(" ", "") + " "
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
content += name
if self.default_value:
content += "="
if self.value_type not in ["str", "string", "String"]:
content += self.default_value
else:
content += '"' + self.default_value.replace('"', "") + '"'
if self.abstraction:
content += "*"
if self.static:
content += "$"
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
return content
class ClassMethod(ClassMeta):
args: List[ClassAttribute] = Field(default_factory=list)
class UMLClassMethod(UMLClassMeta):
args: List[UMLClassAttribute] = Field(default_factory=list)
return_type: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
if self.return_type:
content += ":" + self.return_type
if self.abstraction:
content += "*"
if self.static:
content += "$"
content += " " + self.return_type.replace(" ", "")
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
return content
class ClassView(ClassMeta):
attributes: List[ClassAttribute] = Field(default_factory=list)
methods: List[ClassMethod] = Field(default_factory=list)
class UMLClassView(UMLClassMeta):
attributes: List[UMLClassAttribute] = Field(default_factory=list)
methods: List[UMLClassMethod] = Field(default_factory=list)
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
@ -748,3 +758,21 @@ class ClassView(ClassMeta):
content += v.get_mermaid(align=align + 1) + "\n"
content += "".join(["\t" for i in range(align)]) + "}\n"
return content
@classmethod
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
class_view = cls(name=dot_class_info.name, visibility=visibility)
for i in dot_class_info.attributes.values():
visibility = UMLClassAttribute.name_to_visibility(i.name)
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
class_view.attributes.append(attr)
for i in dot_class_info.methods.values():
visibility = UMLClassMethod.name_to_visibility(i.name)
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
for j in i.args:
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
method.args.append(arg)
method.return_type = i.return_args.type_
class_view.methods.append(method)
return class_view

View file

@ -2,14 +2,11 @@
# -*- coding: utf-8 -*-
import asyncio
import shutil
from pathlib import Path
import typer
from metagpt.config2 import config
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.context import Context
from metagpt.const import CONFIG_ROOT
from metagpt.utils.project_repo import ProjectRepo
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@ -30,6 +27,8 @@ def generate_repo(
recover_path=None,
) -> ProjectRepo:
"""Run the startup logic. Can be called from CLI or other Python scripts."""
from metagpt.config2 import config
from metagpt.context import Context
from metagpt.roles import (
Architect,
Engineer,
@ -122,7 +121,17 @@ def startup(
)
def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
llm:
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
base_url: "https://api.openai.com/v1" # or forward url / other llm url
api_key: "YOUR_API_KEY"
"""
def copy_config_to():
"""Initialize the configuration file for MetaGPT."""
target_path = CONFIG_ROOT / "config2.yaml"
@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
print(f"Existing configuration file backed up at {backup_path}")
# 复制文件
shutil.copy(str(config_path), target_path)
target_path.write_text(DEFAULT_CONFIG, encoding="utf-8")
print(f"Configuration file initialized at {target_path}")

View file

@ -4,8 +4,8 @@ import json
from pydantic import BaseModel, Field
from metagpt.actions.mi.ask_review import AskReview, ReviewConst
from metagpt.actions.mi.write_plan import (
from metagpt.actions.di.ask_review import AskReview, ReviewConst
from metagpt.actions.di.write_plan import (
WritePlan,
precheck_update_plan_from_rsp,
update_plan_from_rsp,

View file

@ -49,8 +49,8 @@ class TOTSolver(BaseSolver):
raise NotImplementedError
class InterpreterSolver(BaseSolver):
"""InterpreterSolver: Write&Run code in the graph"""
class DataInterpreterSolver(BaseSolver):
"""DataInterpreterSolver: Write&Run code in the graph"""
async def solve(self):
raise NotImplementedError

View file

@ -23,10 +23,10 @@ import platform
import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Literal, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import loguru
@ -423,23 +423,109 @@ def is_send_to(message: "Message", addresses: set):
def any_to_name(val):
"""
Convert a value to its name by extracting the last part of the dotted path.
:param val: The value to convert.
:return: The name of the value.
"""
return any_to_str(val).split(".")[-1]
def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def concat_namespace(*args, delimiter: str = ":") -> str:
"""Concatenate fields to create a unique namespace prefix.
Example:
>>> concat_namespace('prefix', 'field1', 'field2', delimiter=":")
'prefix:field1:field2'
"""
return delimiter.join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
def split_namespace(ns_class_name: str, delimiter: str = ":", maxsplit: int = 1) -> List[str]:
"""Split a namespace-prefixed name into its namespace-prefix and name parts.
Example:
>>> split_namespace('prefix:classname')
['prefix', 'classname']
>>> split_namespace('prefix:module:class', delimiter=":", maxsplit=2)
['prefix', 'module', 'class']
"""
return ns_class_name.split(delimiter, maxsplit=maxsplit)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
def auto_namespace(name: str, delimiter: str = ":") -> str:
"""Automatically handle namespace-prefixed names.
If the input name is empty, returns a default namespace prefix and name.
If the input name is not namespace-prefixed, adds a default namespace prefix.
Otherwise, returns the input name unchanged.
Example:
>>> auto_namespace('classname')
'?:classname'
>>> auto_namespace('prefix:classname')
'prefix:classname'
>>> auto_namespace('')
'?:?'
>>> auto_namespace('?:custom')
'?:custom'
"""
if not name:
return f"?{delimiter}?"
v = split_namespace(name, delimiter=delimiter)
if len(v) < 2:
return f"?{delimiter}{name}"
return name
def add_affix(text: str, affix: Literal["brace", "url", "none"] = "brace"):
"""Add affix to encapsulate data.
Example:
>>> add_affix("data", affix="brace")
'{data}'
>>> add_affix("example.com", affix="url")
'%7Bexample.com%7D'
>>> add_affix("text", affix="none")
'text'
"""
mappings = {
"brace": lambda x: "{" + x + "}",
"url": lambda x: quote("{" + x + "}"),
}
encoder = mappings.get(affix, lambda x: x)
return encoder(text)
def remove_affix(text, affix: Literal["brace", "url", "none"] = "brace"):
"""Remove affix to extract encapsulated data.
Args:
text (str): The input text with affix to be removed.
affix (str, optional): The type of affix used. Defaults to "brace".
Supported affix types: "brace" for removing curly braces, "url" for URL decoding within curly braces.
Returns:
str: The text with affix removed.
Example:
>>> remove_affix('{data}', affix="brace")
'data'
>>> remove_affix('%7Bexample.com%7D', affix="url")
'example.com'
>>> remove_affix('text', affix="none")
'text'
"""
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
decoder = mappings.get(affix, lambda x: x)
return decoder(text)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
@ -626,6 +712,54 @@ def list_files(root: str | Path) -> List[Path]:
return files
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
return [v.strip() for v in json_blocks]
def remove_white_spaces(v: str) -> str:
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
async def aread_bin(filename: str | Path) -> bytes:
"""Read binary file asynchronously.
Args:
filename (Union[str, Path]): The name or path of the file to be read.
Returns:
bytes: The content of the file as bytes.
Example:
>>> content = await aread_bin('example.txt')
b'This is the content of the file.'
>>> content = await aread_bin(Path('example.txt'))
b'This is the content of the file.'
"""
async with aiofiles.open(str(filename), mode="rb") as reader:
content = await reader.read()
return content
async def awrite_bin(filename: str | Path, data: bytes):
"""Write binary file asynchronously.
Args:
filename (Union[str, Path]): The name or path of the file to be written.
data (bytes): The binary data to be written to the file.
Example:
>>> await awrite_bin('output.bin', b'This is binary data.')
>>> await awrite_bin(Path('output.bin'), b'Another set of binary data.')
"""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="wb") as writer:
await writer.write(data)
def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)
@ -689,3 +823,14 @@ def process_message(messages: Union[str, Message, list[dict], list[Message], lis
else:
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
return processed_messages
def log_and_reraise(retry_state: RetryCallState):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(
"""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
"""
)
raise retry_state.outcome.exception()

View file

@ -6,12 +6,13 @@
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
"""
import re
from typing import NamedTuple
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.token_counter import TOKEN_COSTS
from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS
class Costs(NamedTuple):
@ -29,6 +30,7 @@ class CostManager(BaseModel):
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -39,14 +41,17 @@ class CostManager(BaseModel):
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
if prompt_tokens + completion_tokens == 0 or not model:
return
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
if model not in TOKEN_COSTS:
if model not in self.token_costs:
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
return
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
prompt_tokens * self.token_costs[model]["prompt"]
+ completion_tokens * self.token_costs[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
@ -101,3 +106,44 @@ class TokenCostManager(CostManager):
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
class FireworksCostManager(CostManager):
def model_grade_token_costs(self, model: str) -> dict[str, float]:
def _get_model_size(model: str) -> float:
size = re.findall(".*-([0-9.]+)b", model)
size = float(size[0]) if len(size) > 0 else -1
return size
if "mixtral-8x7b" in model:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"]
else:
model_size = _get_model_size(model)
if 0 < model_size <= 16:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"]
elif 16 < model_size <= 80:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"]
else:
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"]
return token_costs
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
"""
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.4f}"
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)

View file

@ -4,7 +4,9 @@
@Time : 2023/12/19
@Author : mashenquan
@File : di_graph_repository.py
@Desc : Graph repository based on DiGraph
@Desc : Graph repository based on DiGraph.
This script defines a graph repository class based on a directed graph (DiGraph), providing functionalities
specific to handling directed relationships between entities.
"""
from __future__ import annotations
@ -19,20 +21,41 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
"""Insert a new triple into the directed graph repository.
Args:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
Example:
await my_di_graph_repo.insert(subject="Node1", predicate="connects_to", object_="Node2")
# Adds a directed relationship: Node1 connects_to Node2
"""
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
"""Retrieve triples from the directed graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
List[SPO]: A list of SPO objects representing the selected triples.
Example:
selected_triples = await my_di_graph_repo.select(subject="Node1", predicate="connects_to")
# Retrieves directed relationships where Node1 is the subject and the predicate is 'connects_to'.
"""
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
@ -44,12 +67,41 @@ class DiGraphRepository(GraphRepository):
result.append(SPO(subject=s, predicate=p, object_=o))
return result
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
"""Delete triples from the directed graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
int: The number of triples deleted from the repository.
Example:
deleted_count = await my_di_graph_repo.delete(subject="Node1", predicate="connects_to")
# Deletes directed relationships where Node1 is the subject and the predicate is 'connects_to'.
"""
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
if not rows:
return 0
for r in rows:
self._repo.remove_edge(r.subject, r.object_)
return len(rows)
def json(self) -> str:
"""Convert the directed graph repository to a JSON-formatted string."""
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
return data
async def save(self, path: str | Path = None):
"""Save the directed graph repository to a JSON file.
Args:
path (Union[str, Path], optional): The directory path where the JSON file will be saved.
If not provided, the default path is taken from the 'root' key in the keyword arguments.
"""
data = self.json()
path = path or self._kwargs.get("root")
if not path.exists():
@ -58,12 +110,21 @@ class DiGraphRepository(GraphRepository):
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
"""Create and load a directed graph repository from a JSON file.
Args:
pathname (Union[str, Path]): The path to the JSON file to be loaded.
Returns:
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
@ -74,9 +135,16 @@ class DiGraphRepository(GraphRepository):
@property
def root(self) -> str:
"""Return the root directory path for the graph repository files."""
return self._kwargs.get("root")
@property
def pathname(self) -> Path:
"""Return the path and filename to the graph repository file."""
p = Path(self.root) / self.name
return p.with_suffix(".json")
@property
def repo(self):
"""Get the underlying directed graph repository."""
return self._repo

View file

@ -4,21 +4,28 @@
@Time : 2023/12/19
@Author : mashenquan
@File : graph_repository.py
@Desc : Superclass for graph repository.
@Desc : Superclass for graph repository. This script defines a superclass for a graph repository, providing a
foundation for specific implementations.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace, split_namespace
class GraphKeyword:
"""Basic words for a Graph database.
This class defines a set of basic words commonly used in the context of a Graph database.
"""
IS = "is"
OF = "Of"
ON = "On"
@ -28,51 +35,149 @@ class GraphKeyword:
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_METHOD = "class_method"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_METHOD = "has_class_method"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_DETAIL = "has_detail"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
HAS_CLASS_USE_CASE = "has_class_use_case"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
HAS_PARTICIPANT = "has_participant"
class SPO(BaseModel):
"""Graph repository record type.
This class represents a record in a graph repository with three components:
- Subject: The subject of the triple.
- Predicate: The predicate describing the relationship between the subject and the object.
- Object: The object of the triple.
Attributes:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
Example:
spo_record = SPO(subject="Node1", predicate="connects_to", object_="Node2")
# Represents a triple: Node1 connects_to Node2
"""
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
"""Abstract base class for a Graph Repository.
This class defines the interface for a graph repository, providing methods for inserting, selecting,
deleting, and saving graph data. Concrete implementations of this class must provide functionality
for these operations.
"""
def __init__(self, name: str, **kwargs):
self._repo_name = name
self._kwargs = kwargs
@abstractmethod
async def insert(self, subject: str, predicate: str, object_: str):
pass
"""Insert a new triple into the graph repository.
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
Args:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
Example:
await my_repository.insert(subject="Node1", predicate="connects_to", object_="Node2")
# Inserts a triple: Node1 connects_to Node2 into the graph repository.
"""
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
"""Retrieve triples from the graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
List[SPO]: A list of SPO objects representing the selected triples.
Example:
selected_triples = await my_repository.select(subject="Node1", predicate="connects_to")
# Retrieves triples where Node1 is the subject and the predicate is 'connects_to'.
"""
pass
@abstractmethod
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
"""Delete triples from the graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
int: The number of triples deleted from the repository.
Example:
deleted_count = await my_repository.delete(subject="Node1", predicate="connects_to")
# Deletes triples where Node1 is the subject and the predicate is 'connects_to'.
"""
pass
@abstractmethod
async def save(self):
"""Save any changes made to the graph repository.
Example:
await my_repository.save()
# Persists any changes made to the graph repository.
"""
pass
@property
def name(self) -> str:
"""Get the name of the graph repository."""
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
"""Insert information of RepoFileInfo into the specified graph repository.
This function updates the provided graph repository with information from the given RepoFileInfo object.
The function inserts triples related to various dimensions such as file type, class, class method, function,
global variable, and page info.
Triple Patterns:
- (?, is, [file type])
- (?, has class, ?)
- (?, is, [class])
- (?, has class method, ?)
- (?, has function, ?)
- (?, is, [function])
- (?, is, global variable)
- (?, has page info, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
file_info (RepoFileInfo): The RepoFileInfo object containing information to be inserted.
Example:
await update_graph_db_with_file_info(my_graph_repo, my_file_info)
# Updates 'my_graph_repo' with information from 'my_file_info'.
"""
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
@ -95,13 +200,13 @@ class GraphRepository(ABC):
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
for f in file_info.functions:
# file -> function
@ -133,7 +238,34 @@ class GraphRepository(ABC):
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
"""Insert dot format class information into the specified graph repository.
This function updates the provided graph repository with class information from the given list of DotClassInfo objects.
The function inserts triples related to various aspects of class views, including source code, file type, class,
class property, class detail, method, composition, and aggregation.
Triple Patterns:
- (?, is, source code)
- (?, is, file type)
- (?, has class, ?)
- (?, is, class)
- (?, has class property, ?)
- (?, is, class property)
- (?, has detail, ?)
- (?, has method, ?)
- (?, is composite of, ?)
- (?, is aggregate of, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
class_views (List[DotClassInfo]): List of DotClassInfo objects containing class information to be inserted.
Example:
await update_graph_db_with_class_views(my_graph_repo, [class_info1, class_info2])
# Updates 'my_graph_repo' with class information from the provided list of DotClassInfo objects.
"""
for c in class_views:
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
@ -146,6 +278,7 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
@ -160,33 +293,61 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.HAS_DETAIL,
object_=vt.model_dump_json(),
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
for fn, ft in c.methods.items():
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
predicate=GraphKeyword.HAS_DETAIL,
object_=ft.model_dump_json(),
)
for i in c.compositions:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
)
for i in c.aggregations:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
):
"""Insert class relationships and labels into the specified graph repository.
This function updates the provided graph repository with class relationship information from the given list
of DotClassRelationship objects. The function inserts triples representing relationships and labels between
classes.
Triple Patterns:
- (?, is relationship of, ?)
- (?, is relationship on, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects containing
class relationship information to be inserted.
Example:
await update_graph_db_with_class_relationship_views(my_graph_repo, [relationship1, relationship2])
# Updates 'my_graph_repo' with class relationship information from the provided list of DotClassRelationship objects.
"""
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
@ -198,3 +359,32 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)
@staticmethod
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
"""Append namespace-prefixed information to relationship SPO (Subject-Predicate-Object) objects in the graph
repository.
This function updates the provided graph repository by appending namespace-prefixed information to existing
relationship SPO objects.
Args:
graph_db (GraphRepository): The graph repository object to be updated.
"""
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mapping = defaultdict(list)
for c in classes:
name = split_namespace(c.subject)[-1]
mapping[name].append(c.subject)
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
for r in rows:
ns, class_ = split_namespace(r.object_)
if ns != "?":
continue
val = mapping[class_]
if len(val) != 1:
continue
ns_name = val[0]
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)

View file

@ -33,6 +33,7 @@ from metagpt.const import (
TASK_PDF_FILE_REPO,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
VISUAL_GRAPH_REPO_FILE_REPO,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
code_summary: FileRepository
sd_output: FileRepository
code_plan_and_change: FileRepository
graph_repo: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
class ProjectRepo(FileRepository):
@ -133,6 +136,7 @@ class ProjectRepo(FileRepository):
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
return bool(code_files)
def with_src_path(self, path: str | Path) -> ProjectRepo:
try:

View file

@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
arr = output.split("\n")
new_arr = []
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
# problem, `"""` or `'''` without `,`
new_line = f",{line}"
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
# backslash problem like \" in the output
char = rline[col_no - 1]
nearest_char_idx = rline[col_no:].find(char)
new_line = (
rline[: col_no - 1]
+ "\\"
+ rline[col_no - 1 : col_no + nearest_char_idx]
+ "\\"
+ rline[col_no + nearest_char_idx :]
)
elif '",' not in line and "," not in line and '"' not in line:
new_line = f'{line}",'
elif not line.endswith(","):

View file

@ -35,9 +35,111 @@ TOKEN_COSTS = {
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
"open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025},
"open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007},
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
"claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024},
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
}
"""
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
"""
QIANFAN_MODEL_TOKEN_COSTS = {
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
}
QIANFAN_ENDPOINT_TOKEN_COSTS = {
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
"""
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
"qwen-max": {"prompt": 0.0, "completion": 0.0},
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
}
FIREWORKS_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
@ -61,6 +163,19 @@ TOKEN_MAX = {
"glm-3-turbo": 128000,
"glm-4": 128000,
"gemini-pro": 32768,
"moonshot-v1-8k": 8192,
"moonshot-v1-32k": 32768,
"moonshot-v1-128k": 128000,
"open-mistral-7b": 8192,
"open-mixtral-8x7b": 32768,
"mistral-small-latest": 32768,
"mistral-medium-latest": 32768,
"mistral-large-latest": 32768,
"claude-instant-1.2": 100000,
"claude-2.0": 100000,
"claude-2.1": 200000,
"claude-3-sonnet-20240229": 200000,
"claude-3-opus-20240229": 200000,
}

View file

@ -0,0 +1,162 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : visualize_graph.py
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
"""
from __future__ import annotations
import re
from abc import ABC
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.schema import UMLClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class _VisualClassView(BaseModel):
"""Protected class used by VisualGraphRepo internally.
Attributes:
package (str): The package associated with the class.
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
generalizations (List[str]): List of generalizations for the class.
compositions (List[str]): List of compositions for the class.
aggregations (List[str]): List of aggregations for the class.
"""
package: str
uml: Optional[UMLClassView] = None
generalizations: List[str] = Field(default_factory=list)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
def get_mermaid(self, align: int = 1) -> str:
"""Creates a Markdown Mermaid class diagram text.
Args:
align (int): Indent count used for alignment.
Returns:
str: The Markdown text representing the Mermaid class diagram.
"""
if not self.uml:
return ""
prefix = "\t" * align
mermaid_txt = self.uml.get_mermaid(align=align)
for i in self.generalizations:
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
for i in self.compositions:
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
for i in self.aggregations:
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
return mermaid_txt
@property
def name(self) -> str:
"""Returns the class name without the namespace prefix."""
return split_namespace(self.package)[-1]
class VisualGraphRepo(ABC):
"""Abstract base class for VisualGraphRepo."""
graph_db: GraphRepository
def __init__(self, graph_db):
self.graph_db = graph_db
class VisualDiGraphRepo(VisualGraphRepo):
"""Implementation of VisualGraphRepo for DiGraph graph repository.
This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph.
"""
@classmethod
async def load_from(cls, filename: str | Path):
"""Load a VisualDiGraphRepo instance from a file."""
graph_db = await DiGraphRepository.load_from(str(filename))
return cls(graph_db=graph_db)
async def get_mermaid_class_view(self) -> str:
"""
Returns a Markdown Mermaid class diagram code block object.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mermaid_txt = "classDiagram\n"
for r in rows:
v = await self._get_class_view(ns_class_name=r.subject)
mermaid_txt += v.get_mermaid()
return mermaid_txt
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
"""Returns the Markdown Mermaid class diagram code block object for the specified class."""
rows = await self.graph_db.select(subject=ns_class_name)
class_view = _VisualClassView(package=ns_class_name)
for r in rows:
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
class_view.uml = UMLClassView.model_validate_json(r.object_)
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.generalizations.append(name)
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.compositions.append(name)
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.aggregations.append(name)
return class_view
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys."""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views
@staticmethod
def _refine_name(name: str) -> str:
"""Removes impurity content from the given name.
Example:
>>> _refine_name("int")
""
>>> _refine_name('"Class1"')
'Class1'
>>> _refine_name("pkg.Class1")
"Class1"
"""
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
return ""
if "." in name:
name = name.split(".")[-1]
return name
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys."""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views

View file

@ -11,10 +11,11 @@ typer==0.9.0
# godot==0.1.1
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.4.0
langchain==0.0.352
langchain==0.1.8
sqlalchemy==2.0.0 # along with langchain
loguru==0.6.0
meilisearch==0.21.0
numpy>=1.24.3
numpy>=1.24.3,<1.25.0
openai==1.6.0
openpyxl
beautifulsoup4==4.12.2
@ -27,13 +28,13 @@ python_docx==0.8.11
PyYAML==6.0.1
# sentence_transformers==2.2.2
setuptools==65.6.3
tenacity==8.2.2
tenacity==8.2.3
tiktoken==0.5.2
tqdm==4.65.0
#unstructured[local-inference]
# selenium>4
# webdriver_manager<3.9
anthropic==0.8.1
anthropic==0.18.1
typing-inspect==0.8.0
libcst==1.0.1
qdrant-client==1.7.0
@ -54,7 +55,7 @@ rich==13.6.0
nbclient==0.9.0
nbformat==5.9.2
ipython==8.17.2
ipykernel==6.27.0
ipykernel==6.27.1
scikit_learn==1.3.2
typing-extensions==4.9.0
socksio~=1.0.0
@ -68,3 +69,5 @@ anytree
ipywidgets==8.1.1
Pillow
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1

View file

@ -57,7 +57,7 @@ extras_require["dev"] = (["pylint~=3.0.3", "black~=23.3.0", "isort~=5.12.0", "pr
setup(
name="metagpt",
version="0.7.0",
version="0.7.4",
description="The Multi-Agent Framework",
long_description=long_description,
long_description_content_type="text/markdown",

View file

@ -12,6 +12,7 @@ import logging
import os
import re
import uuid
from pathlib import Path
from typing import Callable
import aiohttp.web
@ -270,3 +271,11 @@ def mermaid_mocker(aiohttp_mocker, mermaid_rsp_cache):
aiohttp_mocker.rsp_cache = mermaid_rsp_cache
aiohttp_mocker.check_funcs = check_funcs
yield check_funcs
@pytest.fixture
def git_dir():
"""Fixture to get the unittest directory."""
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
git_dir.mkdir(parents=True, exist_ok=True)
return git_dir

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -149,7 +149,7 @@ sequenceDiagram
The requirement analysis suggests the need for a clean and intuitive interface. Since we are using a command-line interface, we need to ensure that the text-based UI is as user-friendly as possible. Further clarification on whether a graphical user interface (GUI) is expected in the future would be helpful for planning the extendability of the game."""
TASKS_SAMPLE = """
TASK_SAMPLE = """
## Required Python packages
- random==2.2.1
@ -345,7 +345,7 @@ REFINED_DESIGN_JSON = {
"Anything UNCLEAR": "",
}
REFINED_TASKS_JSON = {
REFINED_TASK_JSON = {
"Required Python packages": ["random==2.2.1", "Tkinter==8.6"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Refined Logic Analysis": [
@ -373,7 +373,14 @@ REFINED_TASKS_JSON = {
}
CODE_PLAN_AND_CHANGE_SAMPLE = {
"Code Plan And Change": '\n1. Plan for gui.py: Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.\n```python\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```\n\n2. Plan for main.py: Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.\n```python\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n'
"Development Plan": [
"Develop the GUI using Tkinter to replace the command-line interface. Start by setting up the main window and event handling. Then, add widgets for displaying the game status, results, and feedback. Implement interactive elements for difficulty selection and visualize the guess history. Finally, create animations for guess feedback and ensure responsiveness across different screen sizes.",
"Modify the main.py to initialize the GUI and start the event-driven game loop. Ensure that the GUI is the primary interface for user interaction.",
],
"Incremental Change": [
"""```diff\nclass GUI:\n- pass\n+ def __init__(self):\n+ self.setup_window()\n+\n+ def setup_window(self):\n+ # Initialize the main window using Tkinter\n+ pass\n+\n+ def bind_events(self):\n+ # Bind button clicks and other events\n+ pass\n+\n+ def update_feedback(self, message: str):\n+ # Update the feedback label with the given message\n+ pass\n+\n+ def update_attempts(self, attempts: int):\n+ # Update the attempts label with the number of attempts\n+ pass\n+\n+ def update_history(self, history: list):\n+ # Update the history view with the list of past guesses\n+ pass\n+\n+ def show_difficulty_selector(self):\n+ # Show buttons or a dropdown for difficulty selection\n+ pass\n+\n+ def animate_guess_result(self, correct: bool):\n+ # Trigger an animation for correct or incorrect guesses\n+ pass\n```""",
"""```diff\nclass Main:\n def main(self):\n- user_interface = UI()\n- user_interface.start()\n+ graphical_user_interface = GUI()\n+ graphical_user_interface.setup_window()\n+ graphical_user_interface.bind_events()\n+ # Start the Tkinter main loop\n+ pass\n\n if __name__ == "__main__":\n main_instance = Main()\n main_instance.main()\n```\n\n3. Plan for ui.py: Refactor ui.py to work with the new GUI class. Remove command-line interactions and delegate display and input tasks to the GUI.\n```python\nclass UI:\n- def display_message(self, message: str):\n- print(message)\n+\n+ def display_message(self, message: str):\n+ # This method will now pass the message to the GUI to display\n+ pass\n\n- def get_user_input(self, prompt: str) -> str:\n- return input(prompt)\n+\n+ def get_user_input(self, prompt: str) -> str:\n+ # This method will now trigger the GUI to get user input\n+ pass\n\n- def show_attempts(self, attempts: int):\n- print(f"Number of attempts: {attempts}")\n+\n+ def show_attempts(self, attempts: int):\n+ # This method will now update the GUI with the number of attempts\n+ pass\n\n- def show_history(self, history: list):\n- print("Guess history:")\n- for guess in history:\n- print(guess)\n+\n+ def show_history(self, history: list):\n+ # This method will now update the GUI with the guess history\n+ pass\n```\n\n4. Plan for game.py: Ensure game.py remains mostly unchanged as it contains the core game logic. However, make minor adjustments if necessary to integrate with the new GUI.\n```python\nclass Game:\n # No changes required for now\n```\n""",
],
}
REFINED_CODE_INPUT_SAMPLE = """

File diff suppressed because one or more lines are too long

View file

@ -1,6 +1,6 @@
import pytest
from metagpt.actions.mi.ask_review import AskReview
from metagpt.actions.di.ask_review import AskReview
@pytest.mark.asyncio

Some files were not shown because too many files have changed in this diff Show more