mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-23 15:48:11 +02:00
Merge branch 'main' into code_interpreter
This commit is contained in:
commit
38f21137ec
146 changed files with 4466 additions and 1375 deletions
|
|
@ -22,9 +22,9 @@ from metagpt.actions.write_code_review import WriteCodeReview
|
|||
from metagpt.actions.write_prd import WritePRD
|
||||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.actions.mi.write_plan import WritePlan
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.di.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.actions.di.write_plan import WritePlan
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
||||
IMPLEMENTATION_APPROACH = ActionNode(
|
||||
|
|
@ -109,14 +108,3 @@ REFINED_NODES = [
|
|||
|
||||
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
|
||||
REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = DESIGN_API_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_DESIGN_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from __future__ import annotations
|
|||
import json
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.prompts.mi.write_analysis_code import (
|
||||
from metagpt.prompts.di.write_analysis_code import (
|
||||
CHECK_DATA_PROMPT,
|
||||
DEBUG_REFLECTION_EXAMPLE,
|
||||
INTERPRETER_SYSTEM_MSG,
|
||||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
|
||||
REQUIRED_PYTHON_PACKAGES = ActionNode(
|
||||
key="Required Python packages",
|
||||
|
|
@ -119,14 +118,3 @@ REFINED_NODES = [
|
|||
|
||||
PM_NODE = ActionNode.from_children("PM_NODE", NODES)
|
||||
REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = PM_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
prompt = REFINED_PM_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : rebuild_class_view.py
|
||||
@Desc : Rebuild class view info
|
||||
@Desc : Reconstructs class diagram from a source code project.
|
||||
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
|
||||
"""
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import aiofiles
|
||||
|
||||
|
|
@ -21,86 +23,144 @@ from metagpt.const import (
|
|||
GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, RepoParser
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
"""
|
||||
Reconstructs a graph repository about class diagram from a source code project.
|
||||
|
||||
Attributes:
|
||||
graph_db (Optional[GraphRepository]): The optional graph repository.
|
||||
"""
|
||||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
Args:
|
||||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
format (str): The format for the prompt schema.
|
||||
"""
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=Path(self.i_context))
|
||||
# use pylint
|
||||
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
|
||||
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
|
||||
await GraphRepository.rebuild_composition_relationship(self.graph_db)
|
||||
# use ast
|
||||
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
for file_info in symbols:
|
||||
# Align to the same root directory in accordance with `class_views`.
|
||||
file_info.file = self._align_root(file_info.file, direction, diff_path)
|
||||
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
|
||||
await self._create_mermaid_class_views(graph_db=graph_db)
|
||||
await graph_db.save()
|
||||
await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info)
|
||||
await self._create_mermaid_class_views()
|
||||
await self.graph_db.save()
|
||||
|
||||
async def _create_mermaid_class_views(self, graph_db):
|
||||
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
|
||||
async def _create_mermaid_class_views(self) -> str:
|
||||
"""Creates a Mermaid class diagram using data from the `graph_db` graph repository.
|
||||
|
||||
This method utilizes information stored in the graph repository to generate a Mermaid class diagram.
|
||||
Returns:
|
||||
mermaid class diagram file name.
|
||||
"""
|
||||
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = path / self.context.git_repo.workdir.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
|
||||
filename = str(pathname.with_suffix(".class_diagram.mmd"))
|
||||
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
|
||||
content = "classDiagram\n"
|
||||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
# class names
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
class_distinct = set()
|
||||
relationship_distinct = set()
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
|
||||
content = await self._create_mermaid_class(r.subject)
|
||||
if content:
|
||||
await writer.write(content)
|
||||
class_distinct.add(r.subject)
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
|
||||
content, distinct = await self._create_mermaid_relationship(r.subject)
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
relationship_distinct.update(distinct)
|
||||
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
|
||||
if self.i_context:
|
||||
r_filename = Path(filename).relative_to(self.context.git_repo.workdir)
|
||||
await self.graph_db.insert(
|
||||
subject=self.i_context, predicate="hasMermaidClassDiagramFile", object_=str(r_filename)
|
||||
)
|
||||
logger.info(f"{self.i_context} hasMermaidClassDiagramFile {filename}")
|
||||
return filename
|
||||
|
||||
async def _create_mermaid_class(self, ns_class_name) -> str:
|
||||
"""Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created.
|
||||
|
||||
Returns:
|
||||
str: A Mermaid code block object in markdown representing the class diagram.
|
||||
"""
|
||||
fields = split_namespace(ns_class_name)
|
||||
if len(fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
return ""
|
||||
|
||||
class_view = ClassView(name=fields[1])
|
||||
rows = await graph_db.select(subject=ns_class_name)
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
|
||||
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
|
||||
attribute = ClassAttribute(
|
||||
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
|
||||
)
|
||||
class_view.attributes.append(attribute)
|
||||
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
|
||||
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
|
||||
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
|
||||
class_view.methods.append(method)
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
|
||||
if not rows:
|
||||
return ""
|
||||
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
|
||||
class_view = UMLClassView.load_dot_class_info(dot_class_info)
|
||||
|
||||
# update graph db
|
||||
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
# update uml view
|
||||
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
# update uml isCompositeOf
|
||||
for c in dot_class_info.compositions:
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name,
|
||||
predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF,
|
||||
object_=concat_namespace("?", c),
|
||||
)
|
||||
|
||||
# update uml isAggregateOf
|
||||
for a in dot_class_info.aggregations:
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name,
|
||||
predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF,
|
||||
object_=concat_namespace("?", a),
|
||||
)
|
||||
|
||||
content = class_view.get_mermaid(align=1)
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
distinct.add(ns_class_name)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
|
||||
async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[Optional[str], Optional[Set]]:
|
||||
"""Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created.
|
||||
|
||||
Returns:
|
||||
Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication.
|
||||
"""
|
||||
s_fields = split_namespace(ns_class_name)
|
||||
if len(s_fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
return None, None
|
||||
|
||||
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
|
||||
mappings = {
|
||||
|
|
@ -109,8 +169,9 @@ class RebuildClassView(Action):
|
|||
AGGREGATION: " o-- ",
|
||||
}
|
||||
content = ""
|
||||
distinct = set()
|
||||
for p, v in predicates.items():
|
||||
rows = await graph_db.select(subject=ns_class_name, predicate=p)
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=p)
|
||||
for r in rows:
|
||||
o_fields = split_namespace(r.object_)
|
||||
if len(o_fields) > 2:
|
||||
|
|
@ -121,86 +182,26 @@ class RebuildClassView(Action):
|
|||
distinct.add(link)
|
||||
content += f"\t{link}\n"
|
||||
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(name: str, language="python"):
|
||||
pattern = re.compile(r"<I>(.*?)<\/I>")
|
||||
result = re.search(pattern, name)
|
||||
|
||||
abstraction = ""
|
||||
if result:
|
||||
name = result.group(1)
|
||||
abstraction = "*"
|
||||
if name.startswith("__"):
|
||||
visibility = "-"
|
||||
elif name.startswith("_"):
|
||||
visibility = "#"
|
||||
else:
|
||||
visibility = "+"
|
||||
return name, visibility, abstraction
|
||||
|
||||
@staticmethod
|
||||
async def _parse_variable_type(ns_name, graph_db) -> str:
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
|
||||
if not rows:
|
||||
return ""
|
||||
vals = rows[0].object_.replace("'", "").split(":")
|
||||
if len(vals) == 1:
|
||||
return ""
|
||||
val = vals[-1].strip()
|
||||
return "" if val == "NoneType" else val + " "
|
||||
|
||||
@staticmethod
|
||||
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
|
||||
if not rows:
|
||||
return
|
||||
info = rows[0].object_.replace("'", "")
|
||||
|
||||
fs_tag = "("
|
||||
ix = info.find(fs_tag)
|
||||
fe_tag = "):"
|
||||
eix = info.rfind(fe_tag)
|
||||
if eix < 0:
|
||||
fe_tag = ")"
|
||||
eix = info.rfind(fe_tag)
|
||||
args_info = info[ix + len(fs_tag) : eix].strip()
|
||||
method.return_type = info[eix + len(fe_tag) :].strip()
|
||||
if method.return_type == "None":
|
||||
method.return_type = ""
|
||||
if "(" in method.return_type:
|
||||
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
|
||||
|
||||
# parse args
|
||||
if not args_info:
|
||||
return
|
||||
splitter_ixs = []
|
||||
cost = 0
|
||||
for i in range(len(args_info)):
|
||||
if args_info[i] == "[":
|
||||
cost += 1
|
||||
elif args_info[i] == "]":
|
||||
cost -= 1
|
||||
if args_info[i] == "," and cost == 0:
|
||||
splitter_ixs.append(i)
|
||||
splitter_ixs.append(len(args_info))
|
||||
args = []
|
||||
ix = 0
|
||||
for eix in splitter_ixs:
|
||||
args.append(args_info[ix:eix])
|
||||
ix = eix + 1
|
||||
for arg in args:
|
||||
parts = arg.strip().split(":")
|
||||
if len(parts) == 1:
|
||||
method.args.append(ClassAttribute(name=parts[0].strip()))
|
||||
continue
|
||||
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
|
||||
return content, distinct
|
||||
|
||||
@staticmethod
|
||||
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
|
||||
"""Returns the difference between the root path and the path information represented in the package name.
|
||||
|
||||
Args:
|
||||
path_root (Path): The root path.
|
||||
package_root (Path): The package root path.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part.
|
||||
|
||||
Example:
|
||||
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
|
||||
"-", "metagpt"
|
||||
|
||||
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT/metagpt"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
|
||||
"=", "."
|
||||
"""
|
||||
if len(str(path_root)) > len(str(package_root)):
|
||||
return "+", str(path_root.relative_to(package_root))
|
||||
if len(str(path_root)) < len(str(package_root)):
|
||||
|
|
@ -208,7 +209,24 @@ class RebuildClassView(Action):
|
|||
return "=", "."
|
||||
|
||||
@staticmethod
|
||||
def _align_root(path: str, direction: str, diff_path: str):
|
||||
def _align_root(path: str, direction: str, diff_path: str) -> str:
|
||||
"""Aligns the path to the same root represented by `diff_path`.
|
||||
|
||||
Args:
|
||||
path (str): The path to be aligned.
|
||||
direction (str): The direction of alignment ('+', '-', '=').
|
||||
diff_path (str): The path representing the difference.
|
||||
|
||||
Returns:
|
||||
str: The aligned path.
|
||||
|
||||
Example:
|
||||
>>> _align_root(path="metagpt/software_company.py", direction="+", diff_path="MetaGPT")
|
||||
"MetaGPT/metagpt/software_company.py"
|
||||
|
||||
>>> _align_root(path="metagpt/software_company.py", direction="-", diff_path="metagpt")
|
||||
"software_company.py"
|
||||
"""
|
||||
if direction == "=":
|
||||
return path
|
||||
if direction == "+":
|
||||
|
|
|
|||
|
|
@ -4,34 +4,214 @@
|
|||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : rebuild_sequence_view.py
|
||||
@Desc : Rebuild sequence view info
|
||||
@Desc : Reconstruct sequence view information through reverse engineering.
|
||||
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread, list_files
|
||||
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import (
|
||||
add_affix,
|
||||
aread,
|
||||
auto_namespace,
|
||||
concat_namespace,
|
||||
general_after_log,
|
||||
list_files,
|
||||
parse_json_code_block,
|
||||
read_file_block,
|
||||
split_namespace,
|
||||
)
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword
|
||||
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class ReverseUseCase(BaseModel):
|
||||
"""
|
||||
Represents a reverse engineered use case.
|
||||
|
||||
Attributes:
|
||||
description (str): A description of the reverse use case.
|
||||
inputs (List[str]): List of inputs for the reverse use case.
|
||||
outputs (List[str]): List of outputs for the reverse use case.
|
||||
actors (List[str]): List of actors involved in the reverse use case.
|
||||
steps (List[str]): List of steps for the reverse use case.
|
||||
reason (str): The reason behind the reverse use case.
|
||||
"""
|
||||
|
||||
description: str
|
||||
inputs: List[str]
|
||||
outputs: List[str]
|
||||
actors: List[str]
|
||||
steps: List[str]
|
||||
reason: str
|
||||
|
||||
|
||||
class ReverseUseCaseDetails(BaseModel):
|
||||
"""
|
||||
Represents details of a reverse engineered use case.
|
||||
|
||||
Attributes:
|
||||
description (str): A description of the reverse use case details.
|
||||
use_cases (List[ReverseUseCase]): List of reverse use cases.
|
||||
relationship (List[str]): List of relationships associated with the reverse use case details.
|
||||
"""
|
||||
|
||||
description: str
|
||||
use_cases: List[ReverseUseCase]
|
||||
relationship: List[str]
|
||||
|
||||
|
||||
class RebuildSequenceView(Action):
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
entries = await RebuildSequenceView._search_main_entry(graph_db)
|
||||
for entry in entries:
|
||||
await self._rebuild_sequence_view(entry, graph_db)
|
||||
await graph_db.save()
|
||||
"""
|
||||
Represents an action to reconstruct sequence view through reverse engineering.
|
||||
|
||||
@staticmethod
|
||||
async def _search_main_entry(graph_db) -> List:
|
||||
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
Attributes:
|
||||
graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations.
|
||||
"""
|
||||
|
||||
graph_db: Optional[GraphRepository] = None
|
||||
|
||||
async def run(self, with_messages=None, format=config.prompt_schema):
|
||||
"""
|
||||
Implementation of `Action`'s `run` method.
|
||||
|
||||
Args:
|
||||
with_messages (Optional[Type]): An optional argument specifying messages to react to.
|
||||
format (str): The format for the prompt schema.
|
||||
"""
|
||||
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
|
||||
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
if not self.i_context:
|
||||
entries = await self._search_main_entry()
|
||||
else:
|
||||
entries = [SPO(subject=self.i_context, predicate="", object_="")]
|
||||
for entry in entries:
|
||||
await self._rebuild_main_sequence_view(entry)
|
||||
while await self._merge_sequence_view(entry):
|
||||
pass
|
||||
await self.graph_db.save()
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _rebuild_main_sequence_view(self, entry: SPO):
|
||||
"""
|
||||
Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the
|
||||
subject `__name__:__main__`.
|
||||
"""
|
||||
filename = entry.subject.split(":", 1)[0]
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
classes = []
|
||||
prefix = filename + ":"
|
||||
for r in rows:
|
||||
if prefix in r.subject:
|
||||
classes.append(r)
|
||||
await self._rebuild_use_case(r.subject)
|
||||
participants = await self._search_participants(split_namespace(entry.subject)[0])
|
||||
class_details = []
|
||||
class_views = []
|
||||
for c in classes:
|
||||
detail = await self._get_class_detail(c.subject)
|
||||
if not detail:
|
||||
continue
|
||||
class_details.append(detail)
|
||||
view = await self._get_uml_class_view(c.subject)
|
||||
if view:
|
||||
class_views.append(view)
|
||||
|
||||
actors = await self._get_participants(c.subject)
|
||||
participants.update(set(actors))
|
||||
|
||||
use_case_blocks = []
|
||||
for c in classes:
|
||||
use_cases = await self._get_class_use_cases(c.subject)
|
||||
use_case_blocks.append(use_cases)
|
||||
prompt_blocks = ["## Use Cases\n" + "\n".join(use_case_blocks)]
|
||||
block = "## Participants\n"
|
||||
for p in participants:
|
||||
block += f"- {p}\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Mermaid Class Views\n```mermaid\n"
|
||||
block += "\n\n".join([c.get_mermaid() for c in class_views])
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Source Code\n```python\n"
|
||||
block += await self._get_source_code(filename)
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
prompt = "\n---\n".join(prompt_blocks)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=[
|
||||
"You are a python code to Mermaid Sequence Diagram translator in function detail.",
|
||||
"Translate the given markdown text to a Mermaid Sequence Diagram.",
|
||||
"Return the merged Mermaid sequence diagram in a markdown code block format.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
|
||||
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject,
|
||||
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
|
||||
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
|
||||
)
|
||||
for c in classes:
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
|
||||
)
|
||||
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
|
||||
|
||||
async def _merge_sequence_view(self, entry: SPO) -> bool:
|
||||
"""
|
||||
Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO object representing the relationship in the graph database.
|
||||
|
||||
Returns:
|
||||
bool: True if additional information has been augmented, otherwise False.
|
||||
"""
|
||||
new_participant = await self._search_new_participant(entry)
|
||||
if not new_participant:
|
||||
return False
|
||||
|
||||
await self._merge_participant(entry, new_participant)
|
||||
return True
|
||||
|
||||
async def _search_main_entry(self) -> List:
|
||||
"""
|
||||
Asynchronously searches for the SPO object that is related to `__name__:__main__`.
|
||||
|
||||
Returns:
|
||||
List: A list containing information about the main entry in the sequence diagram.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
tag = "__name__:__main__"
|
||||
entries = []
|
||||
for r in rows:
|
||||
|
|
@ -39,24 +219,395 @@ class RebuildSequenceView(Action):
|
|||
entries.append(r)
|
||||
return entries
|
||||
|
||||
async def _rebuild_sequence_view(self, entry, graph_db):
|
||||
filename = entry.subject.split(":", 1)[0]
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _rebuild_use_case(self, ns_class_name: str):
|
||||
"""
|
||||
Asynchronously reconstructs the use case for the provided namespace-prefixed class name.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
|
||||
if rows:
|
||||
return
|
||||
content = await aread(filename=src_filename, encoding="utf-8")
|
||||
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
|
||||
data = await self.llm.aask(
|
||||
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
|
||||
|
||||
detail = await self._get_class_detail(ns_class_name)
|
||||
if not detail:
|
||||
return
|
||||
participants = set()
|
||||
participants.update(set(detail.compositions))
|
||||
participants.update(set(detail.aggregations))
|
||||
class_view = await self._get_uml_class_view(ns_class_name)
|
||||
source_code = await self._get_source_code(ns_class_name)
|
||||
|
||||
# prompt_blocks = [
|
||||
# "## Instruction\n"
|
||||
# "You are a python code to UML 2.0 Use Case translator.\n"
|
||||
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
|
||||
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
|
||||
# 'conflict with the information in "Mermaid Class Views".\n'
|
||||
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
|
||||
# "system interactions with the internal system.\n"
|
||||
# ]
|
||||
prompt_blocks = []
|
||||
block = "## Participants\n"
|
||||
for p in participants:
|
||||
block += f"- {p}\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Mermaid Class Views\n```mermaid\n"
|
||||
block += class_view.get_mermaid()
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
block = "## Source Code\n```python\n"
|
||||
block += source_code
|
||||
block += "\n```\n"
|
||||
prompt_blocks.append(block)
|
||||
prompt = "\n---\n".join(prompt_blocks)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=[
|
||||
"You are a python code to UML 2.0 Use Case translator.",
|
||||
'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".',
|
||||
"The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
|
||||
'conflict with the information in "Mermaid Class Views".',
|
||||
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
|
||||
"system interactions with the internal system.",
|
||||
"Return a markdown JSON object with:\n"
|
||||
'- a "description" key to explain what the whole source code want to do;\n'
|
||||
'- a "use_cases" key list all use cases, each use case in the list should including a `description` '
|
||||
"key describes about what the use case to do, a `inputs` key lists the input names of the use case "
|
||||
"from external sources, a `outputs` key lists the output names of the use case to external sources, "
|
||||
"a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how "
|
||||
"the use case works step by step, a `reason` key explaining under what circumstances would the "
|
||||
"external system execute this use case.\n"
|
||||
'- a "relationship" key lists all the descriptions of relationship among these use cases.\n',
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
code_blocks = parse_json_code_block(rsp)
|
||||
for block in code_blocks:
|
||||
detail = ReverseUseCaseDetails.model_validate_json(block)
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
|
||||
)
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _rebuild_sequence_view(self, ns_class_name: str):
|
||||
"""
|
||||
Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed.
|
||||
"""
|
||||
await self._rebuild_use_case(ns_class_name)
|
||||
|
||||
prompts_blocks = []
|
||||
use_case_markdown = await self._get_class_use_cases(ns_class_name)
|
||||
if not use_case_markdown: # external class
|
||||
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
|
||||
return
|
||||
block = f"## Use Cases\n{use_case_markdown}"
|
||||
prompts_blocks.append(block)
|
||||
|
||||
participants = await self._get_participants(ns_class_name)
|
||||
block = "## Participants\n" + "\n".join([f"- {s}" for s in participants])
|
||||
prompts_blocks.append(block)
|
||||
|
||||
view = await self._get_uml_class_view(ns_class_name)
|
||||
block = "## Mermaid Class Views\n```mermaid\n"
|
||||
block += view.get_mermaid()
|
||||
block += "\n```\n"
|
||||
prompts_blocks.append(block)
|
||||
|
||||
block = "## Source Code\n```python\n"
|
||||
block += await self._get_source_code(ns_class_name)
|
||||
block += "\n```\n"
|
||||
prompts_blocks.append(block)
|
||||
prompt = "\n---\n".join(prompts_blocks)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
"You are a Mermaid Sequence Diagram translator in function detail.",
|
||||
"Translate the markdown text to a Mermaid Sequence Diagram.",
|
||||
"Return a markdown mermaid code block.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
|
||||
async def _get_participants(self, ns_class_name: str) -> List[str]:
|
||||
"""
|
||||
Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO
|
||||
object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of participants in the sequence diagram.
|
||||
"""
|
||||
participants = set()
|
||||
detail = await self._get_class_detail(ns_class_name)
|
||||
if not detail:
|
||||
return []
|
||||
participants.update(set(detail.compositions))
|
||||
participants.update(set(detail.aggregations))
|
||||
return list(participants)
|
||||
|
||||
async def _get_class_use_cases(self, ns_class_name: str) -> str:
|
||||
"""
|
||||
Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information.
|
||||
|
||||
Returns:
|
||||
str: A string containing the assembled context about the use case information.
|
||||
"""
|
||||
block = ""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
|
||||
for i, r in enumerate(rows):
|
||||
detail = ReverseUseCaseDetails.model_validate_json(r.object_)
|
||||
block += f"\n### {i + 1}. {detail.description}"
|
||||
for j, use_case in enumerate(detail.use_cases):
|
||||
block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n"
|
||||
block += "\n##### Inputs\n" + "\n".join([f"- {s}" for s in use_case.inputs])
|
||||
block += "\n##### Outputs\n" + "\n".join([f"- {s}" for s in use_case.outputs])
|
||||
block += "\n##### Actors\n" + "\n".join([f"- {s}" for s in use_case.actors])
|
||||
block += "\n##### Steps\n" + "\n".join([f"- {s}" for s in use_case.steps])
|
||||
block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship])
|
||||
return block + "\n"
|
||||
|
||||
async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None:
|
||||
"""
|
||||
Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve class details.
|
||||
|
||||
Returns:
|
||||
Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details,
|
||||
or None if the details are not available.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
|
||||
if not rows:
|
||||
return None
|
||||
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
|
||||
return dot_class_info
|
||||
|
||||
async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None:
|
||||
"""
|
||||
Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details.
|
||||
|
||||
Returns:
|
||||
Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details,
|
||||
or None if the details are not available.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
if not rows:
|
||||
return None
|
||||
class_view = UMLClassView.model_validate_json(rows[0].object_)
|
||||
return class_view
|
||||
|
||||
async def _get_source_code(self, ns_class_name: str) -> str:
|
||||
"""
|
||||
Asynchronously retrieves the source code of the namespace-prefixed SPO object.
|
||||
|
||||
Args:
|
||||
ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code.
|
||||
|
||||
Returns:
|
||||
str: A string containing the source code of the specified namespace-prefixed class.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
filename = split_namespace(ns_class_name=ns_class_name)[0]
|
||||
if not rows:
|
||||
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
|
||||
if not src_filename:
|
||||
return ""
|
||||
return await aread(filename=src_filename, encoding="utf-8")
|
||||
code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_)
|
||||
return await read_file_block(
|
||||
filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno
|
||||
)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
|
||||
logger.info(data)
|
||||
|
||||
@staticmethod
|
||||
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
|
||||
"""
|
||||
Convert package name to the full path of the module.
|
||||
|
||||
Args:
|
||||
root (Union[str, Path]): The root path or string representing the package.
|
||||
pathname (Union[str, Path]): The pathname or string representing the module.
|
||||
|
||||
Returns:
|
||||
Union[Path, None]: The full path of the module, or None if the path cannot be determined.
|
||||
|
||||
Examples:
|
||||
If `root`(workdir) is "/User/xxx/github/MetaGPT/metagpt", and the `pathname` is
|
||||
"metagpt/management/skill_manager.py", then the returned value will be
|
||||
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
|
||||
"""
|
||||
if re.match(r"^/.+", pathname):
|
||||
return pathname
|
||||
files = list_files(root=root)
|
||||
postfix = "/" + str(pathname)
|
||||
for i in files:
|
||||
if str(i).endswith(postfix):
|
||||
return i
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
|
||||
"""
|
||||
Parses the provided Mermaid sequence diagram and returns the list of participants.
|
||||
|
||||
Args:
|
||||
mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of participants extracted from the sequence diagram.
|
||||
"""
|
||||
pattern = r"participant ([a-zA-Z\.0-9_]+)"
|
||||
matches = re.findall(pattern, mermaid_sequence_diagram)
|
||||
matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches]
|
||||
return matches
|
||||
|
||||
async def _search_new_participant(self, entry: SPO) -> str | None:
|
||||
"""
|
||||
Asynchronously retrieves a participant whose sequence diagram has not been augmented.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO object representing the relationship in the graph database.
|
||||
|
||||
Returns:
|
||||
Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found.
|
||||
"""
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
if not rows:
|
||||
return None
|
||||
sequence_view = rows[0].object_
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
|
||||
merged_participants = []
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
merged_participants.append(name)
|
||||
participants = self.parse_participant(sequence_view)
|
||||
for p in participants:
|
||||
if p in merged_participants:
|
||||
continue
|
||||
return p
|
||||
return None
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=20),
|
||||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _merge_participant(self, entry: SPO, class_name: str):
|
||||
"""
|
||||
Augments the sequence diagram of `class_name` to the sequence diagram of `entry`.
|
||||
|
||||
Args:
|
||||
entry (SPO): The SPO object representing the base sequence diagram.
|
||||
class_name (str): The class name whose sequence diagram is to be augmented.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
participants = []
|
||||
for r in rows:
|
||||
name = split_namespace(r.subject)[-1]
|
||||
if name == class_name:
|
||||
participants.append(r)
|
||||
if len(participants) == 0: # external participants
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
|
||||
)
|
||||
return
|
||||
if len(participants) > 1:
|
||||
for r in participants:
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
|
||||
)
|
||||
return
|
||||
|
||||
participant = participants[0]
|
||||
await self._rebuild_sequence_view(participant.subject)
|
||||
sequence_views = await self.graph_db.select(
|
||||
subject=participant.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW
|
||||
)
|
||||
if not sequence_views: # external class
|
||||
return
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
prompt = f"```mermaid\n{sequence_views[0].object_}\n```\n---\n```mermaid\n{rows[0].object_}\n```"
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
prompt,
|
||||
system_msgs=[
|
||||
"You are a tool to merge sequence diagrams into one.",
|
||||
"Participants with the same name are considered identical.",
|
||||
"Return the merged Mermaid sequence diagram in a markdown code block format.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
|
||||
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject,
|
||||
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
|
||||
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
|
||||
)
|
||||
await self.graph_db.insert(
|
||||
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
|
||||
)
|
||||
await self._save_sequence_view(subject=entry.subject, content=sequence_view)
|
||||
|
||||
async def _save_sequence_view(self, subject: str, content: str):
|
||||
pattern = re.compile(r"[^a-zA-Z0-9]")
|
||||
name = re.sub(pattern, "_", subject)
|
||||
filename = Path(name).with_suffix(".sequence_diagram.mmd")
|
||||
await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)
|
||||
|
||||
async def _search_participants(self, filename: str) -> Set:
|
||||
content = await self._get_source_code(filename)
|
||||
|
||||
rsp = await self.llm.aask(
|
||||
msg=content,
|
||||
system_msgs=[
|
||||
"You are a tool for listing all class names used in a source file.",
|
||||
"Return a markdown JSON object with: "
|
||||
'- a "class_names" key containing the list of class names used in the file; '
|
||||
'- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
|
||||
],
|
||||
)
|
||||
|
||||
class _Data(BaseModel):
|
||||
class_names: List[str]
|
||||
reasons: List
|
||||
|
||||
json_blocks = parse_json_code_block(rsp)
|
||||
data = _Data.model_validate_json(json_blocks[0])
|
||||
return set(data.class_names)
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/4
|
||||
@Author : mashenquan
|
||||
@File : rebuild_sequence_view_an.py
|
||||
"""
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.utils.mermaid import MMC2
|
||||
|
||||
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
|
||||
key="Program call flow",
|
||||
expected_type=str,
|
||||
instruction='Translate the "context" content into "format example" format.',
|
||||
example=MMC2,
|
||||
)
|
||||
|
|
@ -50,6 +50,7 @@ class ArgumentsParingAction(Action):
|
|||
rsp = await self.llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
|
||||
stream=False,
|
||||
)
|
||||
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
|
||||
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class TalkAction(Action):
|
|||
|
||||
async def run(self, with_message=None, **kwargs) -> Message:
|
||||
msg, format_msgs, system_msgs = self.aask_args
|
||||
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
|
||||
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False)
|
||||
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
|
||||
return self.rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -23,11 +23,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST
|
||||
from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE
|
||||
from metagpt.const import (
|
||||
BUGFIX_FILENAME,
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.const import BUGFIX_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -98,8 +94,6 @@ class WriteCode(Action):
|
|||
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
|
||||
coding_context = CodingContext.loads(self.i_context.content)
|
||||
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
|
||||
code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
|
||||
code_plan_and_change = code_plan_and_change_doc.content if code_plan_and_change_doc else ""
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
summary_doc = None
|
||||
if coding_context.design_doc and coding_context.design_doc.filename:
|
||||
|
|
@ -111,7 +105,7 @@ class WriteCode(Action):
|
|||
|
||||
if bug_feedback:
|
||||
code_context = coding_context.code_doc.content
|
||||
elif code_plan_and_change:
|
||||
elif self.config.inc:
|
||||
code_context = await self.get_codes(
|
||||
coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True
|
||||
)
|
||||
|
|
@ -122,10 +116,10 @@ class WriteCode(Action):
|
|||
project_repo=self.repo.with_src_path(self.context.src_workspace),
|
||||
)
|
||||
|
||||
if code_plan_and_change:
|
||||
if self.config.inc:
|
||||
prompt = REFINED_TEMPLATE.format(
|
||||
user_requirement=requirement_doc.content if requirement_doc else "",
|
||||
code_plan_and_change=code_plan_and_change,
|
||||
code_plan_and_change=str(coding_context.code_plan_and_change_doc),
|
||||
design=coding_context.design_doc.content if coding_context.design_doc else "",
|
||||
task=coding_context.task_doc.content if coding_context.task_doc else "",
|
||||
code=code_context,
|
||||
|
|
|
|||
|
|
@ -6,30 +6,44 @@
|
|||
@File : write_code_plan_and_change_an.py
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodePlanAndChangeContext
|
||||
|
||||
CODE_PLAN_AND_CHANGE = ActionNode(
|
||||
key="Code Plan And Change",
|
||||
expected_type=str,
|
||||
instruction="Developing comprehensive and step-by-step incremental development plan, and write Incremental "
|
||||
"Change by making a code draft that how to implement incremental development including detailed steps based on the "
|
||||
"context. Note: Track incremental changes using mark of '+' or '-' for add/modify/delete code, and conforms to the "
|
||||
"output format of git diff",
|
||||
example="""
|
||||
1. Plan for calculator.py: Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, multiplication, and division. Additionally, implement robust error handling for the division operation to mitigate potential issues related to division by zero.
|
||||
```python
|
||||
DEVELOPMENT_PLAN = ActionNode(
|
||||
key="Development Plan",
|
||||
expected_type=List[str],
|
||||
instruction="Develop a comprehensive and step-by-step incremental development plan, providing the detail "
|
||||
"changes to be implemented at each step based on the order of 'Task List'",
|
||||
example=[
|
||||
"Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, ...",
|
||||
"Update the existing codebase in main.py to incorporate new API endpoints for subtraction, ...",
|
||||
],
|
||||
)
|
||||
|
||||
INCREMENTAL_CHANGE = ActionNode(
|
||||
key="Incremental Change",
|
||||
expected_type=List[str],
|
||||
instruction="Write Incremental Change by making a code draft that how to implement incremental development "
|
||||
"including detailed steps based on the context. Note: Track incremental changes using the marks `+` and `-` to "
|
||||
"indicate additions and deletions, and ensure compliance with the output format of `git diff`",
|
||||
example=[
|
||||
'''```diff
|
||||
--- Old/calculator.py
|
||||
+++ New/calculator.py
|
||||
|
||||
class Calculator:
|
||||
self.result = number1 + number2
|
||||
return self.result
|
||||
|
||||
- def sub(self, number1, number2) -> float:
|
||||
+ def subtract(self, number1: float, number2: float) -> float:
|
||||
+ '''
|
||||
+ """
|
||||
+ Subtracts the second number from the first and returns the result.
|
||||
+
|
||||
+ Args:
|
||||
|
|
@ -38,13 +52,13 @@ class Calculator:
|
|||
+
|
||||
+ Returns:
|
||||
+ float: The difference of number1 and number2.
|
||||
+ '''
|
||||
+ """
|
||||
+ self.result = number1 - number2
|
||||
+ return self.result
|
||||
+
|
||||
def multiply(self, number1: float, number2: float) -> float:
|
||||
- pass
|
||||
+ '''
|
||||
+ """
|
||||
+ Multiplies two numbers and returns the result.
|
||||
+
|
||||
+ Args:
|
||||
|
|
@ -53,15 +67,15 @@ class Calculator:
|
|||
+
|
||||
+ Returns:
|
||||
+ float: The product of number1 and number2.
|
||||
+ '''
|
||||
+ """
|
||||
+ self.result = number1 * number2
|
||||
+ return self.result
|
||||
+
|
||||
def divide(self, number1: float, number2: float) -> float:
|
||||
- pass
|
||||
+ '''
|
||||
+ """
|
||||
+ ValueError: If the second number is zero.
|
||||
+ '''
|
||||
+ """
|
||||
+ if number2 == 0:
|
||||
+ raise ValueError('Cannot divide by zero')
|
||||
+ self.result = number1 / number2
|
||||
|
|
@ -75,10 +89,11 @@ class Calculator:
|
|||
+ print("Result is already zero, no need to clear.")
|
||||
+
|
||||
self.result = 0.0
|
||||
```
|
||||
```''',
|
||||
"""```diff
|
||||
--- Old/main.py
|
||||
+++ New/main.py
|
||||
|
||||
2. Plan for main.py: Integrate new API endpoints for subtraction, multiplication, and division into the existing codebase of `main.py`. Then, ensure seamless integration with the overall application architecture and maintain consistency with coding standards.
|
||||
```python
|
||||
def add_numbers():
|
||||
result = calculator.add_numbers(num1, num2)
|
||||
return jsonify({'result': result}), 200
|
||||
|
|
@ -106,6 +121,7 @@ def add_numbers():
|
|||
if __name__ == '__main__':
|
||||
app.run()
|
||||
```""",
|
||||
],
|
||||
)
|
||||
|
||||
CODE_PLAN_AND_CHANGE_CONTEXT = """
|
||||
|
|
@ -172,14 +188,16 @@ Role: You are a professional engineer; The main goal is to complete incremental
|
|||
2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.
|
||||
3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import.
|
||||
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
|
||||
5. Follow Code Plan And Change: If there is any Incremental Change that is marked by the git diff format using '+' and '-' for add/modify/delete code, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the plan.
|
||||
5. Follow Code Plan And Change: If there is any "Incremental Change" that is marked by the git diff format with '+' and '-' symbols, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the "Development Plan".
|
||||
6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
|
||||
7. Before using a external variable/module, make sure you import it first.
|
||||
8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.
|
||||
9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code.
|
||||
"""
|
||||
|
||||
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", [CODE_PLAN_AND_CHANGE])
|
||||
CODE_PLAN_AND_CHANGE = [DEVELOPMENT_PLAN, INCREMENTAL_CHANGE]
|
||||
|
||||
WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", CODE_PLAN_AND_CHANGE)
|
||||
|
||||
|
||||
class WriteCodePlanAndChange(Action):
|
||||
|
|
@ -192,14 +210,14 @@ class WriteCodePlanAndChange(Action):
|
|||
prd_doc = await self.repo.docs.prd.get(filename=self.i_context.prd_filename)
|
||||
design_doc = await self.repo.docs.system_design.get(filename=self.i_context.design_filename)
|
||||
task_doc = await self.repo.docs.task.get(filename=self.i_context.task_filename)
|
||||
code_text = await self.get_old_codes()
|
||||
context = CODE_PLAN_AND_CHANGE_CONTEXT.format(
|
||||
requirement=self.i_context.requirement,
|
||||
prd=prd_doc.content,
|
||||
design=design_doc.content,
|
||||
task=task_doc.content,
|
||||
code=code_text,
|
||||
code=await self.get_old_codes(),
|
||||
)
|
||||
logger.info("Writing code plan and change..")
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
||||
async def get_old_codes(self) -> str:
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.const import CODE_PLAN_AND_CHANGE_FILENAME, REQUIREMENT_FILENAME
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -149,29 +149,21 @@ class WriteCodeReview(Action):
|
|||
use_inc=self.config.inc,
|
||||
)
|
||||
|
||||
if not self.config.inc:
|
||||
context = "\n".join(
|
||||
[
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
)
|
||||
else:
|
||||
ctx_list = [
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
if self.config.inc:
|
||||
requirement_doc = await self.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
code_plan_and_change_doc = await self.repo.get(filename=CODE_PLAN_AND_CHANGE_FILENAME)
|
||||
context = "\n".join(
|
||||
[
|
||||
"## User New Requirements\n" + str(requirement_doc) + "\n",
|
||||
"## Code Plan And Change\n" + str(code_plan_and_change_doc) + "\n",
|
||||
"## System Design\n" + str(self.i_context.design_doc) + "\n",
|
||||
"## Task\n" + task_content + "\n",
|
||||
"## Code Files\n" + code_context + "\n",
|
||||
]
|
||||
)
|
||||
insert_ctx_list = [
|
||||
"## User New Requirements\n" + str(requirement_doc) + "\n",
|
||||
"## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n",
|
||||
]
|
||||
ctx_list = insert_ctx_list + ctx_list
|
||||
|
||||
context_prompt = PROMPT_TEMPLATE.format(
|
||||
context=context,
|
||||
context="\n".join(ctx_list),
|
||||
code=iterative_code,
|
||||
filename=self.i_context.code_doc.filename,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ REFINED_PRODUCT_GOALS = ActionNode(
|
|||
key="Refined Product Goals",
|
||||
expected_type=List[str],
|
||||
instruction="Update and expand the original product goals to reflect the evolving needs due to incremental "
|
||||
"development.Ensure that the refined goals align with the current project direction and contribute to its success.",
|
||||
"development. Ensure that the refined goals align with the current project direction and contribute to its success.",
|
||||
example=[
|
||||
"Enhance user engagement through new features",
|
||||
"Optimize performance for scalability",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
class LLMType(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
CLAUDE = "claude" # alias name of anthropic
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
|
|
@ -24,6 +25,10 @@ class LLMType(Enum):
|
|||
METAGPT = "metagpt"
|
||||
AZURE = "azure"
|
||||
OLLAMA = "ollama"
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
|
||||
MOONSHOT = "moonshot"
|
||||
MISTRAL = "mistral"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
@ -36,12 +41,18 @@ class LLMConfig(YamlModel):
|
|||
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
api_key: str = "sk-"
|
||||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters.
|
||||
|
||||
# For Cloud Service Provider like Baidu/ Alibaba
|
||||
access_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
endpoint: Optional[str] = None # for self-deployed model on the cloud
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ MESSAGE_ROUTE_TO_NONE = "<none>"
|
|||
REQUIREMENT_FILENAME = "requirement.txt"
|
||||
BUGFIX_FILENAME = "bugfix.txt"
|
||||
PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt"
|
||||
CODE_PLAN_AND_CHANGE_FILENAME = "code_plan_and_change.json"
|
||||
|
||||
DOCS_FILE_REPO = "docs"
|
||||
PRDS_FILE_REPO = "docs/prd"
|
||||
|
|
@ -105,6 +104,7 @@ CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
|
|||
RESOURCES_FILE_REPO = "resources"
|
||||
SD_OUTPUT_FILE_REPO = "resources/sd_output"
|
||||
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
|
||||
VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db"
|
||||
CLASS_VIEW_FILE_REPO = "docs/class_view"
|
||||
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
|
|
|
|||
|
|
@ -12,10 +12,14 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import (
|
||||
CostManager,
|
||||
FireworksCostManager,
|
||||
TokenCostManager,
|
||||
)
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
|
@ -80,12 +84,21 @@ class Context(BaseModel):
|
|||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
return FireworksCostManager()
|
||||
elif llm_config.api_type == LLMType.OPEN_LLM:
|
||||
return TokenCostManager()
|
||||
else:
|
||||
return self.cost_manager
|
||||
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.llm)
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
self._llm.cost_manager = self._select_costmanager(self.config.llm)
|
||||
return self._llm
|
||||
|
||||
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
|
||||
|
|
@ -93,5 +106,5 @@ class Context(BaseModel):
|
|||
# if self._llm is None:
|
||||
llm = create_llm_instance(llm_config)
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
llm.cost_manager = self._select_costmanager(llm_config)
|
||||
return llm
|
||||
|
|
|
|||
|
|
@ -11,15 +11,16 @@ from pathlib import Path
|
|||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from langchain.document_loaders import (
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
|
||||
|
||||
|
|
@ -130,9 +131,12 @@ class IndexableDocument(Document):
|
|||
if isinstance(data, pd.DataFrame):
|
||||
validate_cols(content_col, data)
|
||||
return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col)
|
||||
else:
|
||||
try:
|
||||
content = data_path.read_text()
|
||||
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
|
||||
except Exception as e:
|
||||
logger.debug(f"Load {str(data_path)} error: {e}")
|
||||
content = ""
|
||||
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
|
||||
|
||||
def _get_docs_and_metadatas_by_df(self) -> (list, list):
|
||||
df = self.data
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class BrainMemory(BaseModel):
|
|||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
logger.debug(f"title ask:{msg}")
|
||||
response = await llm.aask(msg=msg, system_msgs=[])
|
||||
response = await llm.aask(msg=msg, system_msgs=[], stream=False)
|
||||
logger.debug(f"title rsp: {response}")
|
||||
return response
|
||||
|
||||
|
|
@ -201,11 +201,15 @@ class BrainMemory(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
async def _openai_is_related(text1, text2, llm, **kwargs):
|
||||
command = (
|
||||
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
|
||||
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
|
||||
context = f"## Paragraph 1\n{text2}\n---\n## Paragraph 2\n{text1}\n"
|
||||
rsp = await llm.aask(
|
||||
msg=context,
|
||||
system_msgs=[
|
||||
"You are a tool capable of determining whether two paragraphs are semantically related."
|
||||
'Return "TRUE" if "Paragraph 1" is semantically relevant to "Paragraph 2", otherwise return "FALSE".'
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
result = True if "TRUE" in rsp else False
|
||||
p2 = text2.replace("\n", "")
|
||||
p1 = text1.replace("\n", "")
|
||||
|
|
@ -223,12 +227,17 @@ class BrainMemory(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
async def _openai_rewrite(sentence: str, context: str, llm):
|
||||
command = (
|
||||
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
|
||||
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
|
||||
prompt = f"## Context\n{context}\n---\n## Sentence\n{sentence}\n"
|
||||
rsp = await llm.aask(
|
||||
msg=prompt,
|
||||
system_msgs=[
|
||||
'You are a tool augmenting the "Sentence" with information from the "Context".',
|
||||
"Do not supplement the context with information that is not present, especially regarding the subject and object.",
|
||||
"Return the augmented sentence.",
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
|
||||
logger.info(f"REWRITE:\nCommand: {prompt}\nRESULT: {rsp}\n")
|
||||
return rsp
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -293,14 +302,14 @@ class BrainMemory(BaseModel):
|
|||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
system_msgs = [
|
||||
"You are a tool for summarizing and abstracting text.",
|
||||
f"Return the summarized text to less than {max_words} words.",
|
||||
]
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
system_msgs.append("The generated summary should be in the same language as the original text.")
|
||||
response = await self.llm.aask(msg=text, system_msgs=system_msgs, stream=False)
|
||||
logger.debug(f"{text}\nsummary rsp: {response}")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
|
@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL
|
|||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
from metagpt.utils.serialize import deserialize_message, serialize_message
|
||||
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ class MemoryStorage(FaissStore):
|
|||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -6,21 +6,20 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
"GeminiLLM",
|
||||
"OpenLLM",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAILLM",
|
||||
"AzureOpenAILLM",
|
||||
|
|
@ -28,4 +27,7 @@ __all__ = [
|
|||
"OllamaLLM",
|
||||
"HumanProvider",
|
||||
"SparkLLM",
|
||||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,37 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/21 11:15
|
||||
@Author : Leo Xiao
|
||||
@File : anthropic_api.py
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message, Usage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
class Claude2:
|
||||
@register_provider([LLMType.ANTHROPIC, LLMType.CLAUDE])
|
||||
class AnthropicLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__init_anthropic()
|
||||
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=self.config.api_key)
|
||||
def __init_anthropic(self):
|
||||
self.model = self.config.model
|
||||
self.aclient: AsyncAnthropic = AsyncAnthropic(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
|
||||
res = client.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
)
|
||||
return res.completion
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.config.max_token,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=self.config.api_key)
|
||||
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
|
||||
usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens}
|
||||
super()._update_costs(usage, model)
|
||||
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
)
|
||||
return res.completion
|
||||
def get_choice_text(self, resp: Message) -> str:
|
||||
return resp.content[0].text
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
|
||||
self._update_costs(resp.usage, self.model)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = Usage(input_tokens=0, output_tokens=0)
|
||||
async for event in stream:
|
||||
event_type = event.type
|
||||
if event_type == "message_start":
|
||||
usage.input_tokens = event.message.usage.input_tokens
|
||||
usage.output_tokens = event.message.usage.output_tokens
|
||||
elif event_type == "content_block_delta":
|
||||
content = event.delta.text
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
elif event_type == "message_delta":
|
||||
usage.output_tokens = event.usage.output_tokens # update final output_tokens
|
||||
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
|
|
@ -27,6 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
|
|||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.aclient = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
self.pricing_plan = self.config.pricing_plan
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
|
|
|
|||
|
|
@ -6,16 +6,29 @@
|
|||
@File : base_llm.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types import CompletionUsage
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.common import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
|
|
@ -29,6 +42,7 @@ class BaseLLM(ABC):
|
|||
aclient: Optional[Union[AsyncOpenAI]] = None
|
||||
cost_manager: Optional[CostManager] = None
|
||||
model: Optional[str] = None
|
||||
pricing_plan: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
|
|
@ -67,6 +81,28 @@ class BaseLLM(ABC):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
|
||||
"""update each request's token cost
|
||||
Args:
|
||||
model (str): model name or in some scenarios called endpoint
|
||||
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
|
||||
"""
|
||||
calc_usage = self.config.calc_usage and local_calc_usage
|
||||
model = model or self.model
|
||||
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
|
||||
if calc_usage and self.cost_manager:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
|
|
@ -108,6 +144,10 @@ class BaseLLM(ABC):
|
|||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
|
|
@ -120,8 +160,22 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages, timeout=timeout)
|
||||
resp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(resp)
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
@ -171,6 +225,20 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage | Dict):
|
||||
"""
|
||||
Updates the costs based on the provided usage information.
|
||||
"""
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
if isinstance(usage, Dict):
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
else:
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
completion_tokens = usage.completion_tokens
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
227
metagpt/provider/dashscope_api.py
Normal file
227
metagpt/provider/dashscope_api.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
import dashscope
|
||||
from dashscope.aigc.generation import Generation
|
||||
from dashscope.api_entities.aiohttp_request import AioHttpRequest
|
||||
from dashscope.api_entities.api_request_data import ApiRequestData
|
||||
from dashscope.api_entities.api_request_factory import _get_protocol_params
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
GenerationOutput,
|
||||
GenerationResponse,
|
||||
Message,
|
||||
)
|
||||
from dashscope.client.base_api import BaseAioApi
|
||||
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
|
||||
from dashscope.common.error import (
|
||||
InputDataRequired,
|
||||
InputRequired,
|
||||
ModelRequired,
|
||||
UnsupportedApiProtocol,
|
||||
)
|
||||
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM, LLMConfig
|
||||
from metagpt.provider.llm_provider_registry import LLMType, register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
|
||||
|
||||
|
||||
def build_api_arequest(
|
||||
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
|
||||
):
|
||||
(
|
||||
api_protocol,
|
||||
ws_stream_mode,
|
||||
is_binary_input,
|
||||
http_method,
|
||||
stream,
|
||||
async_request,
|
||||
query,
|
||||
headers,
|
||||
request_timeout,
|
||||
form,
|
||||
resources,
|
||||
) = _get_protocol_params(kwargs)
|
||||
task_id = kwargs.pop("task_id", None)
|
||||
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
|
||||
if not dashscope.base_http_api_url.endswith("/"):
|
||||
http_url = dashscope.base_http_api_url + "/"
|
||||
else:
|
||||
http_url = dashscope.base_http_api_url
|
||||
|
||||
if is_service:
|
||||
http_url = http_url + SERVICE_API_PATH + "/"
|
||||
|
||||
if task_group:
|
||||
http_url += "%s/" % task_group
|
||||
if task:
|
||||
http_url += "%s/" % task
|
||||
if function:
|
||||
http_url += function
|
||||
request = AioHttpRequest(
|
||||
url=http_url,
|
||||
api_key=api_key,
|
||||
http_method=http_method,
|
||||
stream=stream,
|
||||
async_request=async_request,
|
||||
query=query,
|
||||
timeout=request_timeout,
|
||||
task_id=task_id,
|
||||
)
|
||||
else:
|
||||
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
|
||||
|
||||
if headers is not None:
|
||||
request.add_headers(headers=headers)
|
||||
|
||||
if input is None and form is None:
|
||||
raise InputDataRequired("There is no input data and form data")
|
||||
|
||||
request_data = ApiRequestData(
|
||||
model,
|
||||
task_group=task_group,
|
||||
task=task,
|
||||
function=function,
|
||||
input=input,
|
||||
form=form,
|
||||
is_binary_input=is_binary_input,
|
||||
api_protocol=api_protocol,
|
||||
)
|
||||
request_data.add_resources(resources)
|
||||
request_data.add_parameters(**kwargs)
|
||||
request.data = request_data
|
||||
return request
|
||||
|
||||
|
||||
class AGeneration(Generation, BaseAioApi):
|
||||
@classmethod
|
||||
async def acall(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: Any = None,
|
||||
history: list = None,
|
||||
api_key: str = None,
|
||||
messages: List[Message] = None,
|
||||
plugins: Union[str, Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
|
||||
if (prompt is None or not prompt) and (messages is None or not messages):
|
||||
raise InputRequired("prompt or messages is required!")
|
||||
if model is None or not model:
|
||||
raise ModelRequired("Model is required!")
|
||||
task_group, function = "aigc", "generation" # fixed value
|
||||
if plugins is not None:
|
||||
headers = kwargs.pop("headers", {})
|
||||
if isinstance(plugins, str):
|
||||
headers["X-DashScope-Plugin"] = plugins
|
||||
else:
|
||||
headers["X-DashScope-Plugin"] = json.dumps(plugins)
|
||||
kwargs["headers"] = headers
|
||||
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
|
||||
|
||||
api_key, model = BaseAioApi._validate_params(api_key, model)
|
||||
request = build_api_arequest(
|
||||
model=model,
|
||||
input=input,
|
||||
task_group=task_group,
|
||||
task=Generation.task,
|
||||
function=function,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
response = await request.aio_call()
|
||||
is_stream = kwargs.get("stream", False)
|
||||
if is_stream:
|
||||
|
||||
async def aresp_iterator(response):
|
||||
async for resp in response:
|
||||
yield GenerationResponse.from_api_response(resp)
|
||||
|
||||
return aresp_iterator(response)
|
||||
else:
|
||||
return GenerationResponse.from_api_response(response)
|
||||
|
||||
|
||||
@register_provider(LLMType.DASHSCOPE)
|
||||
class DashScopeLLM(BaseLLM):
|
||||
def __init__(self, llm_config: LLMConfig):
|
||||
self.config = llm_config
|
||||
self.use_system_prompt = False # only some models support system_prompt
|
||||
self.__init_dashscope()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_dashscope(self):
|
||||
self.model = self.config.model
|
||||
self.api_key = self.config.api_key
|
||||
self.token_costs = DASHSCOPE_TOKEN_COSTS
|
||||
self.aclient: AGeneration = AGeneration
|
||||
|
||||
# check support system_message models
|
||||
support_system_models = [
|
||||
"qwen-", # all support
|
||||
"llama2-", # all support
|
||||
"baichuan2-7b-chat-v1",
|
||||
"chatglm3-6b",
|
||||
]
|
||||
for support_model in support_system_models:
|
||||
if support_model in self.model:
|
||||
self.use_system_prompt = True
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"api_key": self.api_key,
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"result_format": "message",
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it"s specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if stream:
|
||||
kwargs["incremental_output"] = True
|
||||
return kwargs
|
||||
|
||||
def _check_response(self, resp: GenerationResponse):
|
||||
if resp.status_code != HTTPStatus.OK:
|
||||
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
|
||||
|
||||
def get_choice_text(self, output: GenerationOutput) -> str:
|
||||
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> GenerationOutput:
|
||||
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
|
||||
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
self._check_response(chunk)
|
||||
content = chunk.output.choices[0]["message"]["content"]
|
||||
usage = dict(chunk.usage) # each chunk has usage
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import re
|
||||
|
||||
from openai import APIConnectionError, AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.FIREWORKS)
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config=config)
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager = FireworksCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
choice_delta = choice.delta
|
||||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
print(choice_delta.content, end="")
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -13,19 +13,11 @@ from google.generativeai.types.generation_types import (
|
|||
GenerateContentResponse,
|
||||
GenerationConfig,
|
||||
)
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -55,6 +47,7 @@ class GeminiLLM(BaseLLM):
|
|||
self.__init_gemini(config)
|
||||
self.config = config
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
|
|
@ -72,16 +65,6 @@ class GeminiLLM(BaseLLM):
|
|||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
@ -105,16 +88,16 @@ class GeminiLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
usage = await self.aget_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
|
||||
**self._const_kwargs(messages, stream=True)
|
||||
)
|
||||
|
|
@ -129,17 +112,3 @@ class GeminiLLM(BaseLLM):
|
|||
usage = await self.aget_usage(messages, full_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -35,10 +35,16 @@ class HumanProvider(BaseLLM):
|
|||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ class LLMProviderRegistry:
|
|||
return self.providers[enum]
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
def register_provider(keys):
|
||||
"""register provider to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
else:
|
||||
LLM_REGISTRY.register(keys, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@File : metagpt_api.py
|
||||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -12,4 +14,7 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
|
||||
@register_provider(LLMType.METAGPT)
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
pass
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
# The current billing is based on usage frequency. If there is a future billing logic based on the
|
||||
# number of tokens, please refine the logic here accordingly.
|
||||
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
|
|
|
|||
|
|
@ -4,22 +4,12 @@
|
|||
|
||||
import json
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
|
|
@ -36,26 +26,17 @@ class OllamaLLM(BaseLLM):
|
|||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = TokenCostManager()
|
||||
self.cost_manager = TokenCostManager()
|
||||
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
self.pricing_plan = self.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
|
|
@ -69,7 +50,7 @@ class OllamaLLM(BaseLLM):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
|
|
@ -82,9 +63,9 @@ class OllamaLLM(BaseLLM):
|
|||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
|
|
@ -110,17 +91,3 @@ class OllamaLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import Costs, TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
@register_provider(LLMType.OPEN_LLM)
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
|
||||
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
@ -6,10 +6,11 @@
|
|||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
|
@ -28,8 +29,13 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import CodeParser, decode_image, process_message
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.common import (
|
||||
CodeParser,
|
||||
decode_image,
|
||||
log_and_reraise,
|
||||
process_message,
|
||||
)
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -38,33 +44,20 @@ from metagpt.utils.token_counter import (
|
|||
)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMType.OPENAI)
|
||||
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL])
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self._init_model()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
|
|
@ -86,22 +79,41 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
usage = None
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if finish_reason:
|
||||
if hasattr(chunk, "usage"):
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
if not usage:
|
||||
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
# "n": 1, # Some services do not provide this parameter, such as mistral
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": 0.3,
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
"timeout": max(self.config.timeout, timeout),
|
||||
}
|
||||
|
|
@ -128,18 +140,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
log_llm_stream("\n")
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -239,23 +240,13 @@ class OpenAILLM(BaseLLM):
|
|||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.pricing_plan)
|
||||
except Exception as e:
|
||||
logger.warning(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
return Costs(0, 0, 0, 0)
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return self.config.max_token
|
||||
|
|
|
|||
131
metagpt/provider/qianfan_api.py
Normal file
131
metagpt/provider/qianfan_api.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
|
||||
import copy
|
||||
import os
|
||||
|
||||
import qianfan
|
||||
from qianfan import ChatCompletion
|
||||
from qianfan.resources.typing import JsonBody
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import (
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS,
|
||||
QIANFAN_MODEL_TOKEN_COSTS,
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.QIANFAN)
|
||||
class QianFanLLM(BaseLLM):
|
||||
"""
|
||||
Refs
|
||||
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
|
||||
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
|
||||
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
|
||||
self.__init_qianfan()
|
||||
self.cost_manager = CostManager(token_costs=self.token_costs)
|
||||
|
||||
def __init_qianfan(self):
|
||||
if self.config.access_key and self.config.secret_key:
|
||||
# for system level auth, use access_key and secret_key, recommended by official
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
|
||||
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
|
||||
elif self.config.api_key and self.config.secret_key:
|
||||
# for application level auth, use api_key and secret_key
|
||||
# set environment variable due to official recommendation
|
||||
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
|
||||
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
|
||||
else:
|
||||
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
|
||||
|
||||
support_system_pairs = [
|
||||
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
|
||||
("ERNIE-Bot-8k", "ernie_bot_8k"),
|
||||
("ERNIE-Bot", "completions"),
|
||||
("ERNIE-Bot-turbo", "eb-instant"),
|
||||
("ERNIE-Speed", "ernie_speed"),
|
||||
("EB-turbo-AppBuilder", "ai_apaas"),
|
||||
]
|
||||
if self.config.model in [pair[0] for pair in support_system_pairs]:
|
||||
# only some ERNIE models support
|
||||
self.use_system_prompt = True
|
||||
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
|
||||
self.use_system_prompt = True
|
||||
|
||||
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
|
||||
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
|
||||
|
||||
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
|
||||
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
|
||||
|
||||
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
|
||||
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
|
||||
self.aclient: ChatCompletion = qianfan.ChatCompletion()
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.config.temperature > 0:
|
||||
# different model has default temperature. only set when it's specified.
|
||||
kwargs["temperature"] = self.config.temperature
|
||||
if self.config.endpoint:
|
||||
kwargs["endpoint"] = self.config.endpoint
|
||||
elif self.config.model:
|
||||
kwargs["model"] = self.config.model
|
||||
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
model_or_endpoint = self.config.model or self.config.endpoint
|
||||
local_calc_usage = model_or_endpoint in self.token_costs
|
||||
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
|
||||
|
||||
def get_choice_text(self, resp: JsonBody) -> str:
|
||||
return resp.get("result", "")
|
||||
|
||||
def completion(self, messages: list[dict]) -> JsonBody:
|
||||
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
async for chunk in resp:
|
||||
content = chunk.body.get("result", "")
|
||||
usage = chunk.body.get("usage", {})
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
log_llm_stream("\n")
|
||||
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
@ -31,12 +31,18 @@ class SparkLLM(BaseLLM):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
|
|
|
|||
|
|
@ -5,21 +5,12 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
|
@ -47,22 +38,13 @@ class ZhiPuAILLM(BaseLLM):
|
|||
assert self.config.api_key
|
||||
self.api_key = self.config.api_key
|
||||
self.model = self.config.model # so far, it support glm-3-turbo、glm-4
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
self.llm = ZhiPuModelAPI(api_key=self.api_key)
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
|
|
@ -96,17 +78,3 @@ class ZhiPuAILLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Build a symbols repository from source code.
|
||||
|
||||
This script is designed to create a symbols repository from the provided source code.
|
||||
|
||||
@Time : 2023/11/17 17:58
|
||||
@Author : alexanderwu
|
||||
@File : repo_parser.py
|
||||
|
|
@ -15,15 +19,26 @@ from pathlib import Path
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str, aread
|
||||
from metagpt.utils.common import any_to_str, aread, remove_white_spaces
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class RepoFileInfo(BaseModel):
|
||||
"""
|
||||
Repository data element that represents information about a file.
|
||||
|
||||
Attributes:
|
||||
file (str): The name or path of the file.
|
||||
classes (List): A list of class names present in the file.
|
||||
functions (List): A list of function names present in the file.
|
||||
globals (List): A list of global variable names present in the file.
|
||||
page_info (List): A list of page-related information associated with the file.
|
||||
"""
|
||||
|
||||
file: str
|
||||
classes: List = Field(default_factory=list)
|
||||
functions: List = Field(default_factory=list)
|
||||
|
|
@ -32,6 +47,17 @@ class RepoFileInfo(BaseModel):
|
|||
|
||||
|
||||
class CodeBlockInfo(BaseModel):
|
||||
"""
|
||||
Repository data element representing information about a code block.
|
||||
|
||||
Attributes:
|
||||
lineno (int): The starting line number of the code block.
|
||||
end_lineno (int): The ending line number of the code block.
|
||||
type_name (str): The type or category of the code block.
|
||||
tokens (List): A list of tokens present in the code block.
|
||||
properties (Dict): A dictionary containing additional properties associated with the code block.
|
||||
"""
|
||||
|
||||
lineno: int
|
||||
end_lineno: int
|
||||
type_name: str
|
||||
|
|
@ -39,31 +65,395 @@ class CodeBlockInfo(BaseModel):
|
|||
properties: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassInfo(BaseModel):
|
||||
class DotClassAttribute(BaseModel):
|
||||
"""
|
||||
Repository data element representing a class attribute in dot format.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the class attribute.
|
||||
type_ (str): The type of the class attribute.
|
||||
default_ (str): The default value of the class attribute.
|
||||
description (str): A description of the class attribute.
|
||||
compositions (List[str]): A list of compositions associated with the class attribute.
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
type_: str = ""
|
||||
default_: str = ""
|
||||
description: str
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotClassAttribute":
|
||||
"""
|
||||
Parses dot format text and returns a DotClassAttribute object.
|
||||
|
||||
Args:
|
||||
v (str): Dot format text to be parsed.
|
||||
|
||||
Returns:
|
||||
DotClassAttribute: An instance of the DotClassAttribute class representing the parsed data.
|
||||
"""
|
||||
val = ""
|
||||
meet_colon = False
|
||||
meet_equals = False
|
||||
for c in v:
|
||||
if c == ":":
|
||||
meet_colon = True
|
||||
elif c == "=":
|
||||
meet_equals = True
|
||||
if not meet_colon:
|
||||
val += ":"
|
||||
meet_colon = True
|
||||
val += c
|
||||
if not meet_colon:
|
||||
val += ":"
|
||||
if not meet_equals:
|
||||
val += "="
|
||||
|
||||
cix = val.find(":")
|
||||
eix = val.rfind("=")
|
||||
name = val[0:cix].strip()
|
||||
type_ = val[cix + 1 : eix]
|
||||
default_ = val[eix + 1 :].strip()
|
||||
|
||||
type_ = remove_white_spaces(type_) # remove white space
|
||||
if type_ == "NoneType":
|
||||
type_ = ""
|
||||
if "Literal[" in type_:
|
||||
pre_l, literal, post_l = cls._split_literal(type_)
|
||||
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
|
||||
type_ = pre_l + literal + post_l
|
||||
else:
|
||||
type_ = re.sub(r"['\"]+", "", type_) # remove '"
|
||||
composition_val = type_
|
||||
|
||||
if default_ == "None":
|
||||
default_ = ""
|
||||
compositions = cls.parse_compositions(composition_val)
|
||||
return cls(name=name, type_=type_, default_=default_, description=v, compositions=compositions)
|
||||
|
||||
@staticmethod
|
||||
def parse_compositions(types_part) -> List[str]:
|
||||
"""
|
||||
Parses the type definition code block of source code and returns a list of compositions.
|
||||
|
||||
Args:
|
||||
types_part: The type definition code block to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of compositions extracted from the type definition code block.
|
||||
"""
|
||||
if not types_part:
|
||||
return []
|
||||
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
|
||||
types = modified_string.split("|")
|
||||
filters = {
|
||||
"str",
|
||||
"frozenset",
|
||||
"set",
|
||||
"int",
|
||||
"float",
|
||||
"complex",
|
||||
"bool",
|
||||
"dict",
|
||||
"list",
|
||||
"Union",
|
||||
"Dict",
|
||||
"Set",
|
||||
"Tuple",
|
||||
"NoneType",
|
||||
"None",
|
||||
"Any",
|
||||
"Optional",
|
||||
"Iterator",
|
||||
"Literal",
|
||||
"List",
|
||||
}
|
||||
result = set()
|
||||
for t in types:
|
||||
t = re.sub(r"['\"]+", "", t.strip())
|
||||
if t and t not in filters:
|
||||
result.add(t)
|
||||
return list(result)
|
||||
|
||||
@staticmethod
|
||||
def _split_literal(v):
|
||||
"""
|
||||
Parses the literal definition code block and returns three parts: pre-part, literal-part, and post-part.
|
||||
|
||||
Args:
|
||||
v: The literal definition code block to be parsed.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, str]: A tuple containing the pre-part, literal-part, and post-part of the code block.
|
||||
"""
|
||||
tag = "Literal["
|
||||
bix = v.find(tag)
|
||||
eix = len(v) - 1
|
||||
counter = 1
|
||||
for i in range(bix + len(tag), len(v) - 1):
|
||||
c = v[i]
|
||||
if c == "[":
|
||||
counter += 1
|
||||
continue
|
||||
if c == "]":
|
||||
counter -= 1
|
||||
if counter > 0:
|
||||
continue
|
||||
eix = i
|
||||
break
|
||||
pre_l = v[0:bix]
|
||||
post_l = v[eix + 1 :]
|
||||
pre_l = re.sub(r"['\"]", "", pre_l) # remove '"
|
||||
pos_l = re.sub(r"['\"]", "", post_l) # remove '"
|
||||
|
||||
return pre_l, v[bix : eix + 1], pos_l
|
||||
|
||||
@field_validator("compositions", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
"""
|
||||
Auto-sorts a list attribute after making changes.
|
||||
|
||||
Args:
|
||||
lst (List): The list attribute to be sorted.
|
||||
|
||||
Returns:
|
||||
List: The sorted list.
|
||||
"""
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class DotClassInfo(BaseModel):
|
||||
"""
|
||||
Repository data element representing information about a class in dot format.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the class.
|
||||
package (Optional[str]): The package to which the class belongs (optional).
|
||||
attributes (Dict[str, DotClassAttribute]): A dictionary of attributes associated with the class.
|
||||
methods (Dict[str, DotClassMethod]): A dictionary of methods associated with the class.
|
||||
compositions (List[str]): A list of compositions associated with the class.
|
||||
aggregations (List[str]): A list of aggregations associated with the class.
|
||||
"""
|
||||
|
||||
name: str
|
||||
package: Optional[str] = None
|
||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
||||
methods: Dict[str, str] = Field(default_factory=dict)
|
||||
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
|
||||
methods: Dict[str, DotClassMethod] = Field(default_factory=dict)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("compositions", "aggregations", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
"""
|
||||
Auto-sorts a list attribute after making changes.
|
||||
|
||||
Args:
|
||||
lst (List): The list attribute to be sorted.
|
||||
|
||||
Returns:
|
||||
List: The sorted list.
|
||||
"""
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class ClassRelationship(BaseModel):
|
||||
class DotClassRelationship(BaseModel):
|
||||
"""
|
||||
Repository data element representing a relationship between two classes in dot format.
|
||||
|
||||
Attributes:
|
||||
src (str): The source class of the relationship.
|
||||
dest (str): The destination class of the relationship.
|
||||
relationship (str): The type or nature of the relationship.
|
||||
label (Optional[str]): An optional label associated with the relationship.
|
||||
"""
|
||||
|
||||
src: str = ""
|
||||
dest: str = ""
|
||||
relationship: str = ""
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
class DotReturn(BaseModel):
|
||||
"""
|
||||
Repository data element representing a function or method return type in dot format.
|
||||
|
||||
Attributes:
|
||||
type_ (str): The type of the return.
|
||||
description (str): A description of the return type.
|
||||
compositions (List[str]): A list of compositions associated with the return type.
|
||||
"""
|
||||
|
||||
type_: str = ""
|
||||
description: str
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotReturn" | None:
|
||||
"""
|
||||
Parses the return type part of dot format text and returns a DotReturn object.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing the return type part to be parsed.
|
||||
|
||||
Returns:
|
||||
DotReturn | None: An instance of the DotReturn class representing the parsed return type,
|
||||
or None if parsing fails.
|
||||
"""
|
||||
if not v:
|
||||
return DotReturn(description=v)
|
||||
type_ = remove_white_spaces(v)
|
||||
compositions = DotClassAttribute.parse_compositions(type_)
|
||||
return cls(type_=type_, description=v, compositions=compositions)
|
||||
|
||||
@field_validator("compositions", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
"""
|
||||
Auto-sorts a list attribute after making changes.
|
||||
|
||||
Args:
|
||||
lst (List): The list attribute to be sorted.
|
||||
|
||||
Returns:
|
||||
List: The sorted list.
|
||||
"""
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class DotClassMethod(BaseModel):
|
||||
name: str
|
||||
args: List[DotClassAttribute] = Field(default_factory=list)
|
||||
return_args: Optional[DotReturn] = None
|
||||
description: str
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotClassMethod":
|
||||
"""
|
||||
Parses a dot format method text and returns a DotClassMethod object.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing method information to be parsed.
|
||||
|
||||
Returns:
|
||||
DotClassMethod: An instance of the DotClassMethod class representing the parsed method.
|
||||
"""
|
||||
bix = v.find("(")
|
||||
eix = v.rfind(")")
|
||||
rix = v.rfind(":")
|
||||
if rix < 0 or rix < eix:
|
||||
rix = eix
|
||||
name_part = v[0:bix].strip()
|
||||
args_part = v[bix + 1 : eix].strip()
|
||||
return_args_part = v[rix + 1 :].strip()
|
||||
|
||||
name = cls._parse_name(name_part)
|
||||
args = cls._parse_args(args_part)
|
||||
return_args = DotReturn.parse(return_args_part)
|
||||
aggregations = set()
|
||||
for i in args:
|
||||
aggregations.update(set(i.compositions))
|
||||
aggregations.update(set(return_args.compositions))
|
||||
|
||||
return cls(name=name, args=args, description=v, return_args=return_args, aggregations=list(aggregations))
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(v: str) -> str:
|
||||
"""
|
||||
Parses the dot format method name part and returns the method name.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing the method name part to be parsed.
|
||||
|
||||
Returns:
|
||||
str: The parsed method name.
|
||||
"""
|
||||
tags = [">", "</"]
|
||||
if tags[0] in v:
|
||||
bix = v.find(tags[0]) + len(tags[0])
|
||||
eix = v.rfind(tags[1])
|
||||
return v[bix:eix].strip()
|
||||
return v.strip()
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(v: str) -> List[DotClassAttribute]:
|
||||
"""
|
||||
Parses the dot format method arguments part and returns the parsed arguments.
|
||||
|
||||
Args:
|
||||
v (str): The dot format text containing the arguments part to be parsed.
|
||||
|
||||
Returns:
|
||||
str: The parsed method arguments.
|
||||
"""
|
||||
if not v:
|
||||
return []
|
||||
parts = []
|
||||
bix = 0
|
||||
counter = 0
|
||||
for i in range(0, len(v)):
|
||||
c = v[i]
|
||||
if c == "[":
|
||||
counter += 1
|
||||
continue
|
||||
elif c == "]":
|
||||
counter -= 1
|
||||
continue
|
||||
elif c == "," and counter == 0:
|
||||
parts.append(v[bix:i].strip())
|
||||
bix = i + 1
|
||||
parts.append(v[bix:].strip())
|
||||
|
||||
attrs = []
|
||||
for p in parts:
|
||||
if p:
|
||||
attr = DotClassAttribute.parse(p)
|
||||
attrs.append(attr)
|
||||
return attrs
|
||||
|
||||
|
||||
class RepoParser(BaseModel):
|
||||
"""
|
||||
Tool to build a symbols repository from a project directory.
|
||||
|
||||
Attributes:
|
||||
base_directory (Path): The base directory of the project.
|
||||
"""
|
||||
|
||||
base_directory: Path = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
@handle_exception(exception_type=Exception, default_return=[])
|
||||
def _parse_file(cls, file_path: Path) -> list:
|
||||
"""Parse a Python file in the repository."""
|
||||
"""
|
||||
Parses a Python file in the repository.
|
||||
|
||||
Args:
|
||||
file_path (Path): The path to the Python file to be parsed.
|
||||
|
||||
Returns:
|
||||
list: A list containing the parsed symbols from the file.
|
||||
"""
|
||||
return ast.parse(file_path.read_text()).body
|
||||
|
||||
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
|
||||
"""Extract class, function, and global variable information from the AST."""
|
||||
"""
|
||||
Extracts class, function, and global variable information from the Abstract Syntax Tree (AST).
|
||||
|
||||
Args:
|
||||
tree: The Abstract Syntax Tree (AST) of the Python file.
|
||||
file_path: The path to the Python file.
|
||||
|
||||
Returns:
|
||||
RepoFileInfo: A RepoFileInfo object containing the extracted information.
|
||||
"""
|
||||
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
|
||||
for node in tree:
|
||||
info = RepoParser.node_to_str(node)
|
||||
|
|
@ -81,11 +471,17 @@ class RepoParser(BaseModel):
|
|||
return file_info
|
||||
|
||||
def generate_symbols(self) -> List[RepoFileInfo]:
|
||||
"""
|
||||
Builds a symbol repository from '.py' and '.js' files in the project directory.
|
||||
|
||||
Returns:
|
||||
List[RepoFileInfo]: A list of RepoFileInfo objects containing the extracted information.
|
||||
"""
|
||||
files_classes = []
|
||||
directory = self.base_directory
|
||||
|
||||
matching_files = []
|
||||
extensions = ["*.py", "*.js"]
|
||||
extensions = ["*.py"]
|
||||
for ext in extensions:
|
||||
matching_files += directory.rglob(ext)
|
||||
for path in matching_files:
|
||||
|
|
@ -95,19 +491,38 @@ class RepoParser(BaseModel):
|
|||
|
||||
return files_classes
|
||||
|
||||
def generate_json_structure(self, output_path):
|
||||
"""Generate a JSON file documenting the repository structure."""
|
||||
def generate_json_structure(self, output_path: Path):
|
||||
"""
|
||||
Generates a JSON file documenting the repository structure.
|
||||
|
||||
Args:
|
||||
output_path (Path): The path to the JSON file to be generated.
|
||||
"""
|
||||
files_classes = [i.model_dump() for i in self.generate_symbols()]
|
||||
output_path.write_text(json.dumps(files_classes, indent=4))
|
||||
|
||||
def generate_dataframe_structure(self, output_path):
|
||||
"""Generate a DataFrame documenting the repository structure and save as CSV."""
|
||||
def generate_dataframe_structure(self, output_path: Path):
|
||||
"""
|
||||
Generates a DataFrame documenting the repository structure and saves it as a CSV file.
|
||||
|
||||
Args:
|
||||
output_path (Path): The path to the CSV file to be generated.
|
||||
"""
|
||||
files_classes = [i.model_dump() for i in self.generate_symbols()]
|
||||
df = pd.DataFrame(files_classes)
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
def generate_structure(self, output_path=None, mode="json") -> Path:
|
||||
"""Generate the structure of the repository as a specified format."""
|
||||
def generate_structure(self, output_path: str | Path = None, mode="json") -> Path:
|
||||
"""
|
||||
Generates the structure of the repository in a specified format.
|
||||
|
||||
Args:
|
||||
output_path (str | Path): The path to the output file or directory. Default is None.
|
||||
mode (str): The output format mode. Options: "json" (default), "csv", etc.
|
||||
|
||||
Returns:
|
||||
Path: The path to the generated output file or directory.
|
||||
"""
|
||||
output_file = self.base_directory / f"{self.base_directory.name}-structure.{mode}"
|
||||
output_path = Path(output_path) if output_path else output_file
|
||||
|
||||
|
|
@ -119,6 +534,16 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def node_to_str(node) -> CodeBlockInfo | None:
|
||||
"""
|
||||
Parses and converts an Abstract Syntax Tree (AST) node to a CodeBlockInfo object.
|
||||
|
||||
Args:
|
||||
node: The AST node to be converted.
|
||||
|
||||
Returns:
|
||||
CodeBlockInfo | None: A CodeBlockInfo object representing the parsed AST node,
|
||||
or None if the conversion fails.
|
||||
"""
|
||||
if isinstance(node, ast.Try):
|
||||
return None
|
||||
if any_to_str(node) == any_to_str(ast.Expr):
|
||||
|
|
@ -159,9 +584,19 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_expr(node) -> List:
|
||||
"""
|
||||
Parses an expression Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
node: The AST node representing an expression.
|
||||
|
||||
Returns:
|
||||
List: A list containing the parsed information from the expression node.
|
||||
"""
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
|
||||
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
|
||||
any_to_str(ast.Tuple): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
|
||||
}
|
||||
func = funcs.get(any_to_str(node.value))
|
||||
if func:
|
||||
|
|
@ -170,12 +605,30 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_name(n):
|
||||
"""
|
||||
Gets the 'name' value of an Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
n: The AST node.
|
||||
|
||||
Returns:
|
||||
The 'name' value of the AST node.
|
||||
"""
|
||||
if n.asname:
|
||||
return f"{n.name} as {n.asname}"
|
||||
return n.name
|
||||
|
||||
@staticmethod
|
||||
def _parse_if(n):
|
||||
"""
|
||||
Parses an 'if' statement Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
n: The AST node representing an 'if' statement.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the 'if' statement node.
|
||||
"""
|
||||
tokens = []
|
||||
try:
|
||||
if isinstance(n.test, ast.BoolOp):
|
||||
|
|
@ -187,10 +640,14 @@ class RepoParser(BaseModel):
|
|||
v = RepoParser._parse_variable(n.test.left)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
for item in n.test.comparators:
|
||||
v = RepoParser._parse_variable(item)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
if isinstance(n.test, ast.Name):
|
||||
v = RepoParser._parse_variable(n.test)
|
||||
tokens.append(v)
|
||||
if hasattr(n.test, "comparators"):
|
||||
for item in n.test.comparators:
|
||||
v = RepoParser._parse_variable(item)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
return tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Unsupported if: {n}, err:{e}")
|
||||
|
|
@ -198,6 +655,15 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_if_compare(n):
|
||||
"""
|
||||
Parses an 'if' condition Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
n: The AST node representing an 'if' condition.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the 'if' condition node.
|
||||
"""
|
||||
if hasattr(n, "left"):
|
||||
return RepoParser._parse_variable(n.left)
|
||||
else:
|
||||
|
|
@ -205,6 +671,15 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_variable(node):
|
||||
"""
|
||||
Parses a variable Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
node: The AST node representing a variable.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the variable node.
|
||||
"""
|
||||
try:
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: x.value,
|
||||
|
|
@ -213,7 +688,7 @@ class RepoParser(BaseModel):
|
|||
if hasattr(x.value, "id")
|
||||
else f"{x.attr}",
|
||||
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
|
||||
any_to_str(ast.Tuple): lambda x: "",
|
||||
any_to_str(ast.Tuple): lambda x: [d.value for d in x.dims],
|
||||
}
|
||||
func = funcs.get(any_to_str(node))
|
||||
if not func:
|
||||
|
|
@ -224,22 +699,42 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_assign(node):
|
||||
"""
|
||||
Parses an assignment Abstract Syntax Tree (AST) node.
|
||||
|
||||
Args:
|
||||
node: The AST node representing an assignment.
|
||||
|
||||
Returns:
|
||||
None or Parsed information from the assignment node.
|
||||
"""
|
||||
return [RepoParser._parse_variable(t) for t in node.targets]
|
||||
|
||||
async def rebuild_class_views(self, path: str | Path = None):
|
||||
"""
|
||||
Executes `pylint` to reconstruct the dot format class view repository file.
|
||||
|
||||
Args:
|
||||
path (str | Path): The path to the target directory or file. Default is None.
|
||||
"""
|
||||
if not path:
|
||||
path = self.base_directory
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
return
|
||||
init_file = path / "__init__.py"
|
||||
if not init_file.exists():
|
||||
raise ValueError("Failed to import module __init__ with error:No module named __init__.")
|
||||
command = f"pyreverse {str(path)} -o dot"
|
||||
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
|
||||
output_dir = path / "__dot__"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
|
||||
if result.returncode != 0:
|
||||
raise ValueError(f"{result}")
|
||||
class_view_pathname = path / "classes.dot"
|
||||
class_view_pathname = output_dir / "classes.dot"
|
||||
class_views = await self._parse_classes(class_view_pathname)
|
||||
relationship_views = await self._parse_class_relationships(class_view_pathname)
|
||||
packages_pathname = path / "packages.dot"
|
||||
packages_pathname = output_dir / "packages.dot"
|
||||
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
|
||||
class_views=class_views, relationship_views=relationship_views, path=path
|
||||
)
|
||||
|
|
@ -247,7 +742,17 @@ class RepoParser(BaseModel):
|
|||
packages_pathname.unlink(missing_ok=True)
|
||||
return class_views, relationship_views, package_root
|
||||
|
||||
async def _parse_classes(self, class_view_pathname):
|
||||
@staticmethod
|
||||
async def _parse_classes(class_view_pathname: Path) -> List[DotClassInfo]:
|
||||
"""
|
||||
Parses a dot format class view repository file.
|
||||
|
||||
Args:
|
||||
class_view_pathname (Path): The path to the dot format class view repository file.
|
||||
|
||||
Returns:
|
||||
List[DotClassInfo]: A list of DotClassInfo objects representing the parsed classes.
|
||||
"""
|
||||
class_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return class_views
|
||||
|
|
@ -258,22 +763,38 @@ class RepoParser(BaseModel):
|
|||
if not package_name:
|
||||
continue
|
||||
class_name, members, functions = re.split(r"(?<!\\)\|", info)
|
||||
class_info = ClassInfo(name=class_name)
|
||||
class_info = DotClassInfo(name=class_name)
|
||||
class_info.package = package_name
|
||||
for m in members.split("\n"):
|
||||
if not m:
|
||||
continue
|
||||
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
|
||||
class_info.attributes[member_name] = m
|
||||
attr = DotClassAttribute.parse(m)
|
||||
class_info.attributes[attr.name] = attr
|
||||
for i in attr.compositions:
|
||||
if i not in class_info.compositions:
|
||||
class_info.compositions.append(i)
|
||||
for f in functions.split("\n"):
|
||||
if not f:
|
||||
continue
|
||||
function_name, _ = f.split("(", 1)
|
||||
class_info.methods[function_name] = f
|
||||
method = DotClassMethod.parse(f)
|
||||
class_info.methods[method.name] = method
|
||||
for i in method.aggregations:
|
||||
if i not in class_info.compositions and i not in class_info.aggregations:
|
||||
class_info.aggregations.append(i)
|
||||
class_views.append(class_info)
|
||||
return class_views
|
||||
|
||||
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
|
||||
@staticmethod
|
||||
async def _parse_class_relationships(class_view_pathname: Path) -> List[DotClassRelationship]:
|
||||
"""
|
||||
Parses a dot format class view repository file.
|
||||
|
||||
Args:
|
||||
class_view_pathname (Path): The path to the dot format class view repository file.
|
||||
|
||||
Returns:
|
||||
List[DotClassRelationship]: A list of DotClassRelationship objects representing the parsed class relationships.
|
||||
"""
|
||||
relationship_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return relationship_views
|
||||
|
|
@ -287,7 +808,16 @@ class RepoParser(BaseModel):
|
|||
return relationship_views
|
||||
|
||||
@staticmethod
|
||||
def _split_class_line(line):
|
||||
def _split_class_line(line: str) -> (str, str):
|
||||
"""
|
||||
Parses a dot format line about class info and returns the class name part and class members part.
|
||||
|
||||
Args:
|
||||
line (str): The dot format line containing class information.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: A tuple containing the class name part and class members part.
|
||||
"""
|
||||
part_splitor = '" ['
|
||||
if part_splitor not in line:
|
||||
return None, None
|
||||
|
|
@ -305,14 +835,25 @@ class RepoParser(BaseModel):
|
|||
return class_name, info
|
||||
|
||||
@staticmethod
|
||||
def _split_relationship_line(line):
|
||||
def _split_relationship_line(line: str) -> DotClassRelationship:
|
||||
"""
|
||||
Parses a dot format line about the relationship of two classes and returns 'Generalize', 'Composite',
|
||||
or 'Aggregate'.
|
||||
|
||||
Args:
|
||||
line (str): The dot format line containing relationship information.
|
||||
|
||||
Returns:
|
||||
DotClassRelationship: The object of relationship representing either 'Generalize', 'Composite',
|
||||
or 'Aggregate' relationship.
|
||||
"""
|
||||
splitters = [" -> ", " [", "];"]
|
||||
idxs = []
|
||||
for tag in splitters:
|
||||
if tag not in line:
|
||||
return None
|
||||
idxs.append(line.find(tag))
|
||||
ret = ClassRelationship()
|
||||
ret = DotClassRelationship()
|
||||
ret.src = line[0 : idxs[0]].strip('"')
|
||||
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
|
||||
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
|
||||
|
|
@ -330,7 +871,16 @@ class RepoParser(BaseModel):
|
|||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _get_label(line):
|
||||
def _get_label(line: str) -> str:
|
||||
"""
|
||||
Parses a dot format line and returns the label information.
|
||||
|
||||
Args:
|
||||
line (str): The dot format line containing label information.
|
||||
|
||||
Returns:
|
||||
str: The label information parsed from the line.
|
||||
"""
|
||||
tag = 'label="'
|
||||
if tag not in line:
|
||||
return ""
|
||||
|
|
@ -340,6 +890,15 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
|
||||
"""
|
||||
Creates a mapping table between source code files' paths and module names.
|
||||
|
||||
Args:
|
||||
path (str | Path): The path to the source code files or directory.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping source code file paths to their corresponding module names.
|
||||
"""
|
||||
mappings = {
|
||||
str(path).replace("/", "."): str(path),
|
||||
}
|
||||
|
|
@ -363,8 +922,21 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _repair_namespaces(
|
||||
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
|
||||
) -> (List[ClassInfo], List[ClassRelationship], str):
|
||||
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
|
||||
) -> (List[DotClassInfo], List[DotClassRelationship], str):
|
||||
"""
|
||||
Augments namespaces to the path-prefixed classes and relationships.
|
||||
|
||||
Args:
|
||||
class_views (List[DotClassInfo]): List of DotClassInfo objects representing class views.
|
||||
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects representing
|
||||
relationships.
|
||||
path (str | Path): The path to the source code files or directory.
|
||||
|
||||
Returns:
|
||||
Tuple[List[DotClassInfo], List[DotClassRelationship], str]: A tuple containing the augmented class views,
|
||||
relationships, and the root path of the package.
|
||||
"""
|
||||
if not class_views:
|
||||
return [], [], ""
|
||||
c = class_views[0]
|
||||
|
|
@ -383,28 +955,49 @@ class RepoParser(BaseModel):
|
|||
|
||||
for c in class_views:
|
||||
c.package = RepoParser._repair_ns(c.package, new_mappings)
|
||||
for i in range(len(relationship_views)):
|
||||
v = relationship_views[i]
|
||||
for _, v in enumerate(relationship_views):
|
||||
v.src = RepoParser._repair_ns(v.src, new_mappings)
|
||||
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
|
||||
relationship_views[i] = v
|
||||
return class_views, relationship_views, root_path
|
||||
return class_views, relationship_views, str(path)[: len(root_path)]
|
||||
|
||||
@staticmethod
|
||||
def _repair_ns(package, mappings):
|
||||
def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
|
||||
"""
|
||||
Replaces the package-prefix with the namespace-prefix.
|
||||
|
||||
Args:
|
||||
package (str): The package to be repaired.
|
||||
mappings (Dict[str, str]): A dictionary mapping source code file paths to their corresponding packages.
|
||||
|
||||
Returns:
|
||||
str: The repaired namespace.
|
||||
"""
|
||||
file_ns = package
|
||||
ix = 0
|
||||
while file_ns != "":
|
||||
if file_ns not in mappings:
|
||||
ix = file_ns.rfind(".")
|
||||
file_ns = file_ns[0:ix]
|
||||
continue
|
||||
break
|
||||
if file_ns == "":
|
||||
return ""
|
||||
internal_ns = package[ix + 1 :]
|
||||
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
|
||||
return ns
|
||||
|
||||
@staticmethod
|
||||
def _find_root(full_key, package) -> str:
|
||||
def _find_root(full_key: str, package: str) -> str:
|
||||
"""
|
||||
Returns the package root path based on the key, which is the full path, and the package information.
|
||||
|
||||
Args:
|
||||
full_key (str): The full key representing the full path.
|
||||
package (str): The package information.
|
||||
|
||||
Returns:
|
||||
str: The package root path.
|
||||
"""
|
||||
left = full_key
|
||||
while left != "":
|
||||
if left in package:
|
||||
|
|
@ -417,5 +1010,14 @@ class RepoParser(BaseModel):
|
|||
return "." + full_key[0:ix]
|
||||
|
||||
|
||||
def is_func(node):
|
||||
def is_func(node) -> bool:
|
||||
"""
|
||||
Returns True if the given node represents a function.
|
||||
|
||||
Args:
|
||||
node: The Abstract Syntax Tree (AST) node.
|
||||
|
||||
Returns:
|
||||
bool: True if the node represents a function, False otherwise.
|
||||
"""
|
||||
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class Assistant(Role):
|
|||
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
|
||||
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
|
||||
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
|
||||
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
|
||||
rsp = await self.llm.aask(prompt, ["You are an action classifier"], stream=False)
|
||||
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
|
||||
return await self._plan(rsp, last_talk=last_talk)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from typing import Literal, Union
|
|||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions.mi.ask_review import ReviewConst
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.write_analysis_code import CheckData, WriteCodeWithTools
|
||||
from metagpt.actions.di.ask_review import ReviewConst
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.di.write_analysis_code import CheckData, WriteCodeWithTools
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.mi.write_analysis_code import DATA_INFO
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -32,9 +32,9 @@ Output a json following the format:
|
|||
"""
|
||||
|
||||
|
||||
class Interpreter(Role):
|
||||
name: str = "Ivy"
|
||||
profile: str = "Interpreter"
|
||||
class DataInterpreter(Role):
|
||||
name: str = "David"
|
||||
profile: str = "DataInterpreter"
|
||||
auto_run: bool = True
|
||||
use_plan: bool = True
|
||||
use_reflection: bool = False
|
||||
|
|
@ -20,7 +20,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
|
@ -32,7 +31,6 @@ from metagpt.actions.summarize_code import SummarizeCode
|
|||
from metagpt.actions.write_code_plan_and_change_an import WriteCodePlanAndChange
|
||||
from metagpt.const import (
|
||||
CODE_PLAN_AND_CHANGE_FILE_REPO,
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
|
|
@ -119,10 +117,10 @@ class Engineer(Role):
|
|||
|
||||
dependencies = {coding_context.design_doc.root_relative_path, coding_context.task_doc.root_relative_path}
|
||||
if self.config.inc:
|
||||
dependencies.add(os.path.join(CODE_PLAN_AND_CHANGE_FILE_REPO, CODE_PLAN_AND_CHANGE_FILENAME))
|
||||
dependencies.add(coding_context.code_plan_and_change_doc.root_relative_path)
|
||||
await self.project_repo.srcs.save(
|
||||
filename=coding_context.filename,
|
||||
dependencies=dependencies,
|
||||
dependencies=list(dependencies),
|
||||
content=coding_context.code_doc.content,
|
||||
)
|
||||
msg = Message(
|
||||
|
|
@ -206,7 +204,6 @@ class Engineer(Role):
|
|||
|
||||
async def _act_code_plan_and_change(self):
|
||||
"""Write code plan and change that guides subsequent WriteCode and WriteCodeReview"""
|
||||
logger.info("Writing code plan and change..")
|
||||
node = await self.rc.todo.run()
|
||||
code_plan_and_change = node.instruct_content.model_dump_json()
|
||||
dependencies = {
|
||||
|
|
@ -215,11 +212,12 @@ class Engineer(Role):
|
|||
self.rc.todo.i_context.design_filename,
|
||||
self.rc.todo.i_context.task_filename,
|
||||
}
|
||||
code_plan_and_change_filepath = Path(self.rc.todo.i_context.design_filename)
|
||||
await self.project_repo.docs.code_plan_and_change.save(
|
||||
filename=self.rc.todo.i_context.filename, content=code_plan_and_change, dependencies=dependencies
|
||||
filename=code_plan_and_change_filepath.name, content=code_plan_and_change, dependencies=dependencies
|
||||
)
|
||||
await self.project_repo.resources.code_plan_and_change.save(
|
||||
filename=Path(self.rc.todo.i_context.filename).with_suffix(".md").name,
|
||||
filename=code_plan_and_change_filepath.with_suffix(".md").name,
|
||||
content=node.content,
|
||||
dependencies=dependencies,
|
||||
)
|
||||
|
|
@ -269,15 +267,24 @@ class Engineer(Role):
|
|||
dependencies = {Path(i) for i in await dependency.get(old_code_doc.root_relative_path)}
|
||||
task_doc = None
|
||||
design_doc = None
|
||||
code_plan_and_change_doc = None
|
||||
for i in dependencies:
|
||||
if str(i.parent) == TASK_FILE_REPO:
|
||||
task_doc = await self.project_repo.docs.task.get(i.name)
|
||||
elif str(i.parent) == SYSTEM_DESIGN_FILE_REPO:
|
||||
design_doc = await self.project_repo.docs.system_design.get(i.name)
|
||||
elif str(i.parent) == CODE_PLAN_AND_CHANGE_FILE_REPO:
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(i.name)
|
||||
if not task_doc or not design_doc:
|
||||
logger.error(f'Detected source code "{filename}" from an unknown origin.')
|
||||
raise ValueError(f'Detected source code "{filename}" from an unknown origin.')
|
||||
context = CodingContext(filename=filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc)
|
||||
context = CodingContext(
|
||||
filename=filename,
|
||||
design_doc=design_doc,
|
||||
task_doc=task_doc,
|
||||
code_doc=old_code_doc,
|
||||
code_plan_and_change_doc=code_plan_and_change_doc,
|
||||
)
|
||||
return context
|
||||
|
||||
async def _new_coding_doc(self, filename, dependency):
|
||||
|
|
@ -296,6 +303,7 @@ class Engineer(Role):
|
|||
for filename in changed_task_files:
|
||||
design_doc = await self.project_repo.docs.system_design.get(filename)
|
||||
task_doc = await self.project_repo.docs.task.get(filename)
|
||||
code_plan_and_change_doc = await self.project_repo.docs.code_plan_and_change.get(filename)
|
||||
task_list = self._parse_tasks(task_doc)
|
||||
for task_filename in task_list:
|
||||
old_code_doc = await self.project_repo.srcs.get(task_filename)
|
||||
|
|
@ -303,9 +311,18 @@ class Engineer(Role):
|
|||
old_code_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path), filename=task_filename, content=""
|
||||
)
|
||||
context = CodingContext(
|
||||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
if not code_plan_and_change_doc:
|
||||
context = CodingContext(
|
||||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
else:
|
||||
context = CodingContext(
|
||||
filename=task_filename,
|
||||
design_doc=design_doc,
|
||||
task_doc=task_doc,
|
||||
code_doc=old_code_doc,
|
||||
code_plan_and_change_doc=code_plan_and_change_doc,
|
||||
)
|
||||
coding_doc = Document(
|
||||
root_path=str(self.project_repo.src_relative_path),
|
||||
filename=task_filename,
|
||||
|
|
@ -342,9 +359,17 @@ class Engineer(Role):
|
|||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
ctx.codes_filenames = filenames
|
||||
self.summarize_todos.append(SummarizeCode(i_context=ctx, context=self.context, llm=self.llm))
|
||||
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
|
||||
for i, act in enumerate(self.summarize_todos):
|
||||
if act.i_context.task_filename == new_summarize.i_context.task_filename:
|
||||
self.summarize_todos[i] = new_summarize
|
||||
new_summarize = None
|
||||
break
|
||||
if new_summarize:
|
||||
self.summarize_todos.append(new_summarize)
|
||||
if self.summarize_todos:
|
||||
self.set_todo(self.summarize_todos[0])
|
||||
self.summarize_todos.pop(0)
|
||||
|
||||
async def _new_code_plan_and_change_action(self):
|
||||
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
i = action
|
||||
self._init_action(i)
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
self.states.append(f"{len(self.actions) - 1}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ from pydantic import (
|
|||
)
|
||||
|
||||
from metagpt.const import (
|
||||
CODE_PLAN_AND_CHANGE_FILENAME,
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
|
|
@ -47,6 +46,7 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import DotClassInfo
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.serialize import (
|
||||
|
|
@ -613,6 +613,7 @@ class CodingContext(BaseContext):
|
|||
design_doc: Optional[Document] = None
|
||||
task_doc: Optional[Document] = None
|
||||
code_doc: Optional[Document] = None
|
||||
code_plan_and_change_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class TestingContext(BaseContext):
|
||||
|
|
@ -667,7 +668,6 @@ class BugFixContext(BaseContext):
|
|||
|
||||
|
||||
class CodePlanAndChangeContext(BaseModel):
|
||||
filename: str = CODE_PLAN_AND_CHANGE_FILENAME
|
||||
requirement: str = ""
|
||||
prd_filename: str = ""
|
||||
design_filename: str = ""
|
||||
|
|
@ -691,54 +691,64 @@ class CodePlanAndChangeContext(BaseModel):
|
|||
|
||||
|
||||
# mermaid class view
|
||||
class ClassMeta(BaseModel):
|
||||
class UMLClassMeta(BaseModel):
|
||||
name: str = ""
|
||||
abstraction: bool = False
|
||||
static: bool = False
|
||||
visibility: str = ""
|
||||
|
||||
@staticmethod
|
||||
def name_to_visibility(name: str) -> str:
|
||||
if name == "__init__":
|
||||
return "+"
|
||||
if name.startswith("__"):
|
||||
return "-"
|
||||
elif name.startswith("_"):
|
||||
return "#"
|
||||
return "+"
|
||||
|
||||
class ClassAttribute(ClassMeta):
|
||||
|
||||
class UMLClassAttribute(UMLClassMeta):
|
||||
value_type: str = ""
|
||||
default_value: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
if self.value_type:
|
||||
content += self.value_type + " "
|
||||
content += self.name
|
||||
content += self.value_type.replace(" ", "") + " "
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name
|
||||
if self.default_value:
|
||||
content += "="
|
||||
if self.value_type not in ["str", "string", "String"]:
|
||||
content += self.default_value
|
||||
else:
|
||||
content += '"' + self.default_value.replace('"', "") + '"'
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassMethod(ClassMeta):
|
||||
args: List[ClassAttribute] = Field(default_factory=list)
|
||||
class UMLClassMethod(UMLClassMeta):
|
||||
args: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
return_type: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
if self.return_type:
|
||||
content += ":" + self.return_type
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
content += " " + self.return_type.replace(" ", "")
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassView(ClassMeta):
|
||||
attributes: List[ClassAttribute] = Field(default_factory=list)
|
||||
methods: List[ClassMethod] = Field(default_factory=list)
|
||||
class UMLClassView(UMLClassMeta):
|
||||
attributes: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
methods: List[UMLClassMethod] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
||||
|
|
@ -748,3 +758,21 @@ class ClassView(ClassMeta):
|
|||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
content += "".join(["\t" for i in range(align)]) + "}\n"
|
||||
return content
|
||||
|
||||
@classmethod
|
||||
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
|
||||
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
|
||||
class_view = cls(name=dot_class_info.name, visibility=visibility)
|
||||
for i in dot_class_info.attributes.values():
|
||||
visibility = UMLClassAttribute.name_to_visibility(i.name)
|
||||
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
|
||||
class_view.attributes.append(attr)
|
||||
for i in dot_class_info.methods.values():
|
||||
visibility = UMLClassMethod.name_to_visibility(i.name)
|
||||
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
|
||||
for j in i.args:
|
||||
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
|
||||
method.args.append(arg)
|
||||
method.return_type = i.return_args.type_
|
||||
class_view.methods.append(method)
|
||||
return class_view
|
||||
|
|
|
|||
|
|
@ -2,14 +2,11 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.context import Context
|
||||
from metagpt.const import CONFIG_ROOT
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
|
||||
|
|
@ -30,6 +27,8 @@ def generate_repo(
|
|||
recover_path=None,
|
||||
) -> ProjectRepo:
|
||||
"""Run the startup logic. Can be called from CLI or other Python scripts."""
|
||||
from metagpt.config2 import config
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles import (
|
||||
Architect,
|
||||
Engineer,
|
||||
|
|
@ -122,7 +121,17 @@ def startup(
|
|||
)
|
||||
|
||||
|
||||
def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
|
||||
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
|
||||
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
|
||||
llm:
|
||||
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
|
||||
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
|
||||
base_url: "https://api.openai.com/v1" # or forward url / other llm url
|
||||
api_key: "YOUR_API_KEY"
|
||||
"""
|
||||
|
||||
|
||||
def copy_config_to():
|
||||
"""Initialize the configuration file for MetaGPT."""
|
||||
target_path = CONFIG_ROOT / "config2.yaml"
|
||||
|
||||
|
|
@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
|
|||
print(f"Existing configuration file backed up at {backup_path}")
|
||||
|
||||
# 复制文件
|
||||
shutil.copy(str(config_path), target_path)
|
||||
target_path.write_text(DEFAULT_CONFIG, encoding="utf-8")
|
||||
print(f"Configuration file initialized at {target_path}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import json
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.mi.ask_review import AskReview, ReviewConst
|
||||
from metagpt.actions.mi.write_plan import (
|
||||
from metagpt.actions.di.ask_review import AskReview, ReviewConst
|
||||
from metagpt.actions.di.write_plan import (
|
||||
WritePlan,
|
||||
precheck_update_plan_from_rsp,
|
||||
update_plan_from_rsp,
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ class TOTSolver(BaseSolver):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class InterpreterSolver(BaseSolver):
|
||||
"""InterpreterSolver: Write&Run code in the graph"""
|
||||
class DataInterpreterSolver(BaseSolver):
|
||||
"""DataInterpreterSolver: Write&Run code in the graph"""
|
||||
|
||||
async def solve(self):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -23,10 +23,10 @@ import platform
|
|||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
from typing import Any, Callable, List, Literal, Tuple, Union
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
|
|
@ -423,23 +423,109 @@ def is_send_to(message: "Message", addresses: set):
|
|||
def any_to_name(val):
|
||||
"""
|
||||
Convert a value to its name by extracting the last part of the dotted path.
|
||||
|
||||
:param val: The value to convert.
|
||||
|
||||
:return: The name of the value.
|
||||
"""
|
||||
return any_to_str(val).split(".")[-1]
|
||||
|
||||
|
||||
def concat_namespace(*args) -> str:
|
||||
return ":".join(str(value) for value in args)
|
||||
def concat_namespace(*args, delimiter: str = ":") -> str:
|
||||
"""Concatenate fields to create a unique namespace prefix.
|
||||
|
||||
Example:
|
||||
>>> concat_namespace('prefix', 'field1', 'field2', delimiter=":")
|
||||
'prefix:field1:field2'
|
||||
"""
|
||||
return delimiter.join(str(value) for value in args)
|
||||
|
||||
|
||||
def split_namespace(ns_class_name: str) -> List[str]:
|
||||
return ns_class_name.split(":")
|
||||
def split_namespace(ns_class_name: str, delimiter: str = ":", maxsplit: int = 1) -> List[str]:
|
||||
"""Split a namespace-prefixed name into its namespace-prefix and name parts.
|
||||
|
||||
Example:
|
||||
>>> split_namespace('prefix:classname')
|
||||
['prefix', 'classname']
|
||||
|
||||
>>> split_namespace('prefix:module:class', delimiter=":", maxsplit=2)
|
||||
['prefix', 'module', 'class']
|
||||
"""
|
||||
return ns_class_name.split(delimiter, maxsplit=maxsplit)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
def auto_namespace(name: str, delimiter: str = ":") -> str:
|
||||
"""Automatically handle namespace-prefixed names.
|
||||
|
||||
If the input name is empty, returns a default namespace prefix and name.
|
||||
If the input name is not namespace-prefixed, adds a default namespace prefix.
|
||||
Otherwise, returns the input name unchanged.
|
||||
|
||||
Example:
|
||||
>>> auto_namespace('classname')
|
||||
'?:classname'
|
||||
|
||||
>>> auto_namespace('prefix:classname')
|
||||
'prefix:classname'
|
||||
|
||||
>>> auto_namespace('')
|
||||
'?:?'
|
||||
|
||||
>>> auto_namespace('?:custom')
|
||||
'?:custom'
|
||||
"""
|
||||
if not name:
|
||||
return f"?{delimiter}?"
|
||||
v = split_namespace(name, delimiter=delimiter)
|
||||
if len(v) < 2:
|
||||
return f"?{delimiter}{name}"
|
||||
return name
|
||||
|
||||
|
||||
def add_affix(text: str, affix: Literal["brace", "url", "none"] = "brace"):
|
||||
"""Add affix to encapsulate data.
|
||||
|
||||
Example:
|
||||
>>> add_affix("data", affix="brace")
|
||||
'{data}'
|
||||
|
||||
>>> add_affix("example.com", affix="url")
|
||||
'%7Bexample.com%7D'
|
||||
|
||||
>>> add_affix("text", affix="none")
|
||||
'text'
|
||||
"""
|
||||
mappings = {
|
||||
"brace": lambda x: "{" + x + "}",
|
||||
"url": lambda x: quote("{" + x + "}"),
|
||||
}
|
||||
encoder = mappings.get(affix, lambda x: x)
|
||||
return encoder(text)
|
||||
|
||||
|
||||
def remove_affix(text, affix: Literal["brace", "url", "none"] = "brace"):
|
||||
"""Remove affix to extract encapsulated data.
|
||||
|
||||
Args:
|
||||
text (str): The input text with affix to be removed.
|
||||
affix (str, optional): The type of affix used. Defaults to "brace".
|
||||
Supported affix types: "brace" for removing curly braces, "url" for URL decoding within curly braces.
|
||||
|
||||
Returns:
|
||||
str: The text with affix removed.
|
||||
|
||||
Example:
|
||||
>>> remove_affix('{data}', affix="brace")
|
||||
'data'
|
||||
|
||||
>>> remove_affix('%7Bexample.com%7D', affix="url")
|
||||
'example.com'
|
||||
|
||||
>>> remove_affix('text', affix="none")
|
||||
'text'
|
||||
"""
|
||||
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
|
||||
decoder = mappings.get(affix, lambda x: x)
|
||||
return decoder(text)
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
||||
|
|
@ -626,6 +712,54 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
return files
|
||||
|
||||
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
||||
|
||||
def remove_white_spaces(v: str) -> str:
|
||||
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
|
||||
|
||||
|
||||
async def aread_bin(filename: str | Path) -> bytes:
|
||||
"""Read binary file asynchronously.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The name or path of the file to be read.
|
||||
|
||||
Returns:
|
||||
bytes: The content of the file as bytes.
|
||||
|
||||
Example:
|
||||
>>> content = await aread_bin('example.txt')
|
||||
b'This is the content of the file.'
|
||||
|
||||
>>> content = await aread_bin(Path('example.txt'))
|
||||
b'This is the content of the file.'
|
||||
"""
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
content = await reader.read()
|
||||
return content
|
||||
|
||||
|
||||
async def awrite_bin(filename: str | Path, data: bytes):
|
||||
"""Write binary file asynchronously.
|
||||
|
||||
Args:
|
||||
filename (Union[str, Path]): The name or path of the file to be written.
|
||||
data (bytes): The binary data to be written to the file.
|
||||
|
||||
Example:
|
||||
>>> await awrite_bin('output.bin', b'This is binary data.')
|
||||
|
||||
>>> await awrite_bin(Path('output.bin'), b'Another set of binary data.')
|
||||
"""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="wb") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
|
||||
def is_coroutine_func(func: Callable) -> bool:
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
|
@ -689,3 +823,14 @@ def process_message(messages: Union[str, Message, list[dict], list[Message], lis
|
|||
else:
|
||||
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
|
||||
return processed_messages
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import NamedTuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
|
|
@ -29,6 +30,7 @@ class CostManager(BaseModel):
|
|||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
|
|
@ -39,14 +41,17 @@ class CostManager(BaseModel):
|
|||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
if prompt_tokens + completion_tokens == 0 or not model:
|
||||
return
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
if model not in TOKEN_COSTS:
|
||||
if model not in self.token_costs:
|
||||
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
|
||||
return
|
||||
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
prompt_tokens * self.token_costs[model]["prompt"]
|
||||
+ completion_tokens * self.token_costs[model]["completion"]
|
||||
) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
|
|
@ -101,3 +106,44 @@ class TokenCostManager(CostManager):
|
|||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : di_graph_repository.py
|
||||
@Desc : Graph repository based on DiGraph
|
||||
@Desc : Graph repository based on DiGraph.
|
||||
This script defines a graph repository class based on a directed graph (DiGraph), providing functionalities
|
||||
specific to handling directed relationships between entities.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -19,20 +21,41 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
|
|||
|
||||
|
||||
class DiGraphRepository(GraphRepository):
|
||||
"""Graph repository based on DiGraph."""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
super().__init__(name=name, **kwargs)
|
||||
self._repo = networkx.DiGraph()
|
||||
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
"""Insert a new triple into the directed graph repository.
|
||||
|
||||
Args:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
Example:
|
||||
await my_di_graph_repo.insert(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Adds a directed relationship: Node1 connects_to Node2
|
||||
"""
|
||||
self._repo.add_edge(subject, object_, predicate=predicate)
|
||||
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
"""Retrieve triples from the directed graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
List[SPO]: A list of SPO objects representing the selected triples.
|
||||
|
||||
Example:
|
||||
selected_triples = await my_di_graph_repo.select(subject="Node1", predicate="connects_to")
|
||||
# Retrieves directed relationships where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
result = []
|
||||
for s, o, p in self._repo.edges(data="predicate"):
|
||||
if subject and subject != s:
|
||||
|
|
@ -44,12 +67,41 @@ class DiGraphRepository(GraphRepository):
|
|||
result.append(SPO(subject=s, predicate=p, object_=o))
|
||||
return result
|
||||
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
"""Delete triples from the directed graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
int: The number of triples deleted from the repository.
|
||||
|
||||
Example:
|
||||
deleted_count = await my_di_graph_repo.delete(subject="Node1", predicate="connects_to")
|
||||
# Deletes directed relationships where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
|
||||
if not rows:
|
||||
return 0
|
||||
for r in rows:
|
||||
self._repo.remove_edge(r.subject, r.object_)
|
||||
return len(rows)
|
||||
|
||||
def json(self) -> str:
|
||||
"""Convert the directed graph repository to a JSON-formatted string."""
|
||||
m = networkx.node_link_data(self._repo)
|
||||
data = json.dumps(m)
|
||||
return data
|
||||
|
||||
async def save(self, path: str | Path = None):
|
||||
"""Save the directed graph repository to a JSON file.
|
||||
|
||||
Args:
|
||||
path (Union[str, Path], optional): The directory path where the JSON file will be saved.
|
||||
If not provided, the default path is taken from the 'root' key in the keyword arguments.
|
||||
"""
|
||||
data = self.json()
|
||||
path = path or self._kwargs.get("root")
|
||||
if not path.exists():
|
||||
|
|
@ -58,12 +110,21 @@ class DiGraphRepository(GraphRepository):
|
|||
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
"""Load a directed graph repository from a JSON file."""
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
@staticmethod
|
||||
async def load_from(pathname: str | Path) -> GraphRepository:
|
||||
"""Create and load a directed graph repository from a JSON file.
|
||||
|
||||
Args:
|
||||
pathname (Union[str, Path]): The path to the JSON file to be loaded.
|
||||
|
||||
Returns:
|
||||
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
|
||||
"""
|
||||
pathname = Path(pathname)
|
||||
name = pathname.with_suffix("").name
|
||||
root = pathname.parent
|
||||
|
|
@ -74,9 +135,16 @@ class DiGraphRepository(GraphRepository):
|
|||
|
||||
@property
|
||||
def root(self) -> str:
|
||||
"""Return the root directory path for the graph repository files."""
|
||||
return self._kwargs.get("root")
|
||||
|
||||
@property
|
||||
def pathname(self) -> Path:
|
||||
"""Return the path and filename to the graph repository file."""
|
||||
p = Path(self.root) / self.name
|
||||
return p.with_suffix(".json")
|
||||
|
||||
@property
|
||||
def repo(self):
|
||||
"""Get the underlying directed graph repository."""
|
||||
return self._repo
|
||||
|
|
|
|||
|
|
@ -4,21 +4,28 @@
|
|||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : graph_repository.py
|
||||
@Desc : Superclass for graph repository.
|
||||
@Desc : Superclass for graph repository. This script defines a superclass for a graph repository, providing a
|
||||
foundation for specific implementations.
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
"""Basic words for a Graph database.
|
||||
|
||||
This class defines a set of basic words commonly used in the context of a Graph database.
|
||||
"""
|
||||
|
||||
IS = "is"
|
||||
OF = "Of"
|
||||
ON = "On"
|
||||
|
|
@ -28,51 +35,149 @@ class GraphKeyword:
|
|||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_METHOD = "class_method"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_METHOD = "has_class_method"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_DETAIL = "has_detail"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_ARGS_DESC = "has_args_desc"
|
||||
HAS_TYPE_DESC = "has_type_desc"
|
||||
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
|
||||
HAS_CLASS_USE_CASE = "has_class_use_case"
|
||||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
HAS_PARTICIPANT = "has_participant"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
"""Graph repository record type.
|
||||
|
||||
This class represents a record in a graph repository with three components:
|
||||
- Subject: The subject of the triple.
|
||||
- Predicate: The predicate describing the relationship between the subject and the object.
|
||||
- Object: The object of the triple.
|
||||
|
||||
Attributes:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
Example:
|
||||
spo_record = SPO(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Represents a triple: Node1 connects_to Node2
|
||||
"""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object_: str
|
||||
|
||||
|
||||
class GraphRepository(ABC):
|
||||
"""Abstract base class for a Graph Repository.
|
||||
|
||||
This class defines the interface for a graph repository, providing methods for inserting, selecting,
|
||||
deleting, and saving graph data. Concrete implementations of this class must provide functionality
|
||||
for these operations.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self._repo_name = name
|
||||
self._kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def insert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
"""Insert a new triple into the graph repository.
|
||||
|
||||
@abstractmethod
|
||||
async def upsert(self, subject: str, predicate: str, object_: str):
|
||||
pass
|
||||
Args:
|
||||
subject (str): The subject of the triple.
|
||||
predicate (str): The predicate describing the relationship.
|
||||
object_ (str): The object of the triple.
|
||||
|
||||
@abstractmethod
|
||||
async def update(self, subject: str, predicate: str, object_: str):
|
||||
Example:
|
||||
await my_repository.insert(subject="Node1", predicate="connects_to", object_="Node2")
|
||||
# Inserts a triple: Node1 connects_to Node2 into the graph repository.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
|
||||
"""Retrieve triples from the graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
List[SPO]: A list of SPO objects representing the selected triples.
|
||||
|
||||
Example:
|
||||
selected_triples = await my_repository.select(subject="Node1", predicate="connects_to")
|
||||
# Retrieves triples where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
|
||||
"""Delete triples from the graph repository based on specified criteria.
|
||||
|
||||
Args:
|
||||
subject (str, optional): The subject of the triple to filter by.
|
||||
predicate (str, optional): The predicate describing the relationship to filter by.
|
||||
object_ (str, optional): The object of the triple to filter by.
|
||||
|
||||
Returns:
|
||||
int: The number of triples deleted from the repository.
|
||||
|
||||
Example:
|
||||
deleted_count = await my_repository.delete(subject="Node1", predicate="connects_to")
|
||||
# Deletes triples where Node1 is the subject and the predicate is 'connects_to'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def save(self):
|
||||
"""Save any changes made to the graph repository.
|
||||
|
||||
Example:
|
||||
await my_repository.save()
|
||||
# Persists any changes made to the graph repository.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the name of the graph repository."""
|
||||
return self._repo_name
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
|
||||
"""Insert information of RepoFileInfo into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with information from the given RepoFileInfo object.
|
||||
The function inserts triples related to various dimensions such as file type, class, class method, function,
|
||||
global variable, and page info.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is, [file type])
|
||||
- (?, has class, ?)
|
||||
- (?, is, [class])
|
||||
- (?, has class method, ?)
|
||||
- (?, has function, ?)
|
||||
- (?, is, [function])
|
||||
- (?, is, global variable)
|
||||
- (?, has page info, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
file_info (RepoFileInfo): The RepoFileInfo object containing information to be inserted.
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_file_info(my_graph_repo, my_file_info)
|
||||
# Updates 'my_graph_repo' with information from 'my_file_info'.
|
||||
"""
|
||||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
file_types = {".py": "python", ".js": "javascript"}
|
||||
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
|
||||
|
|
@ -95,13 +200,13 @@ class GraphRepository(ABC):
|
|||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(file_info.file, class_name, fn),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
# file -> function
|
||||
|
|
@ -133,7 +238,34 @@ class GraphRepository(ABC):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
|
||||
"""Insert dot format class information into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with class information from the given list of DotClassInfo objects.
|
||||
The function inserts triples related to various aspects of class views, including source code, file type, class,
|
||||
class property, class detail, method, composition, and aggregation.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is, source code)
|
||||
- (?, is, file type)
|
||||
- (?, has class, ?)
|
||||
- (?, is, class)
|
||||
- (?, has class property, ?)
|
||||
- (?, is, class property)
|
||||
- (?, has detail, ?)
|
||||
- (?, has method, ?)
|
||||
- (?, is composite of, ?)
|
||||
- (?, is aggregate of, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
class_views (List[DotClassInfo]): List of DotClassInfo objects containing class information to be inserted.
|
||||
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_class_views(my_graph_repo, [class_info1, class_info2])
|
||||
# Updates 'my_graph_repo' with class information from the provided list of DotClassInfo objects.
|
||||
"""
|
||||
for c in class_views:
|
||||
filename, _ = c.package.split(":", 1)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
|
|
@ -146,6 +278,7 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
|
||||
for vn, vt in c.attributes.items():
|
||||
# class -> property
|
||||
await graph_db.insert(
|
||||
|
|
@ -160,33 +293,61 @@ class GraphRepository(ABC):
|
|||
object_=GraphKeyword.CLASS_PROPERTY,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=vt.model_dump_json(),
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
if "</I>" in desc and "<I>" not in desc:
|
||||
logger.error(desc)
|
||||
for fn, ft in c.methods.items():
|
||||
# class -> function
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(c.package, fn),
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=ft.model_dump_json(),
|
||||
)
|
||||
for i in c.compositions:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
for i in c.aggregations:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_relationship_views(
|
||||
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
|
||||
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
|
||||
):
|
||||
"""Insert class relationships and labels into the specified graph repository.
|
||||
|
||||
This function updates the provided graph repository with class relationship information from the given list
|
||||
of DotClassRelationship objects. The function inserts triples representing relationships and labels between
|
||||
classes.
|
||||
|
||||
Triple Patterns:
|
||||
- (?, is relationship of, ?)
|
||||
- (?, is relationship on, ?)
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects containing
|
||||
class relationship information to be inserted.
|
||||
|
||||
Example:
|
||||
await update_graph_db_with_class_relationship_views(my_graph_repo, [relationship1, relationship2])
|
||||
# Updates 'my_graph_repo' with class relationship information from the provided list of DotClassRelationship objects.
|
||||
|
||||
"""
|
||||
for r in relationship_views:
|
||||
await graph_db.insert(
|
||||
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
|
||||
|
|
@ -198,3 +359,32 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
|
||||
object_=concat_namespace(r.dest, r.label),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
|
||||
"""Append namespace-prefixed information to relationship SPO (Subject-Predicate-Object) objects in the graph
|
||||
repository.
|
||||
|
||||
This function updates the provided graph repository by appending namespace-prefixed information to existing
|
||||
relationship SPO objects.
|
||||
|
||||
Args:
|
||||
graph_db (GraphRepository): The graph repository object to be updated.
|
||||
"""
|
||||
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mapping = defaultdict(list)
|
||||
for c in classes:
|
||||
name = split_namespace(c.subject)[-1]
|
||||
mapping[name].append(c.subject)
|
||||
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
|
||||
for r in rows:
|
||||
ns, class_ = split_namespace(r.object_)
|
||||
if ns != "?":
|
||||
continue
|
||||
val = mapping[class_]
|
||||
if len(val) != 1:
|
||||
continue
|
||||
ns_name = val[0]
|
||||
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from metagpt.const import (
|
|||
TASK_PDF_FILE_REPO,
|
||||
TEST_CODES_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
VISUAL_GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
code_summary: FileRepository
|
||||
sd_output: FileRepository
|
||||
code_plan_and_change: FileRepository
|
||||
graph_repo: FileRepository
|
||||
|
||||
def __init__(self, git_repo):
|
||||
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
|
||||
|
|
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
|
|||
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
|
||||
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
|
||||
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
|
||||
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
|
||||
|
||||
|
||||
class ProjectRepo(FileRepository):
|
||||
|
|
@ -133,6 +136,7 @@ class ProjectRepo(FileRepository):
|
|||
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
|
||||
if not code_files:
|
||||
return False
|
||||
return bool(code_files)
|
||||
|
||||
def with_src_path(self, path: str | Path) -> ProjectRepo:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
|
|||
logger.info(f"repair_json_format: {'}]'}")
|
||||
elif output.startswith("{") and output.endswith("]"):
|
||||
output = output[:-1] + "}"
|
||||
|
||||
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
|
||||
arr = output.split("\n")
|
||||
new_arr = []
|
||||
|
|
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
|
|||
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
|
||||
# problem, `"""` or `'''` without `,`
|
||||
new_line = f",{line}"
|
||||
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
|
||||
# backslash problem like \" in the output
|
||||
char = rline[col_no - 1]
|
||||
nearest_char_idx = rline[col_no:].find(char)
|
||||
new_line = (
|
||||
rline[: col_no - 1]
|
||||
+ "\\"
|
||||
+ rline[col_no - 1 : col_no + nearest_char_idx]
|
||||
+ "\\"
|
||||
+ rline[col_no + nearest_char_idx :]
|
||||
)
|
||||
elif '",' not in line and "," not in line and '"' not in line:
|
||||
new_line = f'{line}",'
|
||||
elif not line.endswith(","):
|
||||
|
|
|
|||
|
|
@ -35,9 +35,111 @@ TOKEN_COSTS = {
|
|||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
|
||||
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
|
||||
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
|
||||
"open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025},
|
||||
"open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007},
|
||||
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
|
||||
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
|
||||
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024},
|
||||
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
}
|
||||
|
||||
|
||||
"""
|
||||
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
|
||||
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
|
||||
"""
|
||||
QIANFAN_MODEL_TOKEN_COSTS = {
|
||||
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
|
||||
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
|
||||
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
|
||||
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
|
||||
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
|
||||
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
|
||||
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS = {
|
||||
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
|
||||
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
|
||||
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
|
||||
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
|
||||
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
|
||||
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
|
||||
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
|
||||
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
|
||||
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
|
||||
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
|
||||
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
|
||||
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
|
||||
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
|
||||
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
|
||||
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
|
||||
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
|
||||
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
|
||||
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
|
||||
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
|
||||
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
|
||||
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
|
||||
}
|
||||
|
||||
"""
|
||||
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
Different model has different detail page. Attention, some model are free for a limited time.
|
||||
"""
|
||||
DASHSCOPE_TOKEN_COSTS = {
|
||||
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
|
||||
"qwen-max": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
|
||||
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
|
||||
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
|
||||
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
|
||||
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
|
||||
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
|
||||
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
|
||||
}
|
||||
|
||||
|
||||
FIREWORKS_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
|
|
@ -61,6 +163,19 @@ TOKEN_MAX = {
|
|||
"glm-3-turbo": 128000,
|
||||
"glm-4": 128000,
|
||||
"gemini-pro": 32768,
|
||||
"moonshot-v1-8k": 8192,
|
||||
"moonshot-v1-32k": 32768,
|
||||
"moonshot-v1-128k": 128000,
|
||||
"open-mistral-7b": 8192,
|
||||
"open-mixtral-8x7b": 32768,
|
||||
"mistral-small-latest": 32768,
|
||||
"mistral-medium-latest": 32768,
|
||||
"mistral-large-latest": 32768,
|
||||
"claude-instant-1.2": 100000,
|
||||
"claude-2.0": 100000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-3-sonnet-20240229": 200000,
|
||||
"claude-3-opus-20240229": 200000,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
162
metagpt/utils/visual_graph_repo.py
Normal file
162
metagpt/utils/visual_graph_repo.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19
|
||||
@Author : mashenquan
|
||||
@File : visualize_graph.py
|
||||
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.schema import UMLClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class _VisualClassView(BaseModel):
|
||||
"""Protected class used by VisualGraphRepo internally.
|
||||
|
||||
Attributes:
|
||||
package (str): The package associated with the class.
|
||||
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
|
||||
generalizations (List[str]): List of generalizations for the class.
|
||||
compositions (List[str]): List of compositions for the class.
|
||||
aggregations (List[str]): List of aggregations for the class.
|
||||
"""
|
||||
|
||||
package: str
|
||||
uml: Optional[UMLClassView] = None
|
||||
generalizations: List[str] = Field(default_factory=list)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align: int = 1) -> str:
|
||||
"""Creates a Markdown Mermaid class diagram text.
|
||||
|
||||
Args:
|
||||
align (int): Indent count used for alignment.
|
||||
|
||||
Returns:
|
||||
str: The Markdown text representing the Mermaid class diagram.
|
||||
"""
|
||||
if not self.uml:
|
||||
return ""
|
||||
prefix = "\t" * align
|
||||
|
||||
mermaid_txt = self.uml.get_mermaid(align=align)
|
||||
for i in self.generalizations:
|
||||
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
|
||||
for i in self.compositions:
|
||||
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
|
||||
for i in self.aggregations:
|
||||
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
|
||||
return mermaid_txt
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Returns the class name without the namespace prefix."""
|
||||
return split_namespace(self.package)[-1]
|
||||
|
||||
|
||||
class VisualGraphRepo(ABC):
|
||||
"""Abstract base class for VisualGraphRepo."""
|
||||
|
||||
graph_db: GraphRepository
|
||||
|
||||
def __init__(self, graph_db):
|
||||
self.graph_db = graph_db
|
||||
|
||||
|
||||
class VisualDiGraphRepo(VisualGraphRepo):
|
||||
"""Implementation of VisualGraphRepo for DiGraph graph repository.
|
||||
|
||||
This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def load_from(cls, filename: str | Path):
|
||||
"""Load a VisualDiGraphRepo instance from a file."""
|
||||
graph_db = await DiGraphRepository.load_from(str(filename))
|
||||
return cls(graph_db=graph_db)
|
||||
|
||||
async def get_mermaid_class_view(self) -> str:
|
||||
"""
|
||||
Returns a Markdown Mermaid class diagram code block object.
|
||||
"""
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mermaid_txt = "classDiagram\n"
|
||||
for r in rows:
|
||||
v = await self._get_class_view(ns_class_name=r.subject)
|
||||
mermaid_txt += v.get_mermaid()
|
||||
return mermaid_txt
|
||||
|
||||
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
|
||||
"""Returns the Markdown Mermaid class diagram code block object for the specified class."""
|
||||
rows = await self.graph_db.select(subject=ns_class_name)
|
||||
class_view = _VisualClassView(package=ns_class_name)
|
||||
for r in rows:
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
|
||||
class_view.uml = UMLClassView.model_validate_json(r.object_)
|
||||
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.generalizations.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.compositions.append(name)
|
||||
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name = self._refine_name(name)
|
||||
if name:
|
||||
class_view.aggregations.append(name)
|
||||
return class_view
|
||||
|
||||
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
|
||||
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys."""
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
|
||||
@staticmethod
|
||||
def _refine_name(name: str) -> str:
|
||||
"""Removes impurity content from the given name.
|
||||
|
||||
Example:
|
||||
>>> _refine_name("int")
|
||||
""
|
||||
|
||||
>>> _refine_name('"Class1"')
|
||||
'Class1'
|
||||
|
||||
>>> _refine_name("pkg.Class1")
|
||||
"Class1"
|
||||
"""
|
||||
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
|
||||
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
|
||||
return ""
|
||||
if "." in name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
return name
|
||||
|
||||
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
|
||||
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys."""
|
||||
sequence_views = []
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
|
||||
for r in rows:
|
||||
sequence_views.append((r.subject, r.object_))
|
||||
return sequence_views
|
||||
Loading…
Add table
Add a link
Reference in a new issue