mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
fixbug: class view
fixbug: method type, feat: add compotition&aggregation
This commit is contained in:
parent
9e84e63529
commit
633c772529
7 changed files with 510 additions and 161 deletions
|
|
@ -6,7 +6,7 @@
|
|||
@File : rebuild_class_view.py
|
||||
@Desc : Rebuild class view info
|
||||
"""
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -22,9 +22,9 @@ from metagpt.const import (
|
|||
GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, RepoParser
|
||||
from metagpt.schema import UMLClassAttribute, UMLClassMethod, UMLClassView
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
|
@ -40,6 +40,7 @@ class RebuildClassView(Action):
|
|||
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
|
||||
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
|
||||
await GraphRepository.rebuild_composition_relationship(self.graph_db)
|
||||
# use ast
|
||||
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
|
||||
symbols = repo_parser.generate_symbols()
|
||||
|
|
@ -81,24 +82,40 @@ class RebuildClassView(Action):
|
|||
# Ignore sub-class
|
||||
return ""
|
||||
|
||||
class_view = ClassView(name=fields[1])
|
||||
rows = await self.graph_db.select(subject=ns_class_name)
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
|
||||
var_type = await self._parse_variable_type(r.object_)
|
||||
attribute = ClassAttribute(
|
||||
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
|
||||
)
|
||||
class_view.attributes.append(attribute)
|
||||
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
|
||||
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
|
||||
await self._parse_function_args(method, r.object_)
|
||||
class_view.methods.append(method)
|
||||
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
|
||||
if not rows:
|
||||
return ""
|
||||
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
|
||||
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
|
||||
class_view = UMLClassView(name=dot_class_info.name, visibility=visibility)
|
||||
for i in dot_class_info.attributes.values():
|
||||
visibility = UMLClassAttribute.name_to_visibility(i.name)
|
||||
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
|
||||
class_view.attributes.append(attr)
|
||||
for i in dot_class_info.methods.values():
|
||||
visibility = UMLClassMethod.name_to_visibility(i.name)
|
||||
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
|
||||
for j in i.args:
|
||||
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
|
||||
method.args.append(arg)
|
||||
|
||||
# update graph db
|
||||
# update uml view
|
||||
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
# update uml isCompositeOf
|
||||
for c in dot_class_info.compositions:
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name,
|
||||
predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF,
|
||||
object_=concat_namespace("?", c),
|
||||
)
|
||||
|
||||
# update uml isAggregateOf
|
||||
for a in dot_class_info.aggregations:
|
||||
await self.graph_db.insert(
|
||||
subject=ns_class_name,
|
||||
predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF,
|
||||
object_=concat_namespace("?", a),
|
||||
)
|
||||
|
||||
content = class_view.get_mermaid(align=1)
|
||||
logger.debug(content)
|
||||
|
|
@ -132,78 +149,6 @@ class RebuildClassView(Action):
|
|||
|
||||
return content, distinct
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(name: str, language="python"):
|
||||
pattern = re.compile(r"<I>(.*?)</I>")
|
||||
result = re.search(pattern, name)
|
||||
|
||||
abstraction = ""
|
||||
if result:
|
||||
name = result.group(1)
|
||||
abstraction = "*"
|
||||
if name.startswith("__"):
|
||||
visibility = "-"
|
||||
elif name.startswith("_"):
|
||||
visibility = "#"
|
||||
else:
|
||||
visibility = "+"
|
||||
return name, visibility, abstraction
|
||||
|
||||
async def _parse_variable_type(self, ns_name) -> str:
|
||||
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
|
||||
if not rows:
|
||||
return ""
|
||||
vals = rows[0].object_.replace("'", "").split(":")
|
||||
if len(vals) == 1:
|
||||
return ""
|
||||
val = vals[-1].strip()
|
||||
return "" if val == "NoneType" else val + " "
|
||||
|
||||
async def _parse_function_args(self, method: ClassMethod, ns_name: str):
|
||||
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
|
||||
if not rows:
|
||||
return
|
||||
info = rows[0].object_.replace("'", "")
|
||||
|
||||
fs_tag = "("
|
||||
ix = info.find(fs_tag)
|
||||
fe_tag = "):"
|
||||
eix = info.rfind(fe_tag)
|
||||
if eix < 0:
|
||||
fe_tag = ")"
|
||||
eix = info.rfind(fe_tag)
|
||||
args_info = info[ix + len(fs_tag) : eix].strip()
|
||||
method.return_type = info[eix + len(fe_tag) :].strip()
|
||||
if method.return_type == "None":
|
||||
method.return_type = ""
|
||||
if "(" in method.return_type:
|
||||
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
|
||||
|
||||
# parse args
|
||||
if not args_info:
|
||||
return
|
||||
splitter_ixs = []
|
||||
cost = 0
|
||||
for i in range(len(args_info)):
|
||||
if args_info[i] == "[":
|
||||
cost += 1
|
||||
elif args_info[i] == "]":
|
||||
cost -= 1
|
||||
if args_info[i] == "," and cost == 0:
|
||||
splitter_ixs.append(i)
|
||||
splitter_ixs.append(len(args_info))
|
||||
args = []
|
||||
ix = 0
|
||||
for eix in splitter_ixs:
|
||||
args.append(args_info[ix:eix])
|
||||
ix = eix + 1
|
||||
for arg in args:
|
||||
parts = arg.strip().split(":")
|
||||
if len(parts) == 1:
|
||||
method.args.append(ClassAttribute(name=parts[0].strip()))
|
||||
continue
|
||||
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
|
||||
|
||||
@staticmethod
|
||||
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
|
||||
if len(str(path_root)) > len(str(package_root)):
|
||||
|
|
|
|||
|
|
@ -11,14 +11,14 @@ from __future__ import annotations
|
|||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION, GRAPH_REPO_FILE_REPO
|
||||
from metagpt.const import AGGREGATION, GENERALIZATION, GRAPH_REPO_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import ClassView
|
||||
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
|
||||
|
|
@ -62,13 +62,21 @@ class RebuildSequenceView(Action):
|
|||
merged_class_views = set()
|
||||
while True:
|
||||
participants = RebuildSequenceView.parse_participant(sequence_view)
|
||||
class_views = await self._get_class_views(participants)
|
||||
class_views, class_compositions = await self._get_class_views(participants)
|
||||
for compositions in class_compositions.values():
|
||||
for c in compositions:
|
||||
ns, _ = split_namespace(c.object_)
|
||||
if ns == "?":
|
||||
continue
|
||||
await self._parse_class_description(c.object_)
|
||||
diff = set()
|
||||
for cv in class_views:
|
||||
if cv.subject in merged_class_views:
|
||||
continue
|
||||
await self._parse_class_description(cv.subject)
|
||||
sequence_view = await self._merge_sequence_view(sequence_view, cv.subject)
|
||||
sequence_view = await self._merge_sequence_view(
|
||||
sequence_view, cv.subject, class_compositions.get(cv.subject, [])
|
||||
)
|
||||
diff.add(cv.subject)
|
||||
merged_class_views.add(cv.subject)
|
||||
|
||||
|
|
@ -162,7 +170,7 @@ class RebuildSequenceView(Action):
|
|||
stop=stop_after_attempt(6),
|
||||
after=general_after_log(logger),
|
||||
)
|
||||
async def _merge_sequence_view(self, sequence_view, ns_class_name) -> str:
|
||||
async def _merge_sequence_view(self, sequence_view, ns_class_name, compositions) -> str:
|
||||
class_view_part = "```mermaid\n"
|
||||
# add class
|
||||
class_view_part += await self._class_view_to_mermaid(ns_class_name)
|
||||
|
|
@ -177,14 +185,10 @@ class RebuildSequenceView(Action):
|
|||
class_view_part += f"\n\t{me} *-- {name}"
|
||||
class_view_part += await self._class_view_to_mermaid(ns_aggr_name)
|
||||
# add composition relationship
|
||||
rows = await self.graph_db.select(
|
||||
subject=ns_class_name, predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF
|
||||
)
|
||||
compositions = [r.object_ for r in rows]
|
||||
for ns_comp_name in compositions:
|
||||
_, name = split_namespace(ns_comp_name)
|
||||
for c in compositions:
|
||||
_, name = split_namespace(c.object_)
|
||||
class_view_part += f"\n\t{me} *-- {name}"
|
||||
class_view_part += await self._class_view_to_mermaid(ns_comp_name)
|
||||
class_view_part += await self._class_view_to_mermaid(c.object_)
|
||||
|
||||
class_view_part += "\n```"
|
||||
|
||||
|
|
@ -198,7 +202,7 @@ class RebuildSequenceView(Action):
|
|||
system_msgs=[
|
||||
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
|
||||
'Append as much information as possible from the "Mermaid Class View" to the sequence diagram.',
|
||||
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining what information have been merged and why.',
|
||||
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining in detail what information have been merged and why.',
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -225,7 +229,7 @@ class RebuildSequenceView(Action):
|
|||
matches = re.findall(pattern, mermaid_sequence_diagram)
|
||||
return matches
|
||||
|
||||
async def _get_class_views(self, class_names) -> List[SPO]:
|
||||
async def _get_class_views(self, class_names) -> (List[SPO], Dict[str, List[SPO]]):
|
||||
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
|
||||
ns_class_names = {}
|
||||
|
|
@ -235,12 +239,15 @@ class RebuildSequenceView(Action):
|
|||
ns_class_names[r.subject] = class_name
|
||||
|
||||
class_views = []
|
||||
class_compositions = {}
|
||||
for ns_name in ns_class_names.keys():
|
||||
views = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
if not views:
|
||||
continue
|
||||
class_views += views
|
||||
return class_views
|
||||
compositions = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.IS_COMPOSITE_OF)
|
||||
class_compositions[ns_name] = compositions
|
||||
return class_views, class_compositions
|
||||
|
||||
@staticmethod
|
||||
def _desc_to_note(json_str) -> List[str]:
|
||||
|
|
@ -249,8 +256,10 @@ class RebuildSequenceView(Action):
|
|||
m = json.loads(json_str)
|
||||
return [s.replace('"', '\\"').replace("\n", "\\n") for s in m.values()]
|
||||
|
||||
async def _class_view_to_mermaid(self, ns_class_name):
|
||||
async def _class_view_to_mermaid(self, ns_class_name) -> str:
|
||||
class_view_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
if not class_view_rows:
|
||||
return ""
|
||||
result = ClassView.model_validate_json(class_view_rows[0].object_).get_mermaid() if class_view_rows else ""
|
||||
_, name = split_namespace(ns_class_name)
|
||||
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pathlib import Path
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -39,20 +39,231 @@ class CodeBlockInfo(BaseModel):
|
|||
properties: Dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassInfo(BaseModel):
|
||||
class DotClassAttribute(BaseModel):
|
||||
name: str = ""
|
||||
type_: str = ""
|
||||
default_: str = ""
|
||||
description: str
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotClassAttribute":
|
||||
val = ""
|
||||
meet_colon = False
|
||||
meet_equals = False
|
||||
for c in v:
|
||||
if c == ":":
|
||||
meet_colon = True
|
||||
elif c == "=":
|
||||
meet_equals = True
|
||||
if not meet_colon:
|
||||
val += ":"
|
||||
meet_colon = True
|
||||
val += c
|
||||
if not meet_colon:
|
||||
val += ":"
|
||||
if not meet_equals:
|
||||
val += "="
|
||||
|
||||
cix = val.find(":")
|
||||
eix = val.rfind("=")
|
||||
name = val[0:cix].strip()
|
||||
type_ = val[cix + 1 : eix]
|
||||
default_ = val[eix + 1 :].strip()
|
||||
|
||||
type_ = cls.remove_white_spaces(type_) # remove white space
|
||||
if type_ == "NoneType":
|
||||
type_ = ""
|
||||
if "Literal[" in type_:
|
||||
pre_l, literal, post_l = cls._split_literal(type_)
|
||||
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
|
||||
type_ = pre_l + literal + post_l
|
||||
else:
|
||||
type_ = re.sub(r"['\"]", "", type_) # remove '"
|
||||
composition_val = type_
|
||||
|
||||
if default_ == "None":
|
||||
default_ = ""
|
||||
compositions = cls.parse_compositions(composition_val)
|
||||
return cls(name=name, type_=type_, default_=default_, description=v, compositions=compositions)
|
||||
|
||||
@staticmethod
|
||||
def remove_white_spaces(v: str):
|
||||
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
|
||||
|
||||
@staticmethod
|
||||
def parse_compositions(types_part) -> List[str]:
|
||||
if not types_part:
|
||||
return []
|
||||
modified_string = re.sub(r"[\[\],]", "|", types_part)
|
||||
types = modified_string.split("|")
|
||||
filters = {
|
||||
"str",
|
||||
"frozenset",
|
||||
"set",
|
||||
"int",
|
||||
"float",
|
||||
"complex",
|
||||
"bool",
|
||||
"dict",
|
||||
"list",
|
||||
"Union",
|
||||
"Dict",
|
||||
"Set",
|
||||
"Tuple",
|
||||
"NoneType",
|
||||
"None",
|
||||
"Any",
|
||||
"Optional",
|
||||
"Iterator",
|
||||
"Literal",
|
||||
"List",
|
||||
}
|
||||
result = set()
|
||||
for t in types:
|
||||
t = t.strip()
|
||||
if t and t not in filters:
|
||||
result.add(t)
|
||||
return list(result)
|
||||
|
||||
@staticmethod
|
||||
def _split_literal(v):
|
||||
tag = "Literal["
|
||||
bix = v.find(tag)
|
||||
eix = len(v) - 1
|
||||
counter = 1
|
||||
for i in range(bix + len(tag), len(v) - 1):
|
||||
c = v[i]
|
||||
if c == "[":
|
||||
counter += 1
|
||||
continue
|
||||
if c == "]":
|
||||
counter -= 1
|
||||
if counter > 0:
|
||||
continue
|
||||
eix = i
|
||||
break
|
||||
pre_l = v[0:bix]
|
||||
post_l = v[eix + 1 :]
|
||||
pre_l = re.sub(r"['\"]", "", pre_l) # remove '"
|
||||
pos_l = re.sub(r"['\"]", "", post_l) # remove '"
|
||||
|
||||
return pre_l, v[bix : eix + 1], pos_l
|
||||
|
||||
@field_validator("compositions", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class DotClassInfo(BaseModel):
|
||||
name: str
|
||||
package: Optional[str] = None
|
||||
attributes: Dict[str, str] = Field(default_factory=dict)
|
||||
methods: Dict[str, str] = Field(default_factory=dict)
|
||||
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
|
||||
methods: Dict[str, DotClassMethod] = Field(default_factory=dict)
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("compositions", "aggregations", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class ClassRelationship(BaseModel):
|
||||
class DotClassRelationship(BaseModel):
|
||||
src: str = ""
|
||||
dest: str = ""
|
||||
relationship: str = ""
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
class DotReturn(BaseModel):
|
||||
type_: str = ""
|
||||
description: str
|
||||
compositions: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotReturn" | None:
|
||||
if not v:
|
||||
return DotReturn(description=v)
|
||||
type_ = DotClassAttribute.remove_white_spaces(v)
|
||||
compositions = DotClassAttribute.parse_compositions(type_)
|
||||
return cls(type_=type_, description=v, compositions=compositions)
|
||||
|
||||
@field_validator("compositions", mode="after")
|
||||
@classmethod
|
||||
def sort(cls, lst: List) -> List:
|
||||
lst.sort()
|
||||
return lst
|
||||
|
||||
|
||||
class DotClassMethod(BaseModel):
|
||||
name: str
|
||||
args: List[DotClassAttribute] = Field(default_factory=list)
|
||||
return_args: Optional[DotReturn] = None
|
||||
description: str
|
||||
aggregations: List[str] = Field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, v: str) -> "DotClassMethod":
|
||||
bix = v.find("(")
|
||||
eix = v.rfind(")")
|
||||
rix = v.rfind(":")
|
||||
if rix < 0 or rix < eix:
|
||||
rix = eix
|
||||
name_part = v[0:bix].strip()
|
||||
args_part = v[bix + 1 : eix].strip()
|
||||
return_args_part = v[rix + 1 :].strip()
|
||||
|
||||
name = cls._parse_name(name_part)
|
||||
args = cls._parse_args(args_part)
|
||||
return_args = DotReturn.parse(return_args_part)
|
||||
aggregations = set()
|
||||
for i in args:
|
||||
aggregations.update(set(i.compositions))
|
||||
aggregations.update(set(return_args.compositions))
|
||||
|
||||
return cls(name=name, args=args, description=v, return_args=return_args, aggregations=list(aggregations))
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(v: str) -> str:
|
||||
tags = [">", "</"]
|
||||
if tags[0] in v:
|
||||
bix = v.find(tags[0]) + len(tags[0])
|
||||
eix = v.rfind(tags[1])
|
||||
return v[bix:eix].strip()
|
||||
return v.strip()
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(v: str) -> List[DotClassAttribute]:
|
||||
if not v:
|
||||
return []
|
||||
parts = []
|
||||
bix = 0
|
||||
counter = 0
|
||||
for i in range(0, len(v)):
|
||||
c = v[i]
|
||||
if c == "[":
|
||||
counter += 1
|
||||
continue
|
||||
elif c == "]":
|
||||
counter -= 1
|
||||
continue
|
||||
elif c == "," and counter == 0:
|
||||
parts.append(v[bix:i].strip())
|
||||
bix = i + 1
|
||||
parts.append(v[bix:].strip())
|
||||
|
||||
attrs = []
|
||||
for p in parts:
|
||||
if p:
|
||||
attr = DotClassAttribute.parse(p)
|
||||
attrs.append(attr)
|
||||
return attrs
|
||||
|
||||
|
||||
class RepoParser(BaseModel):
|
||||
base_directory: Path = Field(default=None)
|
||||
|
||||
|
|
@ -258,22 +469,28 @@ class RepoParser(BaseModel):
|
|||
if not package_name:
|
||||
continue
|
||||
class_name, members, functions = re.split(r"(?<!\\)\|", info)
|
||||
class_info = ClassInfo(name=class_name)
|
||||
class_info = DotClassInfo(name=class_name)
|
||||
class_info.package = package_name
|
||||
for m in members.split("\n"):
|
||||
if not m:
|
||||
continue
|
||||
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
|
||||
class_info.attributes[member_name] = m
|
||||
attr = DotClassAttribute.parse(m)
|
||||
class_info.attributes[attr.name] = attr
|
||||
for i in attr.compositions:
|
||||
if i not in class_info.compositions:
|
||||
class_info.compositions.append(i)
|
||||
for f in functions.split("\n"):
|
||||
if not f:
|
||||
continue
|
||||
function_name, _ = f.split("(", 1)
|
||||
class_info.methods[function_name] = f
|
||||
method = DotClassMethod.parse(f)
|
||||
class_info.methods[method.name] = method
|
||||
for i in method.aggregations:
|
||||
if i not in class_info.compositions and i not in class_info.aggregations:
|
||||
class_info.aggregations.append(i)
|
||||
class_views.append(class_info)
|
||||
return class_views
|
||||
|
||||
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
|
||||
async def _parse_class_relationships(self, class_view_pathname) -> List[DotClassRelationship]:
|
||||
relationship_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return relationship_views
|
||||
|
|
@ -312,7 +529,7 @@ class RepoParser(BaseModel):
|
|||
if tag not in line:
|
||||
return None
|
||||
idxs.append(line.find(tag))
|
||||
ret = ClassRelationship()
|
||||
ret = DotClassRelationship()
|
||||
ret.src = line[0 : idxs[0]].strip('"')
|
||||
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
|
||||
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
|
||||
|
|
@ -363,8 +580,8 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _repair_namespaces(
|
||||
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
|
||||
) -> (List[ClassInfo], List[ClassRelationship], str):
|
||||
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
|
||||
) -> (List[DotClassInfo], List[DotClassRelationship], str):
|
||||
if not class_views:
|
||||
return [], [], ""
|
||||
c = class_views[0]
|
||||
|
|
|
|||
|
|
@ -471,44 +471,54 @@ class BugFixContext(BaseContext):
|
|||
|
||||
|
||||
# mermaid class view
|
||||
class ClassMeta(BaseModel):
|
||||
class UMLClassMeta(BaseModel):
|
||||
name: str = ""
|
||||
abstraction: bool = False
|
||||
static: bool = False
|
||||
visibility: str = ""
|
||||
|
||||
@staticmethod
|
||||
def name_to_visibility(name: str) -> str:
|
||||
if name == "__init__":
|
||||
return "+"
|
||||
if name.startswith("__"):
|
||||
return "-"
|
||||
elif name.startswith("_"):
|
||||
return "#"
|
||||
return "+"
|
||||
|
||||
class ClassAttribute(ClassMeta):
|
||||
|
||||
class UMLClassAttribute(UMLClassMeta):
|
||||
value_type: str = ""
|
||||
default_value: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
if self.value_type:
|
||||
content += self.value_type + " "
|
||||
content += self.name
|
||||
content += self.value_type.replace(" ", "") + " "
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name
|
||||
if self.default_value:
|
||||
content += "="
|
||||
if self.value_type not in ["str", "string", "String"]:
|
||||
content += self.default_value
|
||||
else:
|
||||
content += '"' + self.default_value.replace('"', "") + '"'
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
# if self.abstraction:
|
||||
# content += "*"
|
||||
# if self.static:
|
||||
# content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassMethod(ClassMeta):
|
||||
args: List[ClassAttribute] = Field(default_factory=list)
|
||||
class UMLClassMethod(UMLClassMeta):
|
||||
args: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
return_type: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
|
||||
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
if self.return_type:
|
||||
content += ":" + self.return_type
|
||||
content += " " + self.return_type.replace(" ", "")
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
|
|
@ -516,9 +526,9 @@ class ClassMethod(ClassMeta):
|
|||
return content
|
||||
|
||||
|
||||
class ClassView(ClassMeta):
|
||||
attributes: List[ClassAttribute] = Field(default_factory=list)
|
||||
methods: List[ClassMethod] = Field(default_factory=list)
|
||||
class UMLClassView(UMLClassMeta):
|
||||
attributes: List[UMLClassAttribute] = Field(default_factory=list)
|
||||
methods: List[UMLClassMethod] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace, split_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
|
|
@ -28,17 +28,17 @@ class GraphKeyword:
|
|||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_METHOD = "class_method"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_METHOD = "has_class_method"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_CLASS_DESC = "has_class_desc"
|
||||
HAS_DETAIL = "has_detail"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
HAS_SEQUENCE_VIEW = "has_sequence_view"
|
||||
HAS_ARGS_DESC = "has_args_desc"
|
||||
HAS_TYPE_DESC = "has_type_desc"
|
||||
IS_COMPOSITE_OF = "is_composite_of"
|
||||
IS_AGGREGATE_OF = "is_aggregate_of"
|
||||
|
||||
|
||||
class SPO(BaseModel):
|
||||
|
|
@ -96,13 +96,13 @@ class GraphRepository(ABC):
|
|||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(file_info.file, class_name, fn),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
# file -> function
|
||||
|
|
@ -134,7 +134,7 @@ class GraphRepository(ABC):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
|
||||
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
|
||||
for c in class_views:
|
||||
filename, _ = c.package.split(":", 1)
|
||||
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
|
||||
|
|
@ -147,6 +147,7 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
|
||||
for vn, vt in c.attributes.items():
|
||||
# class -> property
|
||||
await graph_db.insert(
|
||||
|
|
@ -161,32 +162,40 @@ class GraphRepository(ABC):
|
|||
object_=GraphKeyword.CLASS_PROPERTY,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=vt.model_dump_json(),
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
if "</I>" in desc and "<I>" not in desc:
|
||||
logger.error(desc)
|
||||
for fn, ft in c.methods.items():
|
||||
# class -> function
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
predicate=GraphKeyword.HAS_CLASS_METHOD,
|
||||
object_=concat_namespace(c.package, fn),
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
object_=GraphKeyword.CLASS_METHOD,
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
predicate=GraphKeyword.HAS_DETAIL,
|
||||
object_=ft.model_dump_json(),
|
||||
)
|
||||
for i in c.compositions:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
for i in c.aggregations:
|
||||
await graph_db.insert(
|
||||
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_relationship_views(
|
||||
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
|
||||
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
|
||||
):
|
||||
for r in relationship_views:
|
||||
await graph_db.insert(
|
||||
|
|
@ -199,3 +208,23 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
|
||||
object_=concat_namespace(r.dest, r.label),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
|
||||
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
mapping = defaultdict(list)
|
||||
for c in classes:
|
||||
_, name = split_namespace(c.subject)
|
||||
mapping[name].append(c.subject)
|
||||
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
|
||||
for r in rows:
|
||||
ns, class_ = split_namespace(r.object_)
|
||||
if ns != "?":
|
||||
continue
|
||||
val = mapping[class_]
|
||||
if len(val) != 1:
|
||||
continue
|
||||
ns_name = val[0]
|
||||
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
|
||||
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -1,9 +1,11 @@
|
|||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.repo_parser import DotClassAttribute, DotClassMethod, DotReturn, RepoParser
|
||||
|
||||
|
||||
def test_repo_parser():
|
||||
|
|
@ -23,3 +25,140 @@ def test_error():
|
|||
"""_parse_file should return empty list when file not existed"""
|
||||
rsp = RepoParser._parse_file(Path("test_not_existed_file.py"))
|
||||
assert rsp == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("v", "name", "type_", "default_", "compositions"),
|
||||
[
|
||||
("children : dict[str, 'ActionNode']", "children", "dict[str,ActionNode]", "", ["ActionNode"]),
|
||||
("context : str", "context", "str", "", []),
|
||||
("example", "example", "", "", []),
|
||||
("expected_type : Type", "expected_type", "Type", "", ["Type"]),
|
||||
("args : Optional[Dict]", "args", "Optional[Dict]", "", []),
|
||||
("rsp : Optional[Message] = Message.Default", "rsp", "Optional[Message]", "Message.Default", ["Message"]),
|
||||
(
|
||||
"browser : Literal['chrome', 'firefox', 'edge', 'ie']",
|
||||
"browser",
|
||||
"Literal['chrome','firefox','edge','ie']",
|
||||
"",
|
||||
[],
|
||||
),
|
||||
(
|
||||
"browser : Dict[ Message, Literal['chrome', 'firefox', 'edge', 'ie'] ]",
|
||||
"browser",
|
||||
"Dict[Message,Literal['chrome','firefox','edge','ie']]",
|
||||
"",
|
||||
["Message"],
|
||||
),
|
||||
("attributes : List[ClassAttribute]", "attributes", "List[ClassAttribute]", "", ["ClassAttribute"]),
|
||||
("attributes = []", "attributes", "", "[]", []),
|
||||
(
|
||||
"request_timeout: Optional[Union[float, Tuple[float, float]]]",
|
||||
"request_timeout",
|
||||
"Optional[Union[float,Tuple[float,float]]]",
|
||||
"",
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_member(v, name, type_, default_, compositions):
|
||||
attr = DotClassAttribute.parse(v)
|
||||
assert name == attr.name
|
||||
assert type_ == attr.type_
|
||||
assert default_ == attr.default_
|
||||
assert compositions == attr.compositions
|
||||
assert v == attr.description
|
||||
|
||||
json_data = attr.model_dump_json()
|
||||
v = DotClassAttribute.model_validate_json(json_data)
|
||||
assert v == attr
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("line", "package_name", "info"),
|
||||
[
|
||||
(
|
||||
'"metagpt.roles.architect.Architect" [color="black", fontcolor="black", label=<{Architect|constraints : str<br ALIGN="LEFT"/>goal : str<br ALIGN="LEFT"/>name : str<br ALIGN="LEFT"/>profile : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
|
||||
"metagpt.roles.architect.Architect",
|
||||
"Architect|constraints : str\ngoal : str\nname : str\nprofile : str\n|",
|
||||
),
|
||||
(
|
||||
'"metagpt.actions.skill_action.ArgumentsParingAction" [color="black", fontcolor="black", label=<{ArgumentsParingAction|args : Optional[Dict]<br ALIGN="LEFT"/>ask : str<br ALIGN="LEFT"/>prompt<br ALIGN="LEFT"/>rsp : Optional[Message]<br ALIGN="LEFT"/>skill<br ALIGN="LEFT"/>|parse_arguments(skill_name, txt): dict<br ALIGN="LEFT"/>run(with_message): Message<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
|
||||
"metagpt.actions.skill_action.ArgumentsParingAction",
|
||||
"ArgumentsParingAction|args : Optional[Dict]\nask : str\nprompt\nrsp : Optional[Message]\nskill\n|parse_arguments(skill_name, txt): dict\nrun(with_message): Message\n",
|
||||
),
|
||||
(
|
||||
'"metagpt.strategy.base.BaseEvaluator" [color="black", fontcolor="black", label=<{BaseEvaluator|<br ALIGN="LEFT"/>|<I>status_verify</I>()<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
|
||||
"metagpt.strategy.base.BaseEvaluator",
|
||||
"BaseEvaluator|\n|<I>status_verify</I>()\n",
|
||||
),
|
||||
(
|
||||
'"metagpt.configs.browser_config.BrowserConfig" [color="black", fontcolor="black", label=<{BrowserConfig|browser : Literal[\'chrome\', \'firefox\', \'edge\', \'ie\']<br ALIGN="LEFT"/>driver : Literal[\'chromium\', \'firefox\', \'webkit\']<br ALIGN="LEFT"/>engine<br ALIGN="LEFT"/>path : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
|
||||
"metagpt.configs.browser_config.BrowserConfig",
|
||||
"BrowserConfig|browser : Literal['chrome', 'firefox', 'edge', 'ie']\ndriver : Literal['chromium', 'firefox', 'webkit']\nengine\npath : str\n|",
|
||||
),
|
||||
(
|
||||
'"metagpt.tools.search_engine_serpapi.SerpAPIWrapper" [color="black", fontcolor="black", label=<{SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]<br ALIGN="LEFT"/>model_config<br ALIGN="LEFT"/>params : dict<br ALIGN="LEFT"/>search_engine : Optional[Any]<br ALIGN="LEFT"/>serpapi_api_key : Optional[str]<br ALIGN="LEFT"/>|check_serpapi_api_key(val: str)<br ALIGN="LEFT"/>get_params(query: str): Dict[str, str]<br ALIGN="LEFT"/>results(query: str, max_results: int): dict<br ALIGN="LEFT"/>run(query, max_results: int, as_string: bool): str<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
|
||||
"metagpt.tools.search_engine_serpapi.SerpAPIWrapper",
|
||||
"SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]\nmodel_config\nparams : dict\nsearch_engine : Optional[Any]\nserpapi_api_key : Optional[str]\n|check_serpapi_api_key(val: str)\nget_params(query: str): Dict[str, str]\nresults(query: str, max_results: int): dict\nrun(query, max_results: int, as_string: bool): str\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_class_line(line, package_name, info):
|
||||
p, i = RepoParser._split_class_line(line)
|
||||
assert p == package_name
|
||||
assert i == info
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("v", "name", "args", "return_args"),
|
||||
[
|
||||
(
|
||||
"<I>arequest</I>(method, url, params, headers, files, stream: Literal[True], request_id: Optional[str], request_timeout: Optional[Union[float, Tuple[float, float]]]): Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
|
||||
"arequest",
|
||||
[
|
||||
DotClassAttribute(name="method", description="method"),
|
||||
DotClassAttribute(name="url", description="url"),
|
||||
DotClassAttribute(name="params", description="params"),
|
||||
DotClassAttribute(name="headers", description="headers"),
|
||||
DotClassAttribute(name="files", description="files"),
|
||||
DotClassAttribute(name="stream", type_="Literal[True]", description="stream: Literal[True]"),
|
||||
DotClassAttribute(name="request_id", type_="Optional[str]", description="request_id: Optional[str]"),
|
||||
DotClassAttribute(
|
||||
name="request_timeout",
|
||||
type_="Optional[Union[float,Tuple[float,float]]]",
|
||||
description="request_timeout: Optional[Union[float, Tuple[float, float]]]",
|
||||
),
|
||||
],
|
||||
DotReturn(
|
||||
type_="Tuple[AsyncGenerator[OpenAIResponse,None],bool,str]",
|
||||
compositions=["AsyncGenerator", "OpenAIResponse"],
|
||||
description="Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
|
||||
),
|
||||
),
|
||||
(
|
||||
"<I>update</I>(subject: str, predicate: str, object_: str)",
|
||||
"update",
|
||||
[
|
||||
DotClassAttribute(name="subject", type_="str", description="subject: str"),
|
||||
DotClassAttribute(name="predicate", type_="str", description="predicate: str"),
|
||||
DotClassAttribute(name="object_", type_="str", description="object_: str"),
|
||||
],
|
||||
DotReturn(description=""),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_method(v, name, args, return_args):
|
||||
method = DotClassMethod.parse(v)
|
||||
assert method.name == name
|
||||
assert method.args == args
|
||||
assert method.return_args == return_args
|
||||
assert method.description == v
|
||||
|
||||
json_data = method.model_dump_json()
|
||||
v = DotClassMethod.model_validate_json(json_data)
|
||||
assert v == method
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue