fixbug: class view

fixbug: method type, feat: add compotition&aggregation
This commit is contained in:
莘权 马 2024-01-22 22:49:46 +08:00
parent 9e84e63529
commit 633c772529
7 changed files with 510 additions and 161 deletions

View file

@ -6,7 +6,7 @@
@File : rebuild_class_view.py
@Desc : Rebuild class view info
"""
import re
from pathlib import Path
from typing import Optional
@ -22,9 +22,9 @@ from metagpt.const import (
GRAPH_REPO_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
from metagpt.utils.common import split_namespace
from metagpt.repo_parser import DotClassInfo, RepoParser
from metagpt.schema import UMLClassAttribute, UMLClassMethod, UMLClassView
from metagpt.utils.common import concat_namespace, split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
@ -40,6 +40,7 @@ class RebuildClassView(Action):
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context))
await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views)
await GraphRepository.rebuild_composition_relationship(self.graph_db)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
@ -81,24 +82,40 @@ class RebuildClassView(Action):
# Ignore sub-class
return ""
class_view = ClassView(name=fields[1])
rows = await self.graph_db.select(subject=ns_class_name)
for r in rows:
name = split_namespace(r.object_)[-1]
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
var_type = await self._parse_variable_type(r.object_)
attribute = ClassAttribute(
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
)
class_view.attributes.append(attribute)
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
await self._parse_function_args(method, r.object_)
class_view.methods.append(method)
rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL)
if not rows:
return ""
dot_class_info = DotClassInfo.model_validate_json(rows[0].object_)
visibility = UMLClassView.name_to_visibility(dot_class_info.name)
class_view = UMLClassView(name=dot_class_info.name, visibility=visibility)
for i in dot_class_info.attributes.values():
visibility = UMLClassAttribute.name_to_visibility(i.name)
attr = UMLClassAttribute(name=i.name, visibility=visibility, value_type=i.type_, default_value=i.default_)
class_view.attributes.append(attr)
for i in dot_class_info.methods.values():
visibility = UMLClassMethod.name_to_visibility(i.name)
method = UMLClassMethod(name=i.name, visibility=visibility, return_type=i.return_args.type_)
for j in i.args:
arg = UMLClassAttribute(name=j.name, value_type=j.type_, default_value=j.default_)
method.args.append(arg)
# update graph db
# update uml view
await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
# update uml isCompositeOf
for c in dot_class_info.compositions:
await self.graph_db.insert(
subject=ns_class_name,
predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF,
object_=concat_namespace("?", c),
)
# update uml isAggregateOf
for a in dot_class_info.aggregations:
await self.graph_db.insert(
subject=ns_class_name,
predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF,
object_=concat_namespace("?", a),
)
content = class_view.get_mermaid(align=1)
logger.debug(content)
@ -132,78 +149,6 @@ class RebuildClassView(Action):
return content, distinct
@staticmethod
def _parse_name(name: str, language="python"):
pattern = re.compile(r"<I>(.*?)</I>")
result = re.search(pattern, name)
abstraction = ""
if result:
name = result.group(1)
abstraction = "*"
if name.startswith("__"):
visibility = "-"
elif name.startswith("_"):
visibility = "#"
else:
visibility = "+"
return name, visibility, abstraction
async def _parse_variable_type(self, ns_name) -> str:
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
if not rows:
return ""
vals = rows[0].object_.replace("'", "").split(":")
if len(vals) == 1:
return ""
val = vals[-1].strip()
return "" if val == "NoneType" else val + " "
async def _parse_function_args(self, method: ClassMethod, ns_name: str):
rows = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
if not rows:
return
info = rows[0].object_.replace("'", "")
fs_tag = "("
ix = info.find(fs_tag)
fe_tag = "):"
eix = info.rfind(fe_tag)
if eix < 0:
fe_tag = ")"
eix = info.rfind(fe_tag)
args_info = info[ix + len(fs_tag) : eix].strip()
method.return_type = info[eix + len(fe_tag) :].strip()
if method.return_type == "None":
method.return_type = ""
if "(" in method.return_type:
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
# parse args
if not args_info:
return
splitter_ixs = []
cost = 0
for i in range(len(args_info)):
if args_info[i] == "[":
cost += 1
elif args_info[i] == "]":
cost -= 1
if args_info[i] == "," and cost == 0:
splitter_ixs.append(i)
splitter_ixs.append(len(args_info))
args = []
ix = 0
for eix in splitter_ixs:
args.append(args_info[ix:eix])
ix = eix + 1
for arg in args:
parts = arg.strip().split(":")
if len(parts) == 1:
method.args.append(ClassAttribute(name=parts[0].strip()))
continue
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
@staticmethod
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
if len(str(path_root)) > len(str(package_root)):

View file

@ -11,14 +11,14 @@ from __future__ import annotations
import json
import re
from pathlib import Path
from typing import List, Optional
from typing import Dict, List, Optional
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION, GRAPH_REPO_FILE_REPO
from metagpt.const import AGGREGATION, GENERALIZATION, GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import ClassView
from metagpt.utils.common import aread, general_after_log, list_files, split_namespace
@ -62,13 +62,21 @@ class RebuildSequenceView(Action):
merged_class_views = set()
while True:
participants = RebuildSequenceView.parse_participant(sequence_view)
class_views = await self._get_class_views(participants)
class_views, class_compositions = await self._get_class_views(participants)
for compositions in class_compositions.values():
for c in compositions:
ns, _ = split_namespace(c.object_)
if ns == "?":
continue
await self._parse_class_description(c.object_)
diff = set()
for cv in class_views:
if cv.subject in merged_class_views:
continue
await self._parse_class_description(cv.subject)
sequence_view = await self._merge_sequence_view(sequence_view, cv.subject)
sequence_view = await self._merge_sequence_view(
sequence_view, cv.subject, class_compositions.get(cv.subject, [])
)
diff.add(cv.subject)
merged_class_views.add(cv.subject)
@ -162,7 +170,7 @@ class RebuildSequenceView(Action):
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _merge_sequence_view(self, sequence_view, ns_class_name) -> str:
async def _merge_sequence_view(self, sequence_view, ns_class_name, compositions) -> str:
class_view_part = "```mermaid\n"
# add class
class_view_part += await self._class_view_to_mermaid(ns_class_name)
@ -177,14 +185,10 @@ class RebuildSequenceView(Action):
class_view_part += f"\n\t{me} *-- {name}"
class_view_part += await self._class_view_to_mermaid(ns_aggr_name)
# add composition relationship
rows = await self.graph_db.select(
subject=ns_class_name, predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF
)
compositions = [r.object_ for r in rows]
for ns_comp_name in compositions:
_, name = split_namespace(ns_comp_name)
for c in compositions:
_, name = split_namespace(c.object_)
class_view_part += f"\n\t{me} *-- {name}"
class_view_part += await self._class_view_to_mermaid(ns_comp_name)
class_view_part += await self._class_view_to_mermaid(c.object_)
class_view_part += "\n```"
@ -198,7 +202,7 @@ class RebuildSequenceView(Action):
system_msgs=[
"You are a tool to merge Mermaid class view information into the Mermaid sequence view.",
'Append as much information as possible from the "Mermaid Class View" to the sequence diagram.',
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining what information have been merged and why.',
'Return a markdown JSON format with a "sequence_diagram" key containing the merged Mermaid sequence view, a "reason" key explaining in detail what information have been merged and why.',
],
)
@ -225,7 +229,7 @@ class RebuildSequenceView(Action):
matches = re.findall(pattern, mermaid_sequence_diagram)
return matches
async def _get_class_views(self, class_names) -> List[SPO]:
async def _get_class_views(self, class_names) -> (List[SPO], Dict[str, List[SPO]]):
rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
ns_class_names = {}
@ -235,12 +239,15 @@ class RebuildSequenceView(Action):
ns_class_names[r.subject] = class_name
class_views = []
class_compositions = {}
for ns_name in ns_class_names.keys():
views = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not views:
continue
class_views += views
return class_views
compositions = await self.graph_db.select(subject=ns_name, predicate=GraphKeyword.IS_COMPOSITE_OF)
class_compositions[ns_name] = compositions
return class_views, class_compositions
@staticmethod
def _desc_to_note(json_str) -> List[str]:
@ -249,8 +256,10 @@ class RebuildSequenceView(Action):
m = json.loads(json_str)
return [s.replace('"', '\\"').replace("\n", "\\n") for s in m.values()]
async def _class_view_to_mermaid(self, ns_class_name):
async def _class_view_to_mermaid(self, ns_class_name) -> str:
class_view_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW)
if not class_view_rows:
return ""
result = ClassView.model_validate_json(class_view_rows[0].object_).get_mermaid() if class_view_rows else ""
_, name = split_namespace(ns_class_name)
class_desc_rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_DESC)

View file

@ -15,7 +15,7 @@ from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
@ -39,20 +39,231 @@ class CodeBlockInfo(BaseModel):
properties: Dict = Field(default_factory=dict)
class ClassInfo(BaseModel):
class DotClassAttribute(BaseModel):
name: str = ""
type_: str = ""
default_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotClassAttribute":
val = ""
meet_colon = False
meet_equals = False
for c in v:
if c == ":":
meet_colon = True
elif c == "=":
meet_equals = True
if not meet_colon:
val += ":"
meet_colon = True
val += c
if not meet_colon:
val += ":"
if not meet_equals:
val += "="
cix = val.find(":")
eix = val.rfind("=")
name = val[0:cix].strip()
type_ = val[cix + 1 : eix]
default_ = val[eix + 1 :].strip()
type_ = cls.remove_white_spaces(type_) # remove white space
if type_ == "NoneType":
type_ = ""
if "Literal[" in type_:
pre_l, literal, post_l = cls._split_literal(type_)
composition_val = pre_l + "Literal" + post_l # replace Literal[...] with Literal
type_ = pre_l + literal + post_l
else:
type_ = re.sub(r"['\"]", "", type_) # remove '"
composition_val = type_
if default_ == "None":
default_ = ""
compositions = cls.parse_compositions(composition_val)
return cls(name=name, type_=type_, default_=default_, description=v, compositions=compositions)
@staticmethod
def remove_white_spaces(v: str):
return re.sub(r"(?<!['\"])\s|(?<=['\"])\s", "", v)
@staticmethod
def parse_compositions(types_part) -> List[str]:
if not types_part:
return []
modified_string = re.sub(r"[\[\],]", "|", types_part)
types = modified_string.split("|")
filters = {
"str",
"frozenset",
"set",
"int",
"float",
"complex",
"bool",
"dict",
"list",
"Union",
"Dict",
"Set",
"Tuple",
"NoneType",
"None",
"Any",
"Optional",
"Iterator",
"Literal",
"List",
}
result = set()
for t in types:
t = t.strip()
if t and t not in filters:
result.add(t)
return list(result)
@staticmethod
def _split_literal(v):
tag = "Literal["
bix = v.find(tag)
eix = len(v) - 1
counter = 1
for i in range(bix + len(tag), len(v) - 1):
c = v[i]
if c == "[":
counter += 1
continue
if c == "]":
counter -= 1
if counter > 0:
continue
eix = i
break
pre_l = v[0:bix]
post_l = v[eix + 1 :]
pre_l = re.sub(r"['\"]", "", pre_l) # remove '"
pos_l = re.sub(r"['\"]", "", post_l) # remove '"
return pre_l, v[bix : eix + 1], pos_l
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
lst.sort()
return lst
class DotClassInfo(BaseModel):
name: str
package: Optional[str] = None
attributes: Dict[str, str] = Field(default_factory=dict)
methods: Dict[str, str] = Field(default_factory=dict)
attributes: Dict[str, DotClassAttribute] = Field(default_factory=dict)
methods: Dict[str, DotClassMethod] = Field(default_factory=dict)
compositions: List[str] = Field(default_factory=list)
aggregations: List[str] = Field(default_factory=list)
@field_validator("compositions", "aggregations", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
lst.sort()
return lst
class ClassRelationship(BaseModel):
class DotClassRelationship(BaseModel):
src: str = ""
dest: str = ""
relationship: str = ""
label: Optional[str] = None
class DotReturn(BaseModel):
type_: str = ""
description: str
compositions: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotReturn" | None:
if not v:
return DotReturn(description=v)
type_ = DotClassAttribute.remove_white_spaces(v)
compositions = DotClassAttribute.parse_compositions(type_)
return cls(type_=type_, description=v, compositions=compositions)
@field_validator("compositions", mode="after")
@classmethod
def sort(cls, lst: List) -> List:
lst.sort()
return lst
class DotClassMethod(BaseModel):
name: str
args: List[DotClassAttribute] = Field(default_factory=list)
return_args: Optional[DotReturn] = None
description: str
aggregations: List[str] = Field(default_factory=list)
@classmethod
def parse(cls, v: str) -> "DotClassMethod":
bix = v.find("(")
eix = v.rfind(")")
rix = v.rfind(":")
if rix < 0 or rix < eix:
rix = eix
name_part = v[0:bix].strip()
args_part = v[bix + 1 : eix].strip()
return_args_part = v[rix + 1 :].strip()
name = cls._parse_name(name_part)
args = cls._parse_args(args_part)
return_args = DotReturn.parse(return_args_part)
aggregations = set()
for i in args:
aggregations.update(set(i.compositions))
aggregations.update(set(return_args.compositions))
return cls(name=name, args=args, description=v, return_args=return_args, aggregations=list(aggregations))
@staticmethod
def _parse_name(v: str) -> str:
tags = [">", "</"]
if tags[0] in v:
bix = v.find(tags[0]) + len(tags[0])
eix = v.rfind(tags[1])
return v[bix:eix].strip()
return v.strip()
@staticmethod
def _parse_args(v: str) -> List[DotClassAttribute]:
if not v:
return []
parts = []
bix = 0
counter = 0
for i in range(0, len(v)):
c = v[i]
if c == "[":
counter += 1
continue
elif c == "]":
counter -= 1
continue
elif c == "," and counter == 0:
parts.append(v[bix:i].strip())
bix = i + 1
parts.append(v[bix:].strip())
attrs = []
for p in parts:
if p:
attr = DotClassAttribute.parse(p)
attrs.append(attr)
return attrs
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@ -258,22 +469,28 @@ class RepoParser(BaseModel):
if not package_name:
continue
class_name, members, functions = re.split(r"(?<!\\)\|", info)
class_info = ClassInfo(name=class_name)
class_info = DotClassInfo(name=class_name)
class_info.package = package_name
for m in members.split("\n"):
if not m:
continue
member_name = m.split(":", 1)[0].strip() if ":" in m else m.strip()
class_info.attributes[member_name] = m
attr = DotClassAttribute.parse(m)
class_info.attributes[attr.name] = attr
for i in attr.compositions:
if i not in class_info.compositions:
class_info.compositions.append(i)
for f in functions.split("\n"):
if not f:
continue
function_name, _ = f.split("(", 1)
class_info.methods[function_name] = f
method = DotClassMethod.parse(f)
class_info.methods[method.name] = method
for i in method.aggregations:
if i not in class_info.compositions and i not in class_info.aggregations:
class_info.aggregations.append(i)
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
async def _parse_class_relationships(self, class_view_pathname) -> List[DotClassRelationship]:
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
@ -312,7 +529,7 @@ class RepoParser(BaseModel):
if tag not in line:
return None
idxs.append(line.find(tag))
ret = ClassRelationship()
ret = DotClassRelationship()
ret.src = line[0 : idxs[0]].strip('"')
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
@ -363,8 +580,8 @@ class RepoParser(BaseModel):
@staticmethod
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationship], str):
class_views: List[DotClassInfo], relationship_views: List[DotClassRelationship], path: str | Path
) -> (List[DotClassInfo], List[DotClassRelationship], str):
if not class_views:
return [], [], ""
c = class_views[0]

View file

@ -471,44 +471,54 @@ class BugFixContext(BaseContext):
# mermaid class view
class ClassMeta(BaseModel):
class UMLClassMeta(BaseModel):
name: str = ""
abstraction: bool = False
static: bool = False
visibility: str = ""
@staticmethod
def name_to_visibility(name: str) -> str:
if name == "__init__":
return "+"
if name.startswith("__"):
return "-"
elif name.startswith("_"):
return "#"
return "+"
class ClassAttribute(ClassMeta):
class UMLClassAttribute(UMLClassMeta):
value_type: str = ""
default_value: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
if self.value_type:
content += self.value_type + " "
content += self.name
content += self.value_type.replace(" ", "") + " "
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
content += name
if self.default_value:
content += "="
if self.value_type not in ["str", "string", "String"]:
content += self.default_value
else:
content += '"' + self.default_value.replace('"', "") + '"'
if self.abstraction:
content += "*"
if self.static:
content += "$"
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
return content
class ClassMethod(ClassMeta):
args: List[ClassAttribute] = Field(default_factory=list)
class UMLClassMethod(UMLClassMeta):
args: List[UMLClassAttribute] = Field(default_factory=list)
return_type: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
name = self.name.split(":", 1)[1] if ":" in self.name else self.name
content += name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
if self.return_type:
content += ":" + self.return_type
content += " " + self.return_type.replace(" ", "")
if self.abstraction:
content += "*"
if self.static:
@ -516,9 +526,9 @@ class ClassMethod(ClassMeta):
return content
class ClassView(ClassMeta):
attributes: List[ClassAttribute] = Field(default_factory=list)
methods: List[ClassMethod] = Field(default_factory=list)
class UMLClassView(UMLClassMeta):
attributes: List[UMLClassAttribute] = Field(default_factory=list)
methods: List[UMLClassMethod] = Field(default_factory=list)
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"

View file

@ -8,14 +8,14 @@
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import List
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
from metagpt.repo_parser import DotClassInfo, DotClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace, split_namespace
class GraphKeyword:
@ -28,17 +28,17 @@ class GraphKeyword:
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_METHOD = "class_method"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_METHOD = "has_class_method"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_CLASS_DESC = "has_class_desc"
HAS_DETAIL = "has_detail"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
HAS_SEQUENCE_VIEW = "has_sequence_view"
HAS_ARGS_DESC = "has_args_desc"
HAS_TYPE_DESC = "has_type_desc"
IS_COMPOSITE_OF = "is_composite_of"
IS_AGGREGATE_OF = "is_aggregate_of"
class SPO(BaseModel):
@ -96,13 +96,13 @@ class GraphRepository(ABC):
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
for f in file_info.functions:
# file -> function
@ -134,7 +134,7 @@ class GraphRepository(ABC):
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[DotClassInfo]):
for c in class_views:
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
@ -147,6 +147,7 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
await graph_db.insert(subject=c.package, predicate=GraphKeyword.HAS_DETAIL, object_=c.model_dump_json())
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
@ -161,32 +162,40 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS_PROPERTY,
)
await graph_db.insert(
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.HAS_DETAIL,
object_=vt.model_dump_json(),
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
for fn, ft in c.methods.items():
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
predicate=GraphKeyword.HAS_CLASS_METHOD,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
object_=GraphKeyword.CLASS_METHOD,
)
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
predicate=GraphKeyword.HAS_DETAIL,
object_=ft.model_dump_json(),
)
for i in c.compositions:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_COMPOSITE_OF, object_=concat_namespace("?", i)
)
for i in c.aggregations:
await graph_db.insert(
subject=c.package, predicate=GraphKeyword.IS_AGGREGATE_OF, object_=concat_namespace("?", i)
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
graph_db: "GraphRepository", relationship_views: List[DotClassRelationship]
):
for r in relationship_views:
await graph_db.insert(
@ -199,3 +208,23 @@ class GraphRepository(ABC):
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)
@staticmethod
async def rebuild_composition_relationship(graph_db: "GraphRepository"):
classes = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
mapping = defaultdict(list)
for c in classes:
_, name = split_namespace(c.subject)
mapping[name].append(c.subject)
rows = await graph_db.select(predicate=GraphKeyword.IS_COMPOSITE_OF)
for r in rows:
ns, class_ = split_namespace(r.object_)
if ns != "?":
continue
val = mapping[class_]
if len(val) != 1:
continue
ns_name = val[0]
await graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await graph_db.insert(subject=r.subject, predicate=r.predicate, object_=ns_name)

File diff suppressed because one or more lines are too long

View file

@ -1,9 +1,11 @@
from pathlib import Path
from pprint import pformat
import pytest
from metagpt.const import METAGPT_ROOT
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.repo_parser import DotClassAttribute, DotClassMethod, DotReturn, RepoParser
def test_repo_parser():
@ -23,3 +25,140 @@ def test_error():
"""_parse_file should return empty list when file not existed"""
rsp = RepoParser._parse_file(Path("test_not_existed_file.py"))
assert rsp == []
@pytest.mark.parametrize(
("v", "name", "type_", "default_", "compositions"),
[
("children : dict[str, 'ActionNode']", "children", "dict[str,ActionNode]", "", ["ActionNode"]),
("context : str", "context", "str", "", []),
("example", "example", "", "", []),
("expected_type : Type", "expected_type", "Type", "", ["Type"]),
("args : Optional[Dict]", "args", "Optional[Dict]", "", []),
("rsp : Optional[Message] = Message.Default", "rsp", "Optional[Message]", "Message.Default", ["Message"]),
(
"browser : Literal['chrome', 'firefox', 'edge', 'ie']",
"browser",
"Literal['chrome','firefox','edge','ie']",
"",
[],
),
(
"browser : Dict[ Message, Literal['chrome', 'firefox', 'edge', 'ie'] ]",
"browser",
"Dict[Message,Literal['chrome','firefox','edge','ie']]",
"",
["Message"],
),
("attributes : List[ClassAttribute]", "attributes", "List[ClassAttribute]", "", ["ClassAttribute"]),
("attributes = []", "attributes", "", "[]", []),
(
"request_timeout: Optional[Union[float, Tuple[float, float]]]",
"request_timeout",
"Optional[Union[float,Tuple[float,float]]]",
"",
[],
),
],
)
def test_parse_member(v, name, type_, default_, compositions):
attr = DotClassAttribute.parse(v)
assert name == attr.name
assert type_ == attr.type_
assert default_ == attr.default_
assert compositions == attr.compositions
assert v == attr.description
json_data = attr.model_dump_json()
v = DotClassAttribute.model_validate_json(json_data)
assert v == attr
@pytest.mark.parametrize(
("line", "package_name", "info"),
[
(
'"metagpt.roles.architect.Architect" [color="black", fontcolor="black", label=<{Architect|constraints : str<br ALIGN="LEFT"/>goal : str<br ALIGN="LEFT"/>name : str<br ALIGN="LEFT"/>profile : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
"metagpt.roles.architect.Architect",
"Architect|constraints : str\ngoal : str\nname : str\nprofile : str\n|",
),
(
'"metagpt.actions.skill_action.ArgumentsParingAction" [color="black", fontcolor="black", label=<{ArgumentsParingAction|args : Optional[Dict]<br ALIGN="LEFT"/>ask : str<br ALIGN="LEFT"/>prompt<br ALIGN="LEFT"/>rsp : Optional[Message]<br ALIGN="LEFT"/>skill<br ALIGN="LEFT"/>|parse_arguments(skill_name, txt): dict<br ALIGN="LEFT"/>run(with_message): Message<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.actions.skill_action.ArgumentsParingAction",
"ArgumentsParingAction|args : Optional[Dict]\nask : str\nprompt\nrsp : Optional[Message]\nskill\n|parse_arguments(skill_name, txt): dict\nrun(with_message): Message\n",
),
(
'"metagpt.strategy.base.BaseEvaluator" [color="black", fontcolor="black", label=<{BaseEvaluator|<br ALIGN="LEFT"/>|<I>status_verify</I>()<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.strategy.base.BaseEvaluator",
"BaseEvaluator|\n|<I>status_verify</I>()\n",
),
(
'"metagpt.configs.browser_config.BrowserConfig" [color="black", fontcolor="black", label=<{BrowserConfig|browser : Literal[\'chrome\', \'firefox\', \'edge\', \'ie\']<br ALIGN="LEFT"/>driver : Literal[\'chromium\', \'firefox\', \'webkit\']<br ALIGN="LEFT"/>engine<br ALIGN="LEFT"/>path : str<br ALIGN="LEFT"/>|}>, shape="record", style="solid"];',
"metagpt.configs.browser_config.BrowserConfig",
"BrowserConfig|browser : Literal['chrome', 'firefox', 'edge', 'ie']\ndriver : Literal['chromium', 'firefox', 'webkit']\nengine\npath : str\n|",
),
(
'"metagpt.tools.search_engine_serpapi.SerpAPIWrapper" [color="black", fontcolor="black", label=<{SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]<br ALIGN="LEFT"/>model_config<br ALIGN="LEFT"/>params : dict<br ALIGN="LEFT"/>search_engine : Optional[Any]<br ALIGN="LEFT"/>serpapi_api_key : Optional[str]<br ALIGN="LEFT"/>|check_serpapi_api_key(val: str)<br ALIGN="LEFT"/>get_params(query: str): Dict[str, str]<br ALIGN="LEFT"/>results(query: str, max_results: int): dict<br ALIGN="LEFT"/>run(query, max_results: int, as_string: bool): str<br ALIGN="LEFT"/>}>, shape="record", style="solid"];',
"metagpt.tools.search_engine_serpapi.SerpAPIWrapper",
"SerpAPIWrapper|aiosession : Optional[aiohttp.ClientSession]\nmodel_config\nparams : dict\nsearch_engine : Optional[Any]\nserpapi_api_key : Optional[str]\n|check_serpapi_api_key(val: str)\nget_params(query: str): Dict[str, str]\nresults(query: str, max_results: int): dict\nrun(query, max_results: int, as_string: bool): str\n",
),
],
)
def test_split_class_line(line, package_name, info):
p, i = RepoParser._split_class_line(line)
assert p == package_name
assert i == info
@pytest.mark.parametrize(
("v", "name", "args", "return_args"),
[
(
"<I>arequest</I>(method, url, params, headers, files, stream: Literal[True], request_id: Optional[str], request_timeout: Optional[Union[float, Tuple[float, float]]]): Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
"arequest",
[
DotClassAttribute(name="method", description="method"),
DotClassAttribute(name="url", description="url"),
DotClassAttribute(name="params", description="params"),
DotClassAttribute(name="headers", description="headers"),
DotClassAttribute(name="files", description="files"),
DotClassAttribute(name="stream", type_="Literal[True]", description="stream: Literal[True]"),
DotClassAttribute(name="request_id", type_="Optional[str]", description="request_id: Optional[str]"),
DotClassAttribute(
name="request_timeout",
type_="Optional[Union[float,Tuple[float,float]]]",
description="request_timeout: Optional[Union[float, Tuple[float, float]]]",
),
],
DotReturn(
type_="Tuple[AsyncGenerator[OpenAIResponse,None],bool,str]",
compositions=["AsyncGenerator", "OpenAIResponse"],
description="Tuple[AsyncGenerator[OpenAIResponse, None], bool, str]",
),
),
(
"<I>update</I>(subject: str, predicate: str, object_: str)",
"update",
[
DotClassAttribute(name="subject", type_="str", description="subject: str"),
DotClassAttribute(name="predicate", type_="str", description="predicate: str"),
DotClassAttribute(name="object_", type_="str", description="object_: str"),
],
DotReturn(description=""),
),
],
)
def test_parse_method(v, name, args, return_args):
method = DotClassMethod.parse(v)
assert method.name == name
assert method.args == args
assert method.return_args == return_args
assert method.description == v
json_data = method.model_dump_json()
v = DotClassMethod.model_validate_json(json_data)
assert v == method
if __name__ == "__main__":
pytest.main([__file__, "-s"])