Merge pull request #960 from iorisa/feature/rebuild

feat: RFC197
This commit is contained in:
Guess 2024-03-05 16:30:35 +08:00 committed by GitHub
commit 5cae13fd0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
40 changed files with 2432 additions and 332 deletions

View file

@ -5,6 +5,13 @@ llm:
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
repair_llm_output: true # when the output is not a valid json, try to repair it
proxy: "YOUR_PROXY" # for LLM API requests
pricing_plan: "" # Optional. If invalid, it will be automatically filled in with the value of the `model`.
# Azure-exclusive pricing plan mappings
# - gpt-3.5-turbo 4k: "gpt-3.5-turbo-1106"
# - gpt-4-turbo: "gpt-4-turbo-preview"
# - gpt-4-turbo-vision: "gpt-4-vision-preview"
# - gpt-4 8k: "gpt-4"
# See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc.

View file

@ -4,10 +4,12 @@
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view.py
@Desc : Rebuild class view info
@Desc : Reconstructs class diagram from a source code project.
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
"""
import re
from pathlib import Path
from typing import Optional, Set, Tuple
import aiofiles
@ -21,86 +23,144 @@ from metagpt.const import (
GRAPH_REPO_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
from metagpt.utils.common import split_namespace
from metagpt.repo_parser import DotClassInfo, RepoParser
from metagpt.schema import UMLClassView
from metagpt.utils.common import concat_namespace, split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
"""
Reconstructs a graph repository about class diagram from a source code project.
Attributes:
graph_db (Optional[GraphRepository]): The optional graph repository.
"""
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
format (str): The format for the prompt schema.
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=Path(self.i_context))
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
await GraphRepository.rebuild_composition_relationship(self.graph_db)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.
file_info.file = self._align_root(file_info.file, direction, diff_path)
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_views(graph_db=graph_db)
await graph_db.save()
await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info)
await self._create_mermaid_class_views()
await self.graph_db.save()
async def _create_mermaid_class_views(self, graph_db):
path = Path(self.context.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
async def _create_mermaid_class_views(self) -> str:
"""Creates a Mermaid class diagram using data from the `graph_db` graph repository.
This method utilizes information stored in the graph repository to generate a Mermaid class diagram.
Returns:
mermaid class diagram file name.
"""
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
filename = str(pathname.with_suffix(".mmd"))
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
await writer.write(content)
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
class_distinct = set()
relationship_distinct = set()
for r in rows:
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
content = await self._create_mermaid_class(r.subject)
if content:
await writer.write(content)
class_distinct.add(r.subject)
for r in rows:
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
content, distinct = await self._create_mermaid_relationship(r.subject)
if content:
logger.debug(content)
await writer.write(content)
relationship_distinct.update(distinct)
logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}")
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
if self.i_context:
r_filename = Path(filename).relative_to(self.context.git_repo.workdir)
await self.graph_db.insert(
subject=self.i_context, predicate="hasMermaidClassDiagramFile", object_=str(r_filename)
)
logger.info(f"{self.i_context} hasMermaidClassDiagramFile {filename}")
return filename
async def _create_mermaid_class(self, ns_class_name) -> str:
"""Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository.
Args:
ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created.
Returns:
str: A Mermaid code block object in markdown representing the class diagram.
"""
fields = split_namespace(ns_class_name)
if len(fields) > 2:
# Ignore sub-class
return
return ""
class_view = ClassView(name=fields[1])
rows = await graph_db.select(subject=ns_class_name)
for r in rows:
name = split_namespace(r.object_)[-1]
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
attribute = ClassAttribute(
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
)
class_view.attributes.append(attribute)
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
class_view.methods.append(method)
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return ""
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
class_view = UMLClassView.load_dot_class_info(dot_class_info)
# update graph db
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
# update uml view
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
# update uml isCompositeOf
for c in dot_class_info.compositions:
await self.graph_db.insert(
subject=ns_class_name,
predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF,
object_=concat_namespace("?", c),
)
# update uml isAggregateOf
for a in dot_class_info.aggregations:
await self.graph_db.insert(
subject=ns_class_name,
predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF,
object_=concat_namespace("?", a),
)
content = class_view.get_mermaid(align=1)
logger.debug(content)
await file_writer.write(content)
distinct.add(ns_class_name)
return content
@staticmethod
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[Optional[str], Optional[Set]]:
"""Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository.
Args:
ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created.
Returns:
Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication.
"""
s_fields = split_namespace(ns_class_name)
if len(s_fields) > 2:
# Ignore sub-class
return
return None, None
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
mappings = {
@ -109,8 +169,9 @@ class RebuildClassView(Action):
AGGREGATION: " o-- ",
}
content = ""
distinct = set()
for p, v in predicates.items():
rows = await graph_db.select(subject=ns_class_name, predicate=p)
rows = await self.graph_db.select(subject=ns_class_name, predicate=p)
for r in rows:
o_fields = split_namespace(r.object_)
if len(o_fields) > 2:
@ -121,86 +182,26 @@ class RebuildClassView(Action):
distinct.add(link)
content += f"\t{link}\n"
if content:
logger.debug(content)
await file_writer.write(content)
@staticmethod
def _parse_name(name: str, language="python"):
pattern = re.compile(r"<I>(.*?)<\/I>")
result = re.search(pattern, name)
abstraction = ""
if result:
name = result.group(1)
abstraction = "*"
if name.startswith("__"):
visibility = "-"
elif name.startswith("_"):
visibility = "#"
else:
visibility = "+"
return name, visibility, abstraction
@staticmethod
async def _parse_variable_type(ns_name, graph_db) -> str:
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
if not rows:
return ""
vals = rows[0].object_.replace("'", "").split(":")
if len(vals) == 1:
return ""
val = vals[-1].strip()
return "" if val == "NoneType" else val + " "
@staticmethod
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
if not rows:
return
info = rows[0].object_.replace("'", "")
fs_tag = "("
ix = info.find(fs_tag)
fe_tag = "):"
eix = info.rfind(fe_tag)
if eix < 0:
fe_tag = ")"
eix = info.rfind(fe_tag)
args_info = info[ix + len(fs_tag) : eix].strip()
method.return_type = info[eix + len(fe_tag) :].strip()
if method.return_type == "None":
method.return_type = ""
if "(" in method.return_type:
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
# parse args
if not args_info:
return
splitter_ixs = []
cost = 0
for i in range(len(args_info)):
if args_info[i] == "[":
cost += 1
elif args_info[i] == "]":
cost -= 1
if args_info[i] == "," and cost == 0:
splitter_ixs.append(i)
splitter_ixs.append(len(args_info))
args = []
ix = 0
for eix in splitter_ixs:
args.append(args_info[ix:eix])
ix = eix + 1
for arg in args:
parts = arg.strip().split(":")
if len(parts) == 1:
method.args.append(ClassAttribute(name=parts[0].strip()))
continue
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
return content, distinct
@staticmethod
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
"""Returns the difference between the root path and the path information represented in the package name.
Args:
path_root (Path): The root path.
package_root (Path): The package root path.
Returns:
Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part.
Example:
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
"-", "metagpt"
>>> _diff_path(path_root=Path("/Users/x/github/MetaGPT/metagpt"), package_root=Path("/Users/x/github/MetaGPT/metagpt"))
"=", "."
"""
if len(str(path_root)) > len(str(package_root)):
return "+", str(path_root.relative_to(package_root))
if len(str(path_root)) < len(str(package_root)):
@ -208,7 +209,24 @@ class RebuildClassView(Action):
return "=", "."
@staticmethod
def _align_root(path: str, direction: str, diff_path: str):
def _align_root(path: str, direction: str, diff_path: str) -> str:
"""Aligns the path to the same root represented by `diff_path`.
Args:
path (str): The path to be aligned.
direction (str): The direction of alignment ('+', '-', '=').
diff_path (str): The path representing the difference.
Returns:
str: The aligned path.
Example:
>>> _align_root(path="metagpt/software_company.py", direction="+", diff_path="MetaGPT")
"MetaGPT/metagpt/software_company.py"
>>> _align_root(path="metagpt/software_company.py", direction="-", diff_path="metagpt")
"software_company.py"
"""
if direction == "=":
return path
if direction == "+":

View file

@ -4,34 +4,213 @@
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view.py
@Desc : Rebuild sequence view info
@Desc : Reconstruct sequence view information through reverse engineering.
Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt
"""
from __future__ import annotations
import re
from datetime import datetime
from pathlib import Path
from typing import List
from typing import List, Optional
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
from metagpt.repo_parser import CodeBlockInfo, DotClassInfo
from metagpt.schema import UMLClassView
from metagpt.utils.common import (
add_affix,
aread,
auto_namespace,
concat_namespace,
general_after_log,
list_files,
parse_json_code_block,
read_file_block,
split_namespace,
)
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword
from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository
class ReverseUseCase(BaseModel):
"""
Represents a reverse engineered use case.
Attributes:
description (str): A description of the reverse use case.
inputs (List[str]): List of inputs for the reverse use case.
outputs (List[str]): List of outputs for the reverse use case.
actors (List[str]): List of actors involved in the reverse use case.
steps (List[str]): List of steps for the reverse use case.
reason (str): The reason behind the reverse use case.
"""
description: str
inputs: List[str]
outputs: List[str]
actors: List[str]
steps: List[str]
reason: str
class ReverseUseCaseDetails(BaseModel):
"""
Represents details of a reverse engineered use case.
Attributes:
description (str): A description of the reverse use case details.
use_cases (List[ReverseUseCase]): List of reverse use cases.
relationship (List[str]): List of relationships associated with the reverse use case details.
"""
description: str
use_cases: List[ReverseUseCase]
relationship: List[str]
class RebuildSequenceView(Action):
async def run(self, with_messages=None, format=config.prompt_schema):
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
for entry in entries:
await self._rebuild_sequence_view(entry, graph_db)
await graph_db.save()
"""
Represents an action to reconstruct sequence view through reverse engineering.
@staticmethod
async def _search_main_entry(graph_db) -> List:
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
Attributes:
graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations.
"""
graph_db: Optional[GraphRepository] = None
async def run(self, with_messages=None, format=config.prompt_schema):
"""
Implementation of `Action`'s `run` method.
Args:
with_messages (Optional[Type]): An optional argument specifying messages to react to.
format (str): The format for the prompt schema.
"""
graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name
self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
if not self.i_context:
entries = await self._search_main_entry()
else:
entries = [SPO(subject=self.i_context, predicate="", object_="")]
for entry in entries:
await self._rebuild_main_sequence_view(entry)
while await self._merge_sequence_view(entry):
pass
await self.graph_db.save()
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_main_sequence_view(self, entry: SPO):
"""
Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering.
Args:
entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the
subject `__name__:__main__`.
"""
filename = entry.subject.split(":", 1)[0]
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
classes = []
prefix = filename + ":"
for r in rows:
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
participants = set()
class_details = []
class_views = []
for c in classes:
detail = await self._get_class_detail(c.subject)
if not detail:
continue
class_details.append(detail)
view = await self._get_uml_class_view(c.subject)
if view:
class_views.append(view)
actors = await self._get_participants(c.subject)
participants.update(set(actors))
use_case_blocks = []
for c in classes:
use_cases = await self._get_class_use_cases(c.subject)
use_case_blocks.append(use_cases)
prompt_blocks = ["## Use Cases\n" + "\n".join(use_case_blocks)]
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += "\n\n".join([c.get_mermaid() for c in class_views])
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += await self._get_source_code(filename)
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
"You are a python code to Mermaid Sequence Diagram translator in function detail.",
"Translate the given markdown text to a Mermaid Sequence Diagram.",
"Return the merged Mermaid sequence diagram in a markdown code block format.",
],
stream=False,
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.insert(
subject=entry.subject,
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
)
for c in classes:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
await self.graph_db.save()
async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram.
Args:
entry (SPO): The SPO object representing the relationship in the graph database.
Returns:
bool: True if additional information has been augmented, otherwise False.
"""
new_participant = await self._search_new_participant(entry)
if not new_participant:
return False
await self._merge_participant(entry, new_participant)
return True
async def _search_main_entry(self) -> List:
"""
Asynchronously searches for the SPO object that is related to `__name__:__main__`.
Returns:
List: A list containing information about the main entry in the sequence diagram.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
for r in rows:
@ -39,24 +218,405 @@ class RebuildSequenceView(Action):
entries.append(r)
return entries
async def _rebuild_sequence_view(self, entry, graph_db):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_use_case(self, ns_class_name: str):
"""
Asynchronously reconstructs the use case for the provided namespace-prefixed class name.
Args:
ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
if rows:
return
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
detail = await self._get_class_detail(ns_class_name)
if not detail:
return
participants = set()
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
class_view = await self._get_uml_class_view(ns_class_name)
source_code = await self._get_source_code(ns_class_name)
# prompt_blocks = [
# "## Instruction\n"
# "You are a python code to UML 2.0 Use Case translator.\n"
# 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".\n'
# "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
# 'conflict with the information in "Mermaid Class Views".\n'
# 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
# "system interactions with the internal system.\n"
# ]
prompt_blocks = []
block = "## Participants\n"
for p in participants:
block += f"- {p}\n"
prompt_blocks.append(block)
block = "## Mermaid Class Views\n```mermaid\n"
block += class_view.get_mermaid()
block += "\n```\n"
prompt_blocks.append(block)
block = "## Source Code\n```python\n"
block += source_code
block += "\n```\n"
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)
# class _UseCase(BaseModel):
# description: str = Field(default="...", description="Describes about what the use case to do")
# inputs: List[str] = Field(default=["input name 1", "input name 2"],
# description="Lists the input names of the use case from external sources")
# outputs: List[str] = Field(default=["output name 1", "output name 2"],
# description="Lists the output names of the use case to external sources")
# actors: List[str] = Field(default=["actor name 1", "actor name 2"],
# description="Lists the participant actors of the use case")
# steps: List[str] = Field(default=["Step 1", "Step 2"],
# description="Lists the steps about how the use case works step by step")
# reason: str = Field(default="Because ...",
# description="Explaining under what circumstances would the external system execute this use case.")
#
#
# class _UseCaseList(BaseModel):
# description: str = Field(default="...",
# description="A summary explains what the whole source code want to do")
# use_cases: List[_UseCase] = Field(default=[
# {
# "description": "Describes about what the use case to do",
# "inputs": ["input name 1", "input name 2"],
# "outputs": ["output name 1", "output name 2"],
# "actors": ["actor name 1", "actor name 2"],
# "steps": ["Step 1", "Step 2"],
# "reason": "Because ..."
# }
# ], description="List all use cases.")
# relationship: List[str] = Field(default=["use case 1 ..."],
# description="Lists all the descriptions of relationship among these use cases")
# rsp = await ActionNode.from_pydantic(_UseCaseList).fill(context=prompt, llm=self.llm)
rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
"You are a python code to UML 2.0 Use Case translator.",
'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".',
"The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not "
'conflict with the information in "Mermaid Class Views".',
'The section under `if __name__ == "__main__":` of "Source Code" contains information about external '
"system interactions with the internal system.",
"Return a markdown JSON object with:\n"
'- a "description" key to explain what the whole source code want to do;\n'
'- a "use_cases" key list all use cases, each use case in the list should including a `description` '
"key describes about what the use case to do, a `inputs` key lists the input names of the use case "
"from external sources, a `outputs` key lists the output names of the use case to external sources, "
"a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how "
"the use case works step by step, a `reason` key explaining under what circumstances would the "
"external system execute this use case.\n"
'- a "relationship" key lists all the descriptions of relationship among these use cases.\n',
],
stream=False,
)
code_blocks = parse_json_code_block(rsp)
for block in code_blocks:
detail = ReverseUseCaseDetails.model_validate_json(block)
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
await self.graph_db.save()
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _rebuild_sequence_view(self, ns_class_name: str):
"""
Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name.
Args:
ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed.
"""
await self._rebuild_use_case(ns_class_name)
prompts_blocks = []
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
await self.graph_db.save()
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
participants = await self._get_participants(ns_class_name)
block = "## Participants\n" + "\n".join([f"- {s}" for s in participants])
prompts_blocks.append(block)
view = await self._get_uml_class_view(ns_class_name)
block = "## Mermaid Class Views\n```mermaid\n"
block += view.get_mermaid()
block += "\n```\n"
prompts_blocks.append(block)
block = "## Source Code\n```python\n"
block += await self._get_source_code(ns_class_name)
block += "\n```\n"
prompts_blocks.append(block)
prompt = "\n---\n".join(prompts_blocks)
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a Mermaid Sequence Diagram translator in function detail.",
"Translate the markdown text to a Mermaid Sequence Diagram.",
"Return a markdown mermaid code block.",
],
stream=False,
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.save()
async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO
object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list.
Returns:
List[str]: A list of participants in the sequence diagram.
"""
participants = set()
detail = await self._get_class_detail(ns_class_name)
if not detail:
return []
participants.update(set(detail.compositions))
participants.update(set(detail.aggregations))
return list(participants)
async def _get_class_use_cases(self, ns_class_name: str) -> str:
"""
Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information.
Returns:
str: A string containing the assembled context about the use case information.
"""
block = ""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE)
for i, r in enumerate(rows):
detail = ReverseUseCaseDetails.model_validate_json(r.object_)
block += f"\n### {i + 1}. {detail.description}"
for j, use_case in enumerate(detail.use_cases):
block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n"
block += "\n##### Inputs\n" + "\n".join([f"- {s}" for s in use_case.inputs])
block += "\n##### Outputs\n" + "\n".join([f"- {s}" for s in use_case.outputs])
block += "\n##### Actors\n" + "\n".join([f"- {s}" for s in use_case.actors])
block += "\n##### Steps\n" + "\n".join([f"- {s}" for s in use_case.steps])
block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship])
return block + "\n"
async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None:
"""
Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve class details.
Returns:
Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details,
or None if the details are not available.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return None
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
return dot_class_info
async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None:
"""
Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details.
Returns:
Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details,
or None if the details are not available.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not rows:
return None
class_view = UMLClassView.model_validate_json(rows[0].object_)
return class_view
async def _get_source_code(self, ns_class_name: str) -> str:
"""
Asynchronously retrieves the source code of the namespace-prefixed SPO object.
Args:
ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code.
Returns:
str: A string containing the source code of the specified namespace-prefixed class.
"""
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO)
filename = split_namespace(ns_class_name=ns_class_name)[0]
if not rows:
src_filename = RebuildSequenceView._get_full_filename(root=self.i_context, pathname=filename)
if not src_filename:
return ""
return await aread(filename=src_filename, encoding="utf-8")
code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_)
return await read_file_block(
filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno
)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
logger.info(data)
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
"""
Convert package name to the full path of the module.
Args:
root (Union[str, Path]): The root path or string representing the package.
pathname (Union[str, Path]): The pathname or string representing the module.
Returns:
Union[Path, None]: The full path of the module, or None if the path cannot be determined.
Examples:
If `root`(workdir) is "/User/xxx/github/MetaGPT/metagpt", and the `pathname` is
"metagpt/management/skill_manager.py", then the returned value will be
"/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py"
"""
if re.match(r"^/.+", pathname):
return pathname
files = list_files(root=root)
postfix = "/" + str(pathname)
for i in files:
if str(i).endswith(postfix):
return i
return None
@staticmethod
def parse_participant(mermaid_sequence_diagram: str) -> List[str]:
"""
Parses the provided Mermaid sequence diagram and returns the list of participants.
Args:
mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed.
Returns:
List[str]: A list of participants extracted from the sequence diagram.
"""
pattern = r"participant ([a-zA-Z\.0-9_]+)"
matches = re.findall(pattern, mermaid_sequence_diagram)
matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches]
return matches
async def _search_new_participant(self, entry: SPO) -> str | None:
"""
Asynchronously retrieves a participant whose sequence diagram has not been augmented.
Args:
entry (SPO): The SPO object representing the relationship in the graph database.
Returns:
Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found.
"""
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
if not rows:
return None
sequence_view = rows[0].object_
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT)
merged_participants = []
for r in rows:
name = split_namespace(r.object_)[-1]
merged_participants.append(name)
participants = self.parse_participant(sequence_view)
for p in participants:
if p in merged_participants:
continue
return p
return None
@retry(
wait=wait_random_exponential(min=1, max=20),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _merge_participant(self, entry: SPO, class_name: str):
"""
Augments the sequence diagram of `class_name` to the sequence diagram of `entry`.
Args:
entry (SPO): The SPO object representing the base sequence diagram.
class_name (str): The class name whose sequence diagram is to be augmented.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
participants = []
for r in rows:
name = split_namespace(r.subject)[-1]
if name == class_name:
participants.append(r)
if len(participants) == 0: # external participants
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
await self.graph_db.save()
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
await self.graph_db.save()
return
participant = participants[0]
await self._rebuild_sequence_view(participant.subject)
sequence_views = await self.graph_db.select(
subject=participant.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW
)
if not sequence_views: # external class
return
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
prompt = f"```mermaid\n{sequence_views[0].object_}\n```\n---\n```mermaid\n{rows[0].object_}\n```"
rsp = await self.llm.aask(
prompt,
system_msgs=[
"You are a tool to merge sequence diagrams into one.",
"Participants with the same name are considered identical.",
"Return the merged Mermaid sequence diagram in a markdown code block format.",
],
stream=False,
)
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.insert(
subject=entry.subject,
predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER,
object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)),
)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
await self.graph_db.save()

View file

@ -1,16 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view_an.py
"""
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC2
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
key="Program call flow",
expected_type=str,
instruction='Translate the "context" content into "format example" format.',
example=MMC2,
)

View file

@ -50,6 +50,7 @@ class ArgumentsParingAction(Action):
rsp = await self.llm.aask(
msg=prompt,
system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."],
stream=False,
)
logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}")
self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp)

View file

@ -92,7 +92,7 @@ class TalkAction(Action):
async def run(self, with_message=None, **kwargs) -> Message:
msg, format_msgs, system_msgs = self.aask_args
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs)
rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False)
self.rsp = Message(content=rsp, role="assistant", cause_by=self)
return self.rsp

View file

@ -47,6 +47,7 @@ class LLMConfig(YamlModel):
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters.
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None

View file

@ -104,6 +104,7 @@ CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary"
RESOURCES_FILE_REPO = "resources"
SD_OUTPUT_FILE_REPO = "resources/sd_output"
GRAPH_REPO_FILE_REPO = "docs/graph_repo"
VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db"
CLASS_VIEW_FILE_REPO = "docs/class_view"
YAPI_URL = "http://yapi.deepwisdomai.com/"

View file

@ -20,6 +20,7 @@ from langchain_community.document_loaders import (
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
@ -130,9 +131,12 @@ class IndexableDocument(Document):
if isinstance(data, pd.DataFrame):
validate_cols(content_col, data)
return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col)
else:
try:
content = data_path.read_text()
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
except Exception as e:
logger.debug(f"Load {str(data_path)} error: {e}")
content = ""
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
def _get_docs_and_metadatas_by_df(self) -> (list, list):
df = self.data

View file

@ -186,7 +186,7 @@ class BrainMemory(BaseModel):
summaries = [summary, command]
msg = "\n".join(summaries)
logger.debug(f"title ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
response = await llm.aask(msg=msg, system_msgs=[], stream=False)
logger.debug(f"title rsp: {response}")
return response
@ -201,11 +201,15 @@ class BrainMemory(BaseModel):
@staticmethod
async def _openai_is_related(text1, text2, llm, **kwargs):
command = (
f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there "
"any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
context = f"## Paragraph 1\n{text2}\n---\n## Paragraph 2\n{text1}\n"
rsp = await llm.aask(
msg=context,
system_msgs=[
"You are a tool capable of determining whether two paragraphs are semantically related."
'Return "TRUE" if "Paragraph 1" is semantically relevant to "Paragraph 2", otherwise return "FALSE".'
],
stream=False,
)
rsp = await llm.aask(msg=command, system_msgs=[])
result = True if "TRUE" in rsp else False
p2 = text2.replace("\n", "")
p1 = text1.replace("\n", "")
@ -223,12 +227,17 @@ class BrainMemory(BaseModel):
@staticmethod
async def _openai_rewrite(sentence: str, context: str, llm):
command = (
f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly "
f"supplement or rewrite the following text in brief and clear:\n{sentence}"
prompt = f"## Context\n{context}\n---\n## Sentence\n{sentence}\n"
rsp = await llm.aask(
msg=prompt,
system_msgs=[
'You are a tool augmenting the "Sentence" with information from the "Context".',
"Do not supplement the context with information that is not present, especially regarding the subject and object.",
"Return the augmented sentence.",
],
stream=False,
)
rsp = await llm.aask(msg=command, system_msgs=[])
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
logger.info(f"REWRITE:\nCommand: {prompt}\nRESULT: {rsp}\n")
return rsp
@staticmethod
@ -293,14 +302,14 @@ class BrainMemory(BaseModel):
"""Generate text summary"""
if len(text) < max_words:
return text
system_msgs = [
"You are a tool for summarizing and abstracting text.",
f"Return the summarized text to less than {max_words} words.",
]
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.llm.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
system_msgs.append("The generated summary should be in the same language as the original text.")
response = await self.llm.aask(msg=text, system_msgs=system_msgs, stream=False)
logger.debug(f"{text}\nsummary rsp: {response}")
return response
@staticmethod

View file

@ -6,8 +6,6 @@
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
@ -27,6 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan
def _make_client_kwargs(self) -> dict:
kwargs = dict(

View file

@ -6,11 +6,14 @@
@File : base_llm.py
@Desc : mashenquan, 2023/8/22. + try catch
"""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Dict, Optional, Union
from openai import AsyncOpenAI
from openai.types import CompletionUsage
from pydantic import BaseModel
from tenacity import (
after_log,
@ -25,6 +28,7 @@ from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.common import log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
class BaseLLM(ABC):
@ -38,6 +42,7 @@ class BaseLLM(ABC):
aclient: Optional[Union[AsyncOpenAI]] = None
cost_manager: Optional[CostManager] = None
model: Optional[str] = None
pricing_plan: Optional[str] = None
@abstractmethod
def __init__(self, config: LLMConfig):
@ -218,6 +223,20 @@ class BaseLLM(ABC):
"""
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
@handle_exception
def _update_costs(self, usage: CompletionUsage | Dict):
"""
Updates the costs based on the provided usage information.
"""
if self.config.calc_usage and usage and self.cost_manager:
if isinstance(usage, Dict):
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
else:
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -47,6 +47,7 @@ class GeminiLLM(BaseLLM):
self.__init_gemini(config)
self.config = config
self.model = "gemini-pro" # so far only one model
self.pricing_plan = self.config.pricing_plan or self.model
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: LLMConfig):

View file

@ -5,6 +5,8 @@
@File : metagpt_api.py
@Desc : MetaGPT LLM provider.
"""
from openai.types import CompletionUsage
from metagpt.configs.llm_config import LLMType
from metagpt.provider import OpenAILLM
from metagpt.provider.llm_provider_registry import register_provider
@ -12,4 +14,7 @@ from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMType.METAGPT)
class MetaGPTLLM(OpenAILLM):
pass
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
# The current billing is based on usage frequency. If there is a future billing logic based on the
# number of tokens, please refine the logic here accordingly.
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)

View file

@ -26,11 +26,12 @@ class OllamaLLM(BaseLLM):
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = TokenCostManager()
self.cost_manager = TokenCostManager()
def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
self.pricing_plan = self.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}

View file

@ -6,6 +6,7 @@
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from __future__ import annotations
import json
import re
@ -30,7 +31,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
from metagpt.utils.cost_manager import CostManager, TokenCostManager
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -52,6 +53,7 @@ class OpenAILLM(BaseLLM):
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -258,10 +260,9 @@ class OpenAILLM(BaseLLM):
if not self.config.calc_usage:
return usage
model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model"
try:
usage.prompt_tokens = count_message_tokens(messages, model)
usage.completion_tokens = count_string_tokens(rsp, model)
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
usage.completion_tokens = count_string_tokens(rsp, self.pricing_plan)
except Exception as e:
logger.warning(f"usage calculation failed: {e}")

View file

@ -38,6 +38,7 @@ class ZhiPuAILLM(BaseLLM):
assert self.config.api_key
self.api_key = self.config.api_key
self.model = self.config.model # so far, it support glm-3-turbo、glm-4
self.pricing_plan = self.config.pricing_plan or self.model
self.llm = ZhiPuModelAPI(api_key=self.api_key)
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:

View file

@ -1,6 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Build a symbols repository from source code.
This script is designed to create a symbols repository from the provided source code.
@Time : 2023/11/17 17:58
@Author : alexanderwu
@File : repo_parser.py
@ -15,15 +19,26 @@ from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
from metagpt.utils.common import any_to_str, aread
from metagpt.utils.common import any_to_str, aread, remove_white_spaces
from metagpt.utils.exceptions import handle_exception
class RepoFileInfo(BaseModel):
"""
Repository data element that represents information about a file.
Attributes:
file (str): The name or path of the file.
classes (List): A list of class names present in the file.
functions (List): A list of function names present in the file.
globals (List): A list of global variable names present in the file.
page_info (List): A list of page-related information associated with the file.
"""
file: str
classes: List = Field(default_factory=list)
functions: List = Field(default_factory=list)
@ -32,6 +47,17 @@ class RepoFileInfo(BaseModel):
class CodeBlockInfo(BaseModel):
"""
Repository data element representing information about a code block.
Attributes:
lineno (int): The starting line number of the code block.
end_lineno (int): The ending line number of the code block.
type_name (str): The type or category of the code block.
tokens (List): A list of tokens present in the code block.
properties (Dict): A dictionary containing additional properties associated with the code block.
"""
lineno: int
end_lineno: int
type_name: str
@ -39,31 +65,395 @@ class CodeBlockInfo(BaseModel):
properties: Dict = Field(default_factory=dict)
class ClassInfo(BaseModel):
class DotClassAttribute(BaseModel):
"""
Repository data element representing a class attribute in dot format.
Attributes:
name (str): The name of the class attribute.
type_ (str): The type of the class attribute.
default_ (str): The default value of the class attribute.
description (str): A description of the class attribute.
compositions (List[str]): A list of compositions associated with the class attribute.
"""
name: str = ""
type_: str = ""
default_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotClassAttribute":
"""
Parses dot format text and returns a DotClassAttribute object.
Args:
v (str): Dot format text to be parsed.
Returns:
DotClassAttribute: An instance of the DotClassAttribute class representing the parsed data.
"""
val = ""
meet_colon = False
meet_equals = False
for c in v:
if c == ":":
meet_colon = True
elif c == "=":
meet_equals = True
if not meet_colon:
val += ":"
meet_colon = True
val += c
if not meet_colon:
val += ":"
if not meet_equals:
val += "="
cix = val.find(":")
eix = val.rfind("=")
name = val[0:cix].strip()
type_ = val[cix + 1 : eix]
default_ = val[eix + 1 :].strip()
type_ = remove_white_spaces(type_) # remove white space
if type_ == "NoneType":
type_ = ""
if "Literal[" in type_:
pre_l, literal, post_l = cls._split_literal(type_)
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
type_ = pre_l + literal + post_l
else:
type_ = re.sub(r"['\"]+", "", type_) # remove '"
composition_val = type_
if default_ == "None":
default_ = ""
compositions = cls.parse_compositions(composition_val)
return cls(name=name, type_=type_, default_=default_, description=v, compositions=compositions)
@staticmethod
def parse_compositions(types_part) -> List[str]:
"""
Parses the type definition code block of source code and returns a list of compositions.
Args:
types_part: The type definition code block to be parsed.
Returns:
List[str]: A list of compositions extracted from the type definition code block.
"""
if not types_part:
return []
modified_string = re.sub(r"[\[\],\(\)]", "|", types_part)
types = modified_string.split("|")
filters = {
"str",
"frozenset",
"set",
"int",
"float",
"complex",
"bool",
"dict",
"list",
"Union",
"Dict",
"Set",
"Tuple",
"NoneType",
"None",
"Any",
"Optional",
"Iterator",
"Literal",
"List",
}
result = set()
for t in types:
t = re.sub(r"['\"]+", "", t.strip())
if t and t not in filters:
result.add(t)
return list(result)
@staticmethod
def _split_literal(v):
"""
Parses the literal definition code block and returns three parts: pre-part, literal-part, and post-part.
Args:
v: The literal definition code block to be parsed.
Returns:
Tuple[str, str, str]: A tuple containing the pre-part, literal-part, and post-part of the code block.
"""
tag = "Literal["
bix = v.find(tag)
eix = len(v) - 1
counter = 1
for i in range(bix + len(tag), len(v) - 1):
c = v[i]
if c == "[":
counter += 1
continue
if c == "]":
counter -= 1
if counter > 0:
continue
eix = i
break
pre_l = v[0:bix]
post_l = v[eix + 1 :]
pre_l = re.sub(r"['\"]", "", pre_l) # remove '"
pos_l = re.sub(r"['\"]", "", post_l) # remove '"
return pre_l, v[bix : eix + 1], pos_l
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class DotClassInfo(BaseModel):
"""
Repository data element representing information about a class in dot format.
Attributes:
name (str): The name of the class.
package (Optional[str]): The package to which the class belongs (optional).
attributes (Dict[str, DotClassAttribute]): A dictionary of attributes associated with the class.
methods (Dict[str, DotClassMethod]): A dictionary of methods associated with the class.
compositions (List[str]): A list of compositions associated with the class.
aggregations (List[str]): A list of aggregations associated with the class.
"""
name: str
package: Optional[str] = None
attributes: Dict[str, str] = Field(default_factory=dict)
methods: Dict[str, str] = Field(default_factory=dict)
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
methods: Dict[str, DotClassMethod] = Field(default_factory=dict)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
@field_validator("compositions", "aggregations", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class ClassRelationship(BaseModel):
class DotClassRelationship(BaseModel):
"""
Repository data element representing a relationship between two classes in dot format.
Attributes:
src (str): The source class of the relationship.
dest (str): The destination class of the relationship.
relationship (str): The type or nature of the relationship.
label (Optional[str]): An optional label associated with the relationship.
"""
src: str = ""
dest: str = ""
relationship: str = ""
label: Optional[str] = None
class DotReturn(BaseModel):
"""
Repository data element representing a function or method return type in dot format.
Attributes:
type_ (str): The type of the return.
description (str): A description of the return type.
compositions (List[str]): A list of compositions associated with the return type.
"""
type_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotReturn" | None:
"""
Parses the return type part of dot format text and returns a DotReturn object.
Args:
v (str): The dot format text containing the return type part to be parsed.
Returns:
DotReturn | None: An instance of the DotReturn class representing the parsed return type,
or None if parsing fails.
"""
if not v:
return DotReturn(description=v)
type_ = remove_white_spaces(v)
compositions = DotClassAttribute.parse_compositions(type_)
return cls(type_=type_, description=v, compositions=compositions)
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
"""
Auto-sorts a list attribute after making changes.
Args:
lst (List): The list attribute to be sorted.
Returns:
List: The sorted list.
"""
lst.sort()
return lst
class DotClassMethod(BaseModel):
name: str
args: List[DotClassAttribute] = Field(default_factory=list)
return_args: Optional[DotReturn] = None
description: str
aggregations: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotClassMethod":
"""
Parses a dot format method text and returns a DotClassMethod object.
Args:
v (str): The dot format text containing method information to be parsed.
Returns:
DotClassMethod: An instance of the DotClassMethod class representing the parsed method.
"""
bix = v.find("(")
eix = v.rfind(")")
rix = v.rfind(":")
if rix < 0 or rix < eix:
rix = eix
name_part = v[0:bix].strip()
args_part = v[bix + 1 : eix].strip()
return_args_part = v[rix + 1 :].strip()
name = cls._parse_name(name_part)
args = cls._parse_args(args_part)
return_args = DotReturn.parse(return_args_part)
aggregations = set()
for i in args:
aggregations.update(set(i.compositions))
aggregations.update(set(return_args.compositions))
return cls(name=name, args=args, description=v, return_args=return_args, aggregations=list(aggregations))
@staticmethod
def _parse_name(v: str) -> str:
"""
Parses the dot format method name part and returns the method name.
Args:
v (str): The dot format text containing the method name part to be parsed.
Returns:
str: The parsed method name.
"""
tags = [">", "</"]
if tags[0] in v:
bix = v.find(tags[0]) + len(tags[0])
eix = v.rfind(tags[1])
return v[bix:eix].strip()
return v.strip()
@staticmethod
def _parse_args(v: str) -> List[DotClassAttribute]:
"""
Parses the dot format method arguments part and returns the parsed arguments.
Args:
v (str): The dot format text containing the arguments part to be parsed.
Returns:
str: The parsed method arguments.
"""
if not v:
return []
parts = []
bix = 0
counter = 0
for i in range(0, len(v)):
c = v[i]
if c == "[":
counter += 1
continue
elif c == "]":
counter -= 1
continue
elif c == "," and counter == 0:
parts.append(v[bix:i].strip())
bix = i + 1
parts.append(v[bix:].strip())
attrs = []
for p in parts:
if p:
attr = DotClassAttribute.parse(p)
attrs.append(attr)
return attrs
class RepoParser(BaseModel):
"""
Tool to build a symbols repository from a project directory.
Attributes:
base_directory (Path): The base directory of the project.
"""
base_directory: Path = Field(default=None)
@classmethod
@handle_exception(exception_type=Exception, default_return=[])
def _parse_file(cls, file_path: Path) -> list:
"""Parse a Python file in the repository."""
"""
Parses a Python file in the repository.
Args:
file_path (Path): The path to the Python file to be parsed.
Returns:
list: A list containing the parsed symbols from the file.
"""
return ast.parse(file_path.read_text()).body
def extract_class_and_function_info(self, tree, file_path) -> RepoFileInfo:
"""Extract class, function, and global variable information from the AST."""
"""
Extracts class, function, and global variable information from the Abstract Syntax Tree (AST).
Args:
tree: The Abstract Syntax Tree (AST) of the Python file.
file_path: The path to the Python file.
Returns:
RepoFileInfo: A RepoFileInfo object containing the extracted information.
"""
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
@ -81,11 +471,17 @@ class RepoParser(BaseModel):
return file_info
def generate_symbols(self) -> List[RepoFileInfo]:
"""
Builds a symbol repository from '.py' and '.js' files in the project directory.
Returns:
List[RepoFileInfo]: A list of RepoFileInfo objects containing the extracted information.
"""
files_classes = []
directory = self.base_directory
matching_files = []
extensions = ["*.py", "*.js"]
extensions = ["*.py"]
for ext in extensions:
matching_files += directory.rglob(ext)
for path in matching_files:
@ -95,19 +491,38 @@ class RepoParser(BaseModel):
return files_classes
def generate_json_structure(self, output_path):
"""Generate a JSON file documenting the repository structure."""
def generate_json_structure(self, output_path: Path):
"""
Generates a JSON file documenting the repository structure.
Args:
output_path (Path): The path to the JSON file to be generated.
"""
files_classes = [i.model_dump() for i in self.generate_symbols()]
output_path.write_text(json.dumps(files_classes, indent=4))
def generate_dataframe_structure(self, output_path):
"""Generate a DataFrame documenting the repository structure and save as CSV."""
def generate_dataframe_structure(self, output_path: Path):
"""
Generates a DataFrame documenting the repository structure and saves it as a CSV file.
Args:
output_path (Path): The path to the CSV file to be generated.
"""
files_classes = [i.model_dump() for i in self.generate_symbols()]
df = pd.DataFrame(files_classes)
df.to_csv(output_path, index=False)
def generate_structure(self, output_path=None, mode="json") -> Path:
"""Generate the structure of the repository as a specified format."""
def generate_structure(self, output_path: str | Path = None, mode="json") -> Path:
"""
Generates the structure of the repository in a specified format.
Args:
output_path (str | Path): The path to the output file or directory. Default is None.
mode (str): The output format mode. Options: "json" (default), "csv", etc.
Returns:
Path: The path to the generated output file or directory.
"""
output_file = self.base_directory / f"{self.base_directory.name}-structure.{mode}"
output_path = Path(output_path) if output_path else output_file
@ -119,6 +534,16 @@ class RepoParser(BaseModel):
@staticmethod
def node_to_str(node) -> CodeBlockInfo | None:
"""
Parses and converts an Abstract Syntax Tree (AST) node to a CodeBlockInfo object.
Args:
node: The AST node to be converted.
Returns:
CodeBlockInfo | None: A CodeBlockInfo object representing the parsed AST node,
or None if the conversion fails.
"""
if isinstance(node, ast.Try):
return None
if any_to_str(node) == any_to_str(ast.Expr):
@ -159,9 +584,19 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_expr(node) -> List:
"""
Parses an expression Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing an expression.
Returns:
List: A list containing the parsed information from the expression node.
"""
funcs = {
any_to_str(ast.Constant): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
any_to_str(ast.Call): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value.func)],
any_to_str(ast.Tuple): lambda x: [any_to_str(x.value), RepoParser._parse_variable(x.value)],
}
func = funcs.get(any_to_str(node.value))
if func:
@ -170,12 +605,30 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_name(n):
"""
Gets the 'name' value of an Abstract Syntax Tree (AST) node.
Args:
n: The AST node.
Returns:
The 'name' value of the AST node.
"""
if n.asname:
return f"{n.name} as {n.asname}"
return n.name
@staticmethod
def _parse_if(n):
"""
Parses an 'if' statement Abstract Syntax Tree (AST) node.
Args:
n: The AST node representing an 'if' statement.
Returns:
None or Parsed information from the 'if' statement node.
"""
tokens = []
try:
if isinstance(n.test, ast.BoolOp):
@ -187,10 +640,14 @@ class RepoParser(BaseModel):
v = RepoParser._parse_variable(n.test.left)
if v:
tokens.append(v)
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
if isinstance(n.test, ast.Name):
v = RepoParser._parse_variable(n.test)
tokens.append(v)
if hasattr(n.test, "comparators"):
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
return tokens
except Exception as e:
logger.warning(f"Unsupported if: {n}, err:{e}")
@ -198,6 +655,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if_compare(n):
"""
Parses an 'if' condition Abstract Syntax Tree (AST) node.
Args:
n: The AST node representing an 'if' condition.
Returns:
None or Parsed information from the 'if' condition node.
"""
if hasattr(n, "left"):
return RepoParser._parse_variable(n.left)
else:
@ -205,6 +671,15 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_variable(node):
"""
Parses a variable Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing a variable.
Returns:
None or Parsed information from the variable node.
"""
try:
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
@ -213,7 +688,7 @@ class RepoParser(BaseModel):
if hasattr(x.value, "id")
else f"{x.attr}",
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
any_to_str(ast.Tuple): lambda x: "",
any_to_str(ast.Tuple): lambda x: [d.value for d in x.dims],
}
func = funcs.get(any_to_str(node))
if not func:
@ -224,9 +699,24 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_assign(node):
"""
Parses an assignment Abstract Syntax Tree (AST) node.
Args:
node: The AST node representing an assignment.
Returns:
None or Parsed information from the assignment node.
"""
return [RepoParser._parse_variable(t) for t in node.targets]
async def rebuild_class_views(self, path: str | Path = None):
"""
Executes `pylint` to reconstruct the dot format class view repository file.
Args:
path (str | Path): The path to the target directory or file. Default is None.
"""
if not path:
path = self.base_directory
path = Path(path)
@ -247,7 +737,17 @@ class RepoParser(BaseModel):
packages_pathname.unlink(missing_ok=True)
return class_views, relationship_views, package_root
async def _parse_classes(self, class_view_pathname):
@staticmethod
async def _parse_classes(class_view_pathname: Path) -> List[DotClassInfo]:
"""
Parses a dot format class view repository file.
Args:
class_view_pathname (Path): The path to the dot format class view repository file.
Returns:
List[DotClassInfo]: A list of DotClassInfo objects representing the parsed classes.
"""
class_views = []
if not class_view_pathname.exists():
return class_views
@ -258,22 +758,38 @@ class RepoParser(BaseModel):
if not package_name:
continue
class_name, members, functions = re.split(r"(?<!\\)\|", info)
class_info = ClassInfo(name=class_name)
class_info = DotClassInfo(name=class_name)
class_info.package = package_name
for m in members.split("\n"):
if not m:
continue
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
class_info.attributes[member_name] = m
attr = DotClassAttribute.parse(m)
class_info.attributes[attr.name] = attr
for i in attr.compositions:
if i not in class_info.compositions:
class_info.compositions.append(i)
for f in functions.split("\n"):
if not f:
continue
function_name, _ = f.split("(", 1)
class_info.methods[function_name] = f
method = DotClassMethod.parse(f)
class_info.methods[method.name] = method
for i in method.aggregations:
if i not in class_info.compositions and i not in class_info.aggregations:
class_info.aggregations.append(i)
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
@staticmethod
async def _parse_class_relationships(class_view_pathname: Path) -> List[DotClassRelationship]:
"""
Parses a dot format class view repository file.
Args:
class_view_pathname (Path): The path to the dot format class view repository file.
Returns:
List[DotClassRelationship]: A list of DotClassRelationship objects representing the parsed class relationships.
"""
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
@ -287,7 +803,16 @@ class RepoParser(BaseModel):
return relationship_views
@staticmethod
def _split_class_line(line):
def _split_class_line(line: str) -> (str, str):
"""
Parses a dot format line about class info and returns the class name part and class members part.
Args:
line (str): The dot format line containing class information.
Returns:
Tuple[str, str]: A tuple containing the class name part and class members part.
"""
part_splitor = '" ['
if part_splitor not in line:
return None, None
@ -305,14 +830,25 @@ class RepoParser(BaseModel):
return class_name, info
@staticmethod
def _split_relationship_line(line):
def _split_relationship_line(line: str) -> DotClassRelationship:
"""
Parses a dot format line about the relationship of two classes and returns 'Generalize', 'Composite',
or 'Aggregate'.
Args:
line (str): The dot format line containing relationship information.
Returns:
DotClassRelationship: The object of relationship representing either 'Generalize', 'Composite',
or 'Aggregate' relationship.
"""
splitters = [" -> ", " [", "];"]
idxs = []
for tag in splitters:
if tag not in line:
return None
idxs.append(line.find(tag))
ret = ClassRelationship()
ret = DotClassRelationship()
ret.src = line[0 : idxs[0]].strip('"')
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
@ -330,7 +866,16 @@ class RepoParser(BaseModel):
return ret
@staticmethod
def _get_label(line):
def _get_label(line: str) -> str:
"""
Parses a dot format line and returns the label information.
Args:
line (str): The dot format line containing label information.
Returns:
str: The label information parsed from the line.
"""
tag = 'label="'
if tag not in line:
return ""
@ -340,6 +885,15 @@ class RepoParser(BaseModel):
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
"""
Creates a mapping table between source code files' paths and module names.
Args:
path (str | Path): The path to the source code files or directory.
Returns:
Dict[str, str]: A dictionary mapping source code file paths to their corresponding module names.
"""
mappings = {
str(path).replace("/", "."): str(path),
}
@ -363,8 +917,21 @@ class RepoParser(BaseModel):
@staticmethod
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationship], str):
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
) -> (List[DotClassInfo], List[DotClassRelationship], str):
"""
Augments namespaces to the path-prefixed classes and relationships.
Args:
class_views (List[DotClassInfo]): List of DotClassInfo objects representing class views.
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects representing
relationships.
path (str | Path): The path to the source code files or directory.
Returns:
Tuple[List[DotClassInfo], List[DotClassRelationship], str]: A tuple containing the augmented class views,
relationships, and the root path of the package.
"""
if not class_views:
return [], [], ""
c = class_views[0]
@ -383,16 +950,25 @@ class RepoParser(BaseModel):
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
for i in range(len(relationship_views)):
v = relationship_views[i]
for _, v in enumerate(relationship_views):
v.src = RepoParser._repair_ns(v.src, new_mappings)
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
relationship_views[i] = v
return class_views, relationship_views, root_path
return class_views, relationship_views, str(path)[: len(root_path)]
@staticmethod
def _repair_ns(package, mappings):
def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
"""
Replaces the package-prefix with the namespace-prefix.
Args:
package (str): The package to be repaired.
mappings (Dict[str, str]): A dictionary mapping source code file paths to their corresponding packages.
Returns:
str: The repaired namespace.
"""
file_ns = package
ix = 0
while file_ns != "":
if file_ns not in mappings:
ix = file_ns.rfind(".")
@ -404,7 +980,17 @@ class RepoParser(BaseModel):
return ns
@staticmethod
def _find_root(full_key, package) -> str:
def _find_root(full_key: str, package: str) -> str:
"""
Returns the package root path based on the key, which is the full path, and the package information.
Args:
full_key (str): The full key representing the full path.
package (str): The package information.
Returns:
str: The package root path.
"""
left = full_key
while left != "":
if left in package:
@ -417,5 +1003,14 @@ class RepoParser(BaseModel):
return "." + full_key[0:ix]
def is_func(node):
def is_func(node) -> bool:
"""
Returns True if the given node represents a function.
Args:
node: The Abstract Syntax Tree (AST) node.
Returns:
bool: True if the node represents a function, False otherwise.
"""
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))

View file

@ -65,7 +65,7 @@ class Assistant(Role):
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
rsp = await self.llm.aask(prompt, ["You are an action classifier"])
rsp = await self.llm.aask(prompt, ["You are an action classifier"], stream=False)
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
return await self._plan(rsp, last_talk=last_talk)

View file

@ -359,9 +359,17 @@ class Engineer(Role):
summarizations[ctx].append(filename)
for ctx, filenames in summarizations.items():
ctx.codes_filenames = filenames
self.summarize_todos.append(SummarizeCode(i_context=ctx, context=self.context, llm=self.llm))
new_summarize = SummarizeCode(i_context=ctx, context=self.context, llm=self.llm)
for i, act in enumerate(self.summarize_todos):
if act.i_context.task_filename == new_summarize.i_context.task_filename:
self.summarize_todos[i] = new_summarize
new_summarize = None
break
if new_summarize:
self.summarize_todos.append(new_summarize)
if self.summarize_todos:
self.set_todo(self.summarize_todos[0])
self.summarize_todos.pop(0)
async def _new_code_plan_and_change_action(self):
"""Create a WriteCodePlanAndChange action for subsequent to-do actions."""

View file

@ -46,6 +46,7 @@ from metagpt.const import (
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import DotClassInfo
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import (
@ -690,54 +691,64 @@ class CodePlanAndChangeContext(BaseModel):
# mermaid class view
class ClassMeta(BaseModel):
class UMLClassMeta(BaseModel):
name: str = ""
abstraction: bool = False
static: bool = False
visibility: str = ""
@staticmethod
def name_to_visibility(name: str) -> str:
if name == "__init__":
return "+"
if name.startswith("__"):
return "-"
elif name.startswith("_"):
return "#"
return "+"
class ClassAttribute(ClassMeta):
class UMLClassAttribute(UMLClassMeta):
value_type: str = ""
default_value: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
if self.value_type:
content += self.value_type + " "
content += self.name
content += self.value_type.replace(" ", "") + " "
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
content += name
if self.default_value:
content += "="
if self.value_type not in ["str", "string", "String"]:
content += self.default_value
else:
content += '"' + self.default_value.replace('"', "") + '"'
if self.abstraction:
content += "*"
if self.static:
content += "$"
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
return content
class ClassMethod(ClassMeta):
args: List[ClassAttribute] = Field(default_factory=list)
class UMLClassMethod(UMLClassMeta):
args: List[UMLClassAttribute] = Field(default_factory=list)
return_type: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
if self.return_type:
content += ":" + self.return_type
if self.abstraction:
content += "*"
if self.static:
content += "$"
content += " " + self.return_type.replace(" ", "")
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
return content
class ClassView(ClassMeta):
attributes: List[ClassAttribute] = Field(default_factory=list)
methods: List[ClassMethod] = Field(default_factory=list)
class UMLClassView(UMLClassMeta):
attributes: List[UMLClassAttribute] = Field(default_factory=list)
methods: List[UMLClassMethod] = Field(default_factory=list)
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
@ -747,3 +758,21 @@ class ClassView(ClassMeta):
content += v.get_mermaid(align=align + 1) + "\n"
content += "".join(["\t" for i in range(align)]) + "}\n"
return content
@classmethod
def load_dot_class_info(cls, dot_class_info: DotClassInfo) -> UMLClassView:
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
class_view = cls(name=dot_class_info.name, visibility=visibility)
for i in dot_class_info.attributes.values():
visibility = UMLClassAttribute.name_to_visibility(i.name)
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
class_view.attributes.append(attr)
for i in dot_class_info.methods.values():
visibility = UMLClassMethod.name_to_visibility(i.name)
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
for j in i.args:
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
method.args.append(arg)
method.return_type = i.return_args.type_
class_view.methods.append(method)
return class_view

View file

@ -23,10 +23,10 @@ import platform
import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, List, Tuple, Union
from typing import Any, Callable, List, Literal, Tuple, Union
from urllib.parse import quote, unquote
import aiofiles
import loguru
@ -433,23 +433,109 @@ def is_send_to(message: "Message", addresses: set):
def any_to_name(val):
"""
Convert a value to its name by extracting the last part of the dotted path.
:param val: The value to convert.
:return: The name of the value.
"""
return any_to_str(val).split(".")[-1]
def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def concat_namespace(*args, delimiter: str = ":") -> str:
"""Concatenate fields to create a unique namespace prefix.
Example:
>>> concat_namespace('prefix', 'field1', 'field2', delimiter=":")
'prefix:field1:field2'
"""
return delimiter.join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
def split_namespace(ns_class_name: str, delimiter: str = ":", maxsplit: int = 1) -> List[str]:
"""Split a namespace-prefixed name into its namespace-prefix and name parts.
Example:
>>> split_namespace('prefix:classname')
['prefix', 'classname']
>>> split_namespace('prefix:module:class', delimiter=":", maxsplit=2)
['prefix', 'module', 'class']
"""
return ns_class_name.split(delimiter, maxsplit=maxsplit)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
def auto_namespace(name: str, delimiter: str = ":") -> str:
"""Automatically handle namespace-prefixed names.
If the input name is empty, returns a default namespace prefix and name.
If the input name is not namespace-prefixed, adds a default namespace prefix.
Otherwise, returns the input name unchanged.
Example:
>>> auto_namespace('classname')
'?:classname'
>>> auto_namespace('prefix:classname')
'prefix:classname'
>>> auto_namespace('')
'?:?'
>>> auto_namespace('?:custom')
'?:custom'
"""
if not name:
return f"?{delimiter}?"
v = split_namespace(name, delimiter=delimiter)
if len(v) < 2:
return f"?{delimiter}{name}"
return name
def add_affix(text: str, affix: Literal["brace", "url", "none"] = "brace"):
"""Add affix to encapsulate data.
Example:
>>> add_affix("data", affix="brace")
'{data}'
>>> add_affix("example.com", affix="url")
'%7Bexample.com%7D'
>>> add_affix("text", affix="none")
'text'
"""
mappings = {
"brace": lambda x: "{" + x + "}",
"url": lambda x: quote("{" + x + "}"),
}
encoder = mappings.get(affix, lambda x: x)
return encoder(text)
def remove_affix(text, affix: Literal["brace", "url", "none"] = "brace"):
"""Remove affix to extract encapsulated data.
Args:
text (str): The input text with affix to be removed.
affix (str, optional): The type of affix used. Defaults to "brace".
Supported affix types: "brace" for removing curly braces, "url" for URL decoding within curly braces.
Returns:
str: The text with affix removed.
Example:
>>> remove_affix('{data}', affix="brace")
'data'
>>> remove_affix('%7Bexample.com%7D', affix="url")
'example.com'
>>> remove_affix('text', affix="none")
'text'
"""
mappings = {"brace": lambda x: x[1:-1], "url": lambda x: unquote(x)[1:-1]}
decoder = mappings.get(affix, lambda x: x)
return decoder(text)
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
@ -636,6 +722,54 @@ def list_files(root: str | Path) -> List[Path]:
return files
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL)
return [v.strip() for v in json_blocks]
def remove_white_spaces(v: str) -> str:
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
async def aread_bin(filename: str | Path) -> bytes:
"""Read binary file asynchronously.
Args:
filename (Union[str, Path]): The name or path of the file to be read.
Returns:
bytes: The content of the file as bytes.
Example:
>>> content = await aread_bin('example.txt')
b'This is the content of the file.'
>>> content = await aread_bin(Path('example.txt'))
b'This is the content of the file.'
"""
async with aiofiles.open(str(filename), mode="rb") as reader:
content = await reader.read()
return content
async def awrite_bin(filename: str | Path, data: bytes):
"""Write binary file asynchronously.
Args:
filename (Union[str, Path]): The name or path of the file to be written.
data (bytes): The binary data to be written to the file.
Example:
>>> await awrite_bin('output.bin', b'This is binary data.')
>>> await awrite_bin(Path('output.bin'), b'Another set of binary data.')
"""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="wb") as writer:
await writer.write(data)
def is_coroutine_func(func: Callable) -> bool:
return inspect.iscoroutinefunction(func)

View file

@ -41,6 +41,8 @@ class CostManager(BaseModel):
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
if prompt_tokens + completion_tokens == 0 or not model:
return
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
if model not in self.token_costs:

View file

@ -4,7 +4,9 @@
@Time : 2023/12/19
@Author : mashenquan
@File : di_graph_repository.py
@Desc : Graph repository based on DiGraph
@Desc : Graph repository based on DiGraph.
This script defines a graph repository class based on a directed graph (DiGraph), providing functionalities
specific to handling directed relationships between entities.
"""
from __future__ import annotations
@ -19,20 +21,41 @@ from metagpt.utils.graph_repository import SPO, GraphRepository
class DiGraphRepository(GraphRepository):
"""Graph repository based on DiGraph."""
def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)
self._repo = networkx.DiGraph()
async def insert(self, subject: str, predicate: str, object_: str):
"""Insert a new triple into the directed graph repository.
Args:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
Example:
await my_di_graph_repo.insert(subject="Node1", predicate="connects_to", object_="Node2")
# Adds a directed relationship: Node1 connects_to Node2
"""
self._repo.add_edge(subject, object_, predicate=predicate)
async def upsert(self, subject: str, predicate: str, object_: str):
pass
async def update(self, subject: str, predicate: str, object_: str):
pass
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
"""Retrieve triples from the directed graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
List[SPO]: A list of SPO objects representing the selected triples.
Example:
selected_triples = await my_di_graph_repo.select(subject="Node1", predicate="connects_to")
# Retrieves directed relationships where Node1 is the subject and the predicate is 'connects_to'.
"""
result = []
for s, o, p in self._repo.edges(data="predicate"):
if subject and subject != s:
@ -44,12 +67,41 @@ class DiGraphRepository(GraphRepository):
result.append(SPO(subject=s, predicate=p, object_=o))
return result
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
"""Delete triples from the directed graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
int: The number of triples deleted from the repository.
Example:
deleted_count = await my_di_graph_repo.delete(subject="Node1", predicate="connects_to")
# Deletes directed relationships where Node1 is the subject and the predicate is 'connects_to'.
"""
rows = await self.select(subject=subject, predicate=predicate, object_=object_)
if not rows:
return 0
for r in rows:
self._repo.remove_edge(r.subject, r.object_)
return len(rows)
def json(self) -> str:
"""Convert the directed graph repository to a JSON-formatted string."""
m = networkx.node_link_data(self._repo)
data = json.dumps(m)
return data
async def save(self, path: str | Path = None):
"""Save the directed graph repository to a JSON file.
Args:
path (Union[str, Path], optional): The directory path where the JSON file will be saved.
If not provided, the default path is taken from the 'root' key in the keyword arguments.
"""
data = self.json()
path = path or self._kwargs.get("root")
if not path.exists():
@ -58,12 +110,21 @@ class DiGraphRepository(GraphRepository):
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
"""Load a directed graph repository from a JSON file."""
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)
@staticmethod
async def load_from(pathname: str | Path) -> GraphRepository:
"""Create and load a directed graph repository from a JSON file.
Args:
pathname (Union[str, Path]): The path to the JSON file to be loaded.
Returns:
GraphRepository: A new instance of the graph repository loaded from the specified JSON file.
"""
pathname = Path(pathname)
name = pathname.with_suffix("").name
root = pathname.parent
@ -74,9 +135,16 @@ class DiGraphRepository(GraphRepository):
@property
def root(self) -> str:
"""Return the root directory path for the graph repository files."""
return self._kwargs.get("root")
@property
def pathname(self) -> Path:
"""Return the path and filename to the graph repository file."""
p = Path(self.root) / self.name
return p.with_suffix(".json")
@property
def repo(self):
"""Get the underlying directed graph repository."""
return self._repo

View file

@ -4,21 +4,28 @@
@Time : 2023/12/19
@Author : mashenquan
@File : graph_repository.py
@Desc : Superclass for graph repository.
@Desc : Superclass for graph repository. This script defines a superclass for a graph repository, providing a
foundation for specific implementations.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace, split_namespace
class GraphKeyword:
"""Basic words for a Graph database.
This class defines a set of basic words commonly used in the context of a Graph database.
"""
IS = "is"
OF = "Of"
ON = "On"
@ -28,51 +35,149 @@ class GraphKeyword:
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_METHOD = "class_method"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_METHOD = "has_class_method"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_DETAIL = "has_detail"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
HAS_SEQUENCE_VIEW_VER = "has_sequence_view_ver"
HAS_CLASS_USE_CASE = "has_class_use_case"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
HAS_PARTICIPANT = "has_participant"
class SPO(BaseModel):
"""Graph repository record type.
This class represents a record in a graph repository with three components:
- Subject: The subject of the triple.
- Predicate: The predicate describing the relationship between the subject and the object.
- Object: The object of the triple.
Attributes:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
Example:
spo_record = SPO(subject="Node1", predicate="connects_to", object_="Node2")
# Represents a triple: Node1 connects_to Node2
"""
subject: str
predicate: str
object_: str
class GraphRepository(ABC):
"""Abstract base class for a Graph Repository.
This class defines the interface for a graph repository, providing methods for inserting, selecting,
deleting, and saving graph data. Concrete implementations of this class must provide functionality
for these operations.
"""
def __init__(self, name: str, **kwargs):
self._repo_name = name
self._kwargs = kwargs
@abstractmethod
async def insert(self, subject: str, predicate: str, object_: str):
pass
"""Insert a new triple into the graph repository.
@abstractmethod
async def upsert(self, subject: str, predicate: str, object_: str):
pass
Args:
subject (str): The subject of the triple.
predicate (str): The predicate describing the relationship.
object_ (str): The object of the triple.
@abstractmethod
async def update(self, subject: str, predicate: str, object_: str):
Example:
await my_repository.insert(subject="Node1", predicate="connects_to", object_="Node2")
# Inserts a triple: Node1 connects_to Node2 into the graph repository.
"""
pass
@abstractmethod
async def select(self, subject: str = None, predicate: str = None, object_: str = None) -> List[SPO]:
"""Retrieve triples from the graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
List[SPO]: A list of SPO objects representing the selected triples.
Example:
selected_triples = await my_repository.select(subject="Node1", predicate="connects_to")
# Retrieves triples where Node1 is the subject and the predicate is 'connects_to'.
"""
pass
@abstractmethod
async def delete(self, subject: str = None, predicate: str = None, object_: str = None) -> int:
"""Delete triples from the graph repository based on specified criteria.
Args:
subject (str, optional): The subject of the triple to filter by.
predicate (str, optional): The predicate describing the relationship to filter by.
object_ (str, optional): The object of the triple to filter by.
Returns:
int: The number of triples deleted from the repository.
Example:
deleted_count = await my_repository.delete(subject="Node1", predicate="connects_to")
# Deletes triples where Node1 is the subject and the predicate is 'connects_to'.
"""
pass
@abstractmethod
async def save(self):
"""Save any changes made to the graph repository.
Example:
await my_repository.save()
# Persists any changes made to the graph repository.
"""
pass
@property
def name(self) -> str:
"""Get the name of the graph repository."""
return self._repo_name
@staticmethod
async def update_graph_db_with_file_info(graph_db: "GraphRepository", file_info: RepoFileInfo):
"""Insert information of RepoFileInfo into the specified graph repository.
This function updates the provided graph repository with information from the given RepoFileInfo object.
The function inserts triples related to various dimensions such as file type, class, class method, function,
global variable, and page info.
Triple Patterns:
- (?, is, [file type])
- (?, has class, ?)
- (?, is, [class])
- (?, has class method, ?)
- (?, has function, ?)
- (?, is, [function])
- (?, is, global variable)
- (?, has page info, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
file_info (RepoFileInfo): The RepoFileInfo object containing information to be inserted.
Example:
await update_graph_db_with_file_info(my_graph_repo, my_file_info)
# Updates 'my_graph_repo' with information from 'my_file_info'.
"""
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(file_info.file).suffix, GraphKeyword.NULL)
@ -95,13 +200,13 @@ class GraphRepository(ABC):
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
for f in file_info.functions:
# file -> function
@ -133,7 +238,34 @@ class GraphRepository(ABC):
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
"""Insert dot format class information into the specified graph repository.
This function updates the provided graph repository with class information from the given list of DotClassInfo objects.
The function inserts triples related to various aspects of class views, including source code, file type, class,
class property, class detail, method, composition, and aggregation.
Triple Patterns:
- (?, is, source code)
- (?, is, file type)
- (?, has class, ?)
- (?, is, class)
- (?, has class property, ?)
- (?, is, class property)
- (?, has detail, ?)
- (?, has method, ?)
- (?, is composite of, ?)
- (?, is aggregate of, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
class_views (List[DotClassInfo]): List of DotClassInfo objects containing class information to be inserted.
Example:
await update_graph_db_with_class_views(my_graph_repo, [class_info1, class_info2])
# Updates 'my_graph_repo' with class information from the provided list of DotClassInfo objects.
"""
for c in class_views:
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
@ -146,6 +278,7 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
@ -160,33 +293,61 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.HAS_DETAIL,
object_=vt.model_dump_json(),
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
for fn, ft in c.methods.items():
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
predicate=GraphKeyword.HAS_DETAIL,
object_=ft.model_dump_json(),
)
for i in c.compositions:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
)
for i in c.aggregations:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
):
"""Insert class relationships and labels into the specified graph repository.
This function updates the provided graph repository with class relationship information from the given list
of DotClassRelationship objects. The function inserts triples representing relationships and labels between
classes.
Triple Patterns:
- (?, is relationship of, ?)
- (?, is relationship on, ?)
Args:
graph_db (GraphRepository): The graph repository object to be updated.
relationship_views (List[DotClassRelationship]): List of DotClassRelationship objects containing
class relationship information to be inserted.
Example:
await update_graph_db_with_class_relationship_views(my_graph_repo, [relationship1, relationship2])
# Updates 'my_graph_repo' with class relationship information from the provided list of DotClassRelationship objects.
"""
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
@ -198,3 +359,32 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)
@staticmethod
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
"""Append namespace-prefixed information to relationship SPO (Subject-Predicate-Object) objects in the graph
repository.
This function updates the provided graph repository by appending namespace-prefixed information to existing
relationship SPO objects.
Args:
graph_db (GraphRepository): The graph repository object to be updated.
"""
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mapping = defaultdict(list)
for c in classes:
name = split_namespace(c.subject)[-1]
mapping[name].append(c.subject)
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
for r in rows:
ns, class_ = split_namespace(r.object_)
if ns != "?":
continue
val = mapping[class_]
if len(val) != 1:
continue
ns_name = val[0]
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)

View file

@ -33,6 +33,7 @@ from metagpt.const import (
TASK_PDF_FILE_REPO,
TEST_CODES_FILE_REPO,
TEST_OUTPUTS_FILE_REPO,
VISUAL_GRAPH_REPO_FILE_REPO,
)
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -69,6 +70,7 @@ class ResourceFileRepositories(FileRepository):
code_summary: FileRepository
sd_output: FileRepository
code_plan_and_change: FileRepository
graph_repo: FileRepository
def __init__(self, git_repo):
super().__init__(git_repo=git_repo, relative_path=RESOURCES_FILE_REPO)
@ -82,6 +84,7 @@ class ResourceFileRepositories(FileRepository):
self.code_summary = git_repo.new_file_repository(relative_path=CODE_SUMMARIES_PDF_FILE_REPO)
self.sd_output = git_repo.new_file_repository(relative_path=SD_OUTPUT_FILE_REPO)
self.code_plan_and_change = git_repo.new_file_repository(relative_path=CODE_PLAN_AND_CHANGE_PDF_FILE_REPO)
self.graph_repo = git_repo.new_file_repository(relative_path=VISUAL_GRAPH_REPO_FILE_REPO)
class ProjectRepo(FileRepository):

View file

@ -140,7 +140,6 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,

View file

@ -0,0 +1,162 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : visualize_graph.py
@Desc : Visualization tool to visualize the class diagrams or sequence diagrams of the graph repository.
"""
from __future__ import annotations
import re
from abc import ABC
from pathlib import Path
from typing import List, Optional
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.schema import UMLClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class _VisualClassView(BaseModel):
"""Protected class used by VisualGraphRepo internally.
Attributes:
package (str): The package associated with the class.
uml (Optional[UMLClassView]): Optional UMLClassView associated with the class.
generalizations (List[str]): List of generalizations for the class.
compositions (List[str]): List of compositions for the class.
aggregations (List[str]): List of aggregations for the class.
"""
package: str
uml: Optional[UMLClassView] = None
generalizations: List[str] = Field(default_factory=list)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
def get_mermaid(self, align: int = 1) -> str:
"""Creates a Markdown Mermaid class diagram text.
Args:
align (int): Indent count used for alignment.
Returns:
str: The Markdown text representing the Mermaid class diagram.
"""
if not self.uml:
return ""
prefix = "\t" * align
mermaid_txt = self.uml.get_mermaid(align=align)
for i in self.generalizations:
mermaid_txt += f"{prefix}{i} <|-- {self.name}\n"
for i in self.compositions:
mermaid_txt += f"{prefix}{i} *-- {self.name}\n"
for i in self.aggregations:
mermaid_txt += f"{prefix}{i} o-- {self.name}\n"
return mermaid_txt
@property
def name(self) -> str:
"""Returns the class name without the namespace prefix."""
return split_namespace(self.package)[-1]
class VisualGraphRepo(ABC):
"""Abstract base class for VisualGraphRepo."""
graph_db: GraphRepository
def __init__(self, graph_db):
self.graph_db = graph_db
class VisualDiGraphRepo(VisualGraphRepo):
"""Implementation of VisualGraphRepo for DiGraph graph repository.
This class extends VisualGraphRepo to provide specific functionality for a graph repository using DiGraph.
"""
@classmethod
async def load_from(cls, filename: str | Path):
"""Load a VisualDiGraphRepo instance from a file."""
graph_db = await DiGraphRepository.load_from(str(filename))
return cls(graph_db=graph_db)
async def get_mermaid_class_view(self) -> str:
"""
Returns a Markdown Mermaid class diagram code block object.
"""
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mermaid_txt = "classDiagram\n"
for r in rows:
v = await self._get_class_view(ns_class_name=r.subject)
mermaid_txt += v.get_mermaid()
return mermaid_txt
async def _get_class_view(self, ns_class_name: str) -> _VisualClassView:
"""Returns the Markdown Mermaid class diagram code block object for the specified class."""
rows = await self.graph_db.select(subject=ns_class_name)
class_view = _VisualClassView(package=ns_class_name)
for r in rows:
if r.predicate == GraphKeyword.HAS_CLASS_VIEW:
class_view.uml = UMLClassView.model_validate_json(r.object_)
elif r.predicate == GraphKeyword.IS + GENERALIZATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.generalizations.append(name)
elif r.predicate == GraphKeyword.IS + COMPOSITION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.compositions.append(name)
elif r.predicate == GraphKeyword.IS + AGGREGATION + GraphKeyword.OF:
name = split_namespace(r.object_)[-1]
name = self._refine_name(name)
if name:
class_view.aggregations.append(name)
return class_view
async def get_mermaid_sequence_views(self) -> List[(str, str)]:
"""Returns all Markdown sequence diagrams with their corresponding graph repository keys."""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views
@staticmethod
def _refine_name(name: str) -> str:
"""Removes impurity content from the given name.
Example:
>>> _refine_name("int")
""
>>> _refine_name('"Class1"')
'Class1'
>>> _refine_name("pkg.Class1")
"Class1"
"""
name = re.sub(r'^[\'"\\\(\)]+|[\'"\\\(\)]+$', "", name)
if name in ["int", "float", "bool", "str", "list", "tuple", "set", "dict", "None"]:
return ""
if "." in name:
name = name.split(".")[-1]
return name
async def get_mermaid_sequence_view_versions(self) -> List[(str, str)]:
"""Returns all versioned Markdown sequence diagrams with their corresponding graph repository keys."""
sequence_views = []
rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER)
for r in rows:
sequence_views.append((r.subject, r.object_))
return sequence_views

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -24,6 +24,8 @@ async def test_rebuild(context):
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files
@ -46,6 +48,12 @@ def test_align_path(path, direction, diff, want):
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT/metagpt", "=", "."),
("/Users/x/github/MetaGPT", "/Users/x/github/MetaGPT/metagpt", "-", "metagpt"),
("/Users/x/github/MetaGPT/metagpt", "/Users/x/github/MetaGPT", "+", "metagpt"),
(
"/Users/x/github/MetaGPT-env/lib/python3.9/site-packages/moviepy",
"/Users/x/github/MetaGPT-env/lib/python3.9/site-packages/",
"+",
"moviepy",
),
],
)
def test_diff_path(path_root, package_root, want_direction, want_diff):

View file

@ -4,6 +4,7 @@
@Time : 2024/1/4
@Author : mashenquan
@File : test_rebuild_sequence_view.py
@Desc : Unit tests for reconstructing the sequence diagram from a source code project.
"""
from pathlib import Path
@ -14,21 +15,34 @@ from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
from metagpt.utils.common import aread
from metagpt.utils.git_repository import ChangeType
from metagpt.utils.graph_repository import SPO
@pytest.mark.asyncio
@pytest.mark.skip
async def test_rebuild(context):
async def test_rebuild(context, mocker):
# Mock
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.class_view.json")
graph_db_filename = Path(context.repo.workdir.name).with_suffix(".json")
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
context.git_repo.commit("commit1")
# mock_spo = SPO(
# subject="metagpt/startup.py:__name__:__main__",
# predicate="has_page_info",
# object_='{"lineno":78,"end_lineno":79,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
# )
mock_spo = SPO(
subject="metagpt/management/skill_manager.py:__name__:__main__",
predicate="has_page_info",
object_='{"lineno":113,"end_lineno":116,"type_name":"ast.If","tokens":["__name__","__main__"],"properties":{}}',
)
mocker.patch.object(RebuildSequenceView, "_search_main_entry", return_value=[mock_spo])
action = RebuildSequenceView(
name="RedBean",
i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"),
i_context=str(
Path(__file__).parent.parent.parent.parent / "metagpt/management/skill_manager.py:__name__:__main__"
),
llm=LLM(),
context=context,
)

View file

@ -12,6 +12,7 @@ import pytest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from tests.mock.mock_llm import MockLLM
DESIGN_CONTENT = """
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
@ -173,10 +174,87 @@ class Snake:
"""
mock_rsp = """
```mermaid
classDiagram
class Game{
+int score
+int level
+Snake snake
+Food food
+start_game() void
+initialize_game() void
+game_loop() void
+update() void
+draw() void
+handle_events() void
+check_collision() void
+increase_score() void
+increase_level() void
+game_over() void
Game()
}
class Snake{
+list body
+tuple direction
+move() void
+change_direction(direction: str) void
+grow() void
+get_head() tuple
+get_body() list
Snake()
}
class Food{
+tuple position
+generate() void
+get_position() tuple
Food()
}
Game "1" -- "1" Snake: has
Game "1" -- "1" Food: has
```
```sequenceDiagram
participant M as Main
participant G as Game
participant S as Snake
participant F as Food
M->>G: start_game()
G->>G: initialize_game()
G->>G: game_loop()
G->>S: move()
G->>S: change_direction()
G->>S: grow()
G->>F: generate()
S->>S: move()
S->>S: change_direction()
S->>S: grow()
F->>F: generate()
```
## Summary
The code consists of the main game logic, including the Game, Snake, and Food classes. The game loop is responsible for updating and drawing the game elements, handling events, checking collisions, and managing the game state. The Snake class handles the movement, growth, and direction changes of the snake, while the Food class is responsible for generating and tracking the position of food items.
## TODOs
- Modify 'game.py' to add the implementation of obstacle handling and interaction with the game loop.
- Implement 'obstacle.py' to include the methods for spawning, moving, and disappearing of obstacles, as well as collision detection with the snake.
- Update 'main.py' to initialize the obstacle and incorporate it into the game loop.
- Update the mermaid call flow diagram to include the interaction with the obstacle.
```python
{
"files_to_modify": {
"game.py": "Add obstacle handling and interaction with the game loop",
"obstacle.py": "Implement obstacle class with necessary methods",
"main.py": "Initialize the obstacle and incorporate it into the game loop"
}
}
```
"""
@pytest.mark.skip
@pytest.mark.asyncio
async def test_summarize_code(context, git_dir):
async def test_summarize_code(context, mocker):
context.src_workspace = context.git_repo.workdir / "src"
await context.repo.docs.system_design.save(filename="1.json", content=DESIGN_CONTENT)
await context.repo.docs.task.save(filename="1.json", content=TASK_CONTENT)
@ -185,6 +263,7 @@ async def test_summarize_code(context, git_dir):
await context.repo.srcs.save(filename="game.py", content=GAME_PY)
await context.repo.srcs.save(filename="main.py", content=MAIN_PY)
await context.repo.srcs.save(filename="snake.py", content=SNAKE_PY)
mocker.patch.object(MockLLM, "_mock_rsp", return_value=mock_rsp)
all_files = context.repo.srcs.all_files
summarization_context = CodeSummarizeContext(

View file

@ -30,6 +30,17 @@ async def test_run(mocker, context):
language: str
agent_description: str
cause_by: str
agent_skills: list
agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
inputs = [
{
@ -48,6 +59,7 @@ async def test_run(mocker, context):
"language": "English",
"agent_description": "chatterbox",
"cause_by": any_to_str(TalkAction),
"agent_skills": [],
},
{
"memory": {
@ -65,24 +77,16 @@ async def test_run(mocker, context):
"language": "English",
"agent_description": "painter",
"cause_by": any_to_str(SkillAction),
"agent_skills": agent_skills,
},
]
agent_skills = [
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
{"id": 3, "name": "data_analysis", "type": "builtin", "config": {}, "enabled": True},
{"id": 5, "name": "crawler", "type": "builtin", "config": {"engine": "ddg"}, "enabled": True},
{"id": 6, "name": "knowledge", "type": "builtin", "config": {}, "enabled": True},
{"id": 6, "name": "web_search", "type": "builtin", "config": {}, "enabled": True},
]
for i in inputs:
seed = Input(**i)
role = Assistant(language="Chinese", context=context)
role.context.kwargs.language = seed.language
role.context.kwargs.agent_description = seed.agent_description
role.context.kwargs.agent_skills = agent_skills
role.context.kwargs.agent_skills = seed.agent_skills
role.memory = seed.memory # Restore historical conversation content.
while True:

View file

@ -1,9 +1,11 @@
from pathlib import Path
from pprint import pformat
import pytest
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.repo_parser import DotClassAttribute, DotClassMethod, DotReturn, RepoParser
def test_repo_parser():
@ -23,3 +25,140 @@ def test_error():
"""_parse_file should return empty list when file not existed"""
rsp = RepoParser._parse_file(Path("test_not_existed_file.py"))
assert rsp == []
@pytest.mark.parametrize(
("v", "name", "type_", "default_", "compositions"),
[
("children : dict[str, 'ActionNode']", "children", "dict[str,ActionNode]", "", ["ActionNode"]),
("context : str", "context", "str", "", []),
("example", "example", "", "", []),
("expected_type : Type", "expected_type", "Type", "", ["Type"]),
("args : Optional[Dict]", "args", "Optional[Dict]", "", []),
("rsp : Optional[Message] = Message.Default", "rsp", "Optional[Message]", "Message.Default", ["Message"]),
(
"browser : Literal['chrome', 'firefox', 'edge', 'ie']",
"browser",
"Literal['chrome','firefox','edge','ie']",
"",
[],
),
(
"browser : Dict[ Message, Literal['chrome', 'firefox', 'edge', 'ie'] ]",
"browser",
"Dict[Message,Literal['chrome','firefox','edge','ie']]",
"",
["Message"],
),
("attributes : List[ClassAttribute]", "attributes", "List[ClassAttribute]", "", ["ClassAttribute"]),
("attributes = []", "attributes", "", "[]", []),
(
"request_timeout: Optional[Union[float, Tuple[float, float]]]",
"request_timeout",
"Optional[Union[float,Tuple[float,float]]]",
"",
[],
),
],
)
def test_parse_member(v, name, type_, default_, compositions):
attr = DotClassAttribute.parse(v)
assert name == attr.name
assert type_ == attr.type_
assert default_ == attr.default_
assert compositions == attr.compositions
assert v == attr.description
json_data = attr.model_dump_json()
v = DotClassAttribute.model_validate_json(json_data)
assert v == attr
@pytest.mark.parametrize(
("line", "package_name", "info"),
[
(
'"metagpt.roles.architect.Architect" [color="black", fontcolor="black", label=<{Architect|constraints : str<br ALIGN="LEFT"/>goal : str<br ALIGN="LEFT"/>name : str<br ALIGN="LEFT"/>profile : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
"metagpt.roles.architect.Architect",
"Architect|constraints : str\ngoal : str\nname : str\nprofile : str\n|",
),
(
'"metagpt.actions.skill_action.ArgumentsParingAction" [color="black", fontcolor="black", label=<{ArgumentsParingAction|args : Optional[Dict]<br ALIGN="LEFT"/>ask : str<br ALIGN="LEFT"/>prompt<br ALIGN="LEFT"/>rsp : Optional[Message]<br ALIGN="LEFT"/>skill<br ALIGN="LEFT"/>|parse_arguments(skill_name, txt): dict<br ALIGN="LEFT"/>run(with_message): Message<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.actions.skill_action.ArgumentsParingAction",
"ArgumentsParingAction|args : Optional[Dict]\nask : str\nprompt\nrsp : Optional[Message]\nskill\n|parse_arguments(skill_name, txt): dict\nrun(with_message): Message\n",
),
(
'"metagpt.strategy.base.BaseEvaluator" [color="black", fontcolor="black", label=<{BaseEvaluator|<br ALIGN="LEFT"/>|<I>status_verify</I>()<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.strategy.base.BaseEvaluator",
"BaseEvaluator|\n|<I>status_verify</I>()\n",
),
(
'"metagpt.configs.browser_config.BrowserConfig" [color="black", fontcolor="black", label=<{BrowserConfig|browser : Literal[\'chrome\', \'firefox\', \'edge\', \'ie\']<br ALIGN="LEFT"/>driver : Literal[\'chromium\', \'firefox\', \'webkit\']<br ALIGN="LEFT"/>engine<br ALIGN="LEFT"/>path : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
"metagpt.configs.browser_config.BrowserConfig",
"BrowserConfig|browser : Literal['chrome', 'firefox', 'edge', 'ie']\ndriver : Literal['chromium', 'firefox', 'webkit']\nengine\npath : str\n|",
),
(
'"metagpt.tools.search_engine_serpapi.SerpAPIWrapper" [color="black", fontcolor="black", label=<{SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]<br ALIGN="LEFT"/>model_config<br ALIGN="LEFT"/>params : dict<br ALIGN="LEFT"/>search_engine : Optional[Any]<br ALIGN="LEFT"/>serpapi_api_key : Optional[str]<br ALIGN="LEFT"/>|check_serpapi_api_key(val: str)<br ALIGN="LEFT"/>get_params(query: str): Dict[str, str]<br ALIGN="LEFT"/>results(query: str, max_results: int): dict<br ALIGN="LEFT"/>run(query, max_results: int, as_string: bool): str<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.tools.search_engine_serpapi.SerpAPIWrapper",
"SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]\nmodel_config\nparams : dict\nsearch_engine : Optional[Any]\nserpapi_api_key : Optional[str]\n|check_serpapi_api_key(val: str)\nget_params(query: str): Dict[str, str]\nresults(query: str, max_results: int): dict\nrun(query, max_results: int, as_string: bool): str\n",
),
],
)
def test_split_class_line(line, package_name, info):
p, i = RepoParser._split_class_line(line)
assert p == package_name
assert i == info
@pytest.mark.parametrize(
("v", "name", "args", "return_args"),
[
(
"<I>arequest</I>(method, url, params, headers, files, stream: Literal[True], request_id: Optional[str], request_timeout: Optional[Union[float, Tuple[float, float]]]): Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
"arequest",
[
DotClassAttribute(name="method", description="method"),
DotClassAttribute(name="url", description="url"),
DotClassAttribute(name="params", description="params"),
DotClassAttribute(name="headers", description="headers"),
DotClassAttribute(name="files", description="files"),
DotClassAttribute(name="stream", type_="Literal[True]", description="stream: Literal[True]"),
DotClassAttribute(name="request_id", type_="Optional[str]", description="request_id: Optional[str]"),
DotClassAttribute(
name="request_timeout",
type_="Optional[Union[float,Tuple[float,float]]]",
description="request_timeout: Optional[Union[float, Tuple[float, float]]]",
),
],
DotReturn(
type_="Tuple[AsyncGenerator[OpenAIResponse,None],bool,str]",
compositions=["AsyncGenerator", "OpenAIResponse"],
description="Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
),
),
(
"<I>update</I>(subject: str, predicate: str, object_: str)",
"update",
[
DotClassAttribute(name="subject", type_="str", description="subject: str"),
DotClassAttribute(name="predicate", type_="str", description="predicate: str"),
DotClassAttribute(name="object_", type_="str", description="object_: str"),
],
DotReturn(description=""),
),
],
)
def test_parse_method(v, name, args, return_args):
method = DotClassMethod.parse(v)
assert method.name == name
assert method.args == args
assert method.return_args == return_args
assert method.description == v
json_data = method.model_dump_json()
v = DotClassMethod.model_validate_json(json_data)
assert v == method
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -18,9 +18,6 @@ from metagpt.actions.write_code import WriteCode
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
ClassAttribute,
ClassMethod,
ClassView,
CodeSummarizeContext,
Document,
Message,
@ -28,6 +25,9 @@ from metagpt.schema import (
Plan,
SystemMessage,
Task,
UMLClassAttribute,
UMLClassMethod,
UMLClassView,
UserMessage,
)
from metagpt.utils.common import any_to_str
@ -159,27 +159,26 @@ def test_CodeSummarizeContext(file_list, want):
def test_class_view():
attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
assert attr_b.get_mermaid(align=0) == '#str b="0"$'
class_view = ClassView(name="A")
attr_a = UMLClassAttribute(name="a", value_type="int", default_value="0", visibility="+")
assert attr_a.get_mermaid(align=1) == "\t+int a=0"
attr_b = UMLClassAttribute(name="b", value_type="str", default_value="0", visibility="#")
assert attr_b.get_mermaid(align=0) == '#str b="0"'
class_view = UMLClassView(name="A")
class_view.attributes = [attr_a, attr_b]
method_a = ClassMethod(name="run", visibility="+", abstraction=True)
assert method_a.get_mermaid(align=1) == "\t+run()*"
method_b = ClassMethod(
method_a = UMLClassMethod(name="run", visibility="+")
assert method_a.get_mermaid(align=1) == "\t+run()"
method_b = UMLClassMethod(
name="_test",
visibility="#",
static=True,
args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
args=[UMLClassAttribute(name="a", value_type="str"), UMLClassAttribute(name="b", value_type="int")],
return_type="str",
)
assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
assert method_b.get_mermaid(align=0) == "#_test(str a,int b) str"
class_view.methods = [method_a, method_b]
assert (
class_view.get_mermaid(align=0)
== 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
== 'class A{\n\t+int a=0\n\t#str b="0"\n\t+run()\n\t#_test(str a,int b) str\n}\n'
)

View file

@ -178,7 +178,7 @@ class TestGetProjectRoot:
],
)
def test_split_namespace(self, val, want):
res = split_namespace(val)
res = split_namespace(val, maxsplit=-1)
assert res == want
def test_read_json_file(self):

View file

@ -13,6 +13,7 @@ import aiofiles
import pytest
from metagpt.config2 import Config
from metagpt.configs.s3_config import S3Config
from metagpt.utils.common import aread
from metagpt.utils.s3 import S3
@ -30,6 +31,14 @@ async def test_s3(mocker):
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mocker.patch.object(aioboto3.Session, "client", return_value=mock_client)
mock_config = mocker.Mock()
mock_config.s3 = S3Config(
access_key="mock_access_key",
secret_key="mock_secret_key",
endpoint="http://mock.endpoint",
bucket="mock_bucket",
)
mocker.patch.object(Config, "default", return_value=mock_config)
# Prerequisites
s3 = Config.default().s3

View file

@ -0,0 +1,45 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : test_visual_graph_repo.py
@Desc : Unit tests for testing and demonstrating the usage of VisualDiGraphRepo.
"""
import re
from pathlib import Path
import pytest
from metagpt.utils.common import remove_affix, split_namespace
from metagpt.utils.visual_graph_repo import VisualDiGraphRepo
@pytest.mark.asyncio
async def test_visual_di_graph_repo(context, mocker):
filename = Path(__file__).parent / "../../data/graph_db/networkx.sequence_view.json"
repo = await VisualDiGraphRepo.load_from(filename=filename)
class_view = await repo.get_mermaid_class_view()
assert class_view
await context.repo.resources.graph_repo.save(filename="class_view.md", content=f"```mermaid\n{class_view}\n```\n")
sequence_views = await repo.get_mermaid_sequence_views()
assert sequence_views
for ns, sqv in sequence_views:
filename = re.sub(r"[:/\\\.]+", "_", ns) + ".sequence_view.md"
sqv = sqv.strip(" `")
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
sequence_view_vers = await repo.get_mermaid_sequence_view_versions()
assert sequence_view_vers
for ns, sqv in sequence_view_vers:
ver, sqv = split_namespace(sqv)
filename = re.sub(r"[:/\\\.]+", "_", ns) + f".{ver}.sequence_view_ver.md"
sqv = remove_affix(sqv).strip(" `")
await context.repo.resources.graph_repo.save(filename=filename, content=f"```mermaid\n{sqv}\n```\n")
if __name__ == "__main__":
pytest.main([__file__, "-s"])