Merge branch 'main' into llm_mock

This commit is contained in:
yzlin 2024-01-04 17:19:09 +08:00
commit 271ecc30a2
23 changed files with 575 additions and 114 deletions

2
.gitignore vendored
View file

@ -172,6 +172,8 @@ tests/metagpt/utils/file_repo_git
*.png
htmlcov
htmlcov.*
*.dot
*.pkl
*-structure.csv
*-structure.json

View file

@ -11,7 +11,7 @@ paths:
post:
summary: Generate greeting
description: Generates a greeting message.
operationId: hello.post_greeting
operationId: openapi_v3_hello.post_greeting
responses:
200:
description: greeting response

View file

@ -9,60 +9,187 @@
import re
from pathlib import Path
import aiofiles
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
from metagpt.const import (
AGGREGATION,
COMPOSITION,
DATA_API_DESIGN_FILE_REPO,
GENERALIZATION,
GRAPH_REPO_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name=name, context=context, llm=llm)
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=self.context)
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
repo_parser = RepoParser(base_directory=Path(self.context))
class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
symbols = repo_parser.generate_symbols() # use ast
for file_info in symbols:
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_view(graph_db=graph_db)
await self._save(graph_db=graph_db)
await self._create_mermaid_class_views(graph_db=graph_db)
await graph_db.save()
async def _create_mermaid_class_view(self, graph_db):
pass
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
# if not dataset:
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
# return
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
# code_type = ""
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
# for spo in dataset:
# if spo.object_ in ["javascript", "python"]:
# code_type = spo.object_
# break
async def _create_mermaid_class_views(self, graph_db):
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / CONFIG.git_repo.workdir.name
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
await writer.write(content)
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
class_distinct = set()
relationship_distinct = set()
for r in rows:
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
for r in rows:
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
# try:
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
# class_view = node.instruct_content.model_dump()["Class View"]
# except Exception as e:
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
fields = split_namespace(ns_class_name)
if len(fields) > 2:
# Ignore sub-class
return
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
all_class_view = []
for spo in dataset:
title = f"---\ntitle: {spo.subject}\n---\n"
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
all_class_view.append(spo.object_)
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
class_view = ClassView(name=fields[1])
rows = await graph_db.select(subject=ns_class_name)
for r in rows:
name = split_namespace(r.object_)[-1]
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
attribute = ClassAttribute(
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
)
class_view.attributes.append(attribute)
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
class_view.methods.append(method)
# update graph db
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
content = class_view.get_mermaid(align=1)
logger.debug(content)
await file_writer.write(content)
distinct.add(ns_class_name)
@staticmethod
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
s_fields = split_namespace(ns_class_name)
if len(s_fields) > 2:
# Ignore sub-class
return
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
mappings = {
GENERALIZATION: " <|-- ",
COMPOSITION: " *-- ",
AGGREGATION: " o-- ",
}
content = ""
for p, v in predicates.items():
rows = await graph_db.select(subject=ns_class_name, predicate=p)
for r in rows:
o_fields = split_namespace(r.object_)
if len(o_fields) > 2:
# Ignore sub-class
continue
relationship = mappings.get(v, " .. ")
link = f"{o_fields[1]}{relationship}{s_fields[1]}"
distinct.add(link)
content += f"\t{link}\n"
if content:
logger.debug(content)
await file_writer.write(content)
@staticmethod
def _parse_name(name: str, language="python"):
pattern = re.compile(r"<I>(.*?)<\/I>")
result = re.search(pattern, name)
abstraction = ""
if result:
name = result.group(1)
abstraction = "*"
if name.startswith("__"):
visibility = "-"
elif name.startswith("_"):
visibility = "#"
else:
visibility = "+"
return name, visibility, abstraction
@staticmethod
async def _parse_variable_type(ns_name, graph_db) -> str:
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
if not rows:
return ""
vals = rows[0].object_.replace("'", "").split(":")
if len(vals) == 1:
return ""
val = vals[-1].strip()
return "" if val == "NoneType" else val + " "
@staticmethod
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
if not rows:
return
info = rows[0].object_.replace("'", "")
fs_tag = "("
ix = info.find(fs_tag)
fe_tag = "):"
eix = info.rfind(fe_tag)
if eix < 0:
fe_tag = ")"
eix = info.rfind(fe_tag)
args_info = info[ix + len(fs_tag) : eix].strip()
method.return_type = info[eix + len(fe_tag) :].strip()
if method.return_type == "None":
method.return_type = ""
if "(" in method.return_type:
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
# parse args
if not args_info:
return
splitter_ixs = []
cost = 0
for i in range(len(args_info)):
if args_info[i] == "[":
cost += 1
elif args_info[i] == "]":
cost -= 1
if args_info[i] == "," and cost == 0:
splitter_ixs.append(i)
splitter_ixs.append(len(args_info))
args = []
ix = 0
for eix in splitter_ixs:
args.append(args_info[ix:eix])
ix = eix + 1
for arg in args:
parts = arg.strip().split(":")
if len(parts) == 1:
method.args.append(ClassAttribute(name=parts[0].strip()))
continue
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))

View file

@ -130,7 +130,7 @@ class WriteCode(Action):
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context

View file

@ -126,3 +126,8 @@ LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"
# Class Relationship
GENERALIZATION = "Generalize"
COMPOSITION = "Composite"
AGGREGATION = "Aggregate"

View file

@ -12,14 +12,14 @@ import json
import re
import subprocess
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_str, aread
from metagpt.utils.exceptions import handle_exception
@ -46,6 +46,13 @@ class ClassInfo(BaseModel):
methods: Dict[str, str] = Field(default_factory=dict)
class ClassRelationship(BaseModel):
src: str = ""
dest: str = ""
relationship: str = ""
label: Optional[str] = None
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@ -60,7 +67,8 @@ class RepoParser(BaseModel):
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
file_info.page_info.append(info)
if info:
file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info.classes.append({"name": node.name, "methods": class_methods})
@ -110,7 +118,9 @@ class RepoParser(BaseModel):
return output_path
@staticmethod
def node_to_str(node) -> (int, int, str, str | Tuple):
def node_to_str(node) -> CodeBlockInfo | None:
if isinstance(node, ast.Try):
return None
if any_to_str(node) == any_to_str(ast.Expr):
return CodeBlockInfo(
lineno=node.lineno,
@ -129,6 +139,7 @@ class RepoParser(BaseModel):
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target),
}
func = mappings.get(any_to_str(node))
if func:
@ -143,7 +154,8 @@ class RepoParser(BaseModel):
else:
raise NotImplementedError(f"Not implement:{val}")
return code_block
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
return None
@staticmethod
def _parse_expr(node) -> List:
@ -164,22 +176,51 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if(n):
tokens = [RepoParser._parse_variable(n.test.left)]
for item in n.test.comparators:
tokens.append(RepoParser._parse_variable(item))
tokens = []
try:
if isinstance(n.test, ast.BoolOp):
tokens = []
for v in n.test.values:
tokens.extend(RepoParser._parse_if_compare(v))
return tokens
if isinstance(n.test, ast.Compare):
v = RepoParser._parse_variable(n.test.left)
if v:
tokens.append(v)
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
return tokens
except Exception as e:
logger.warning(f"Unsupported if: {n}, err:{e}")
return tokens
@staticmethod
def _parse_if_compare(n):
if hasattr(n, "left"):
return RepoParser._parse_variable(n.left)
else:
return []
@staticmethod
def _parse_variable(node):
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
try:
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}"
if hasattr(x.value, "id")
else f"{x.attr}",
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
any_to_str(ast.Tuple): lambda x: "",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
except Exception as e:
logger.warning(f"Unsupported variable:{node}, err:{e}")
@staticmethod
def _parse_assign(node):
@ -197,18 +238,21 @@ class RepoParser(BaseModel):
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
class_views, relationship_views = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
return class_views
return class_views, relationship_views
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
lines = await reader.readlines()
data = await aread(filename=class_view_pathname, encoding="utf-8")
lines = data.split("\n")
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
@ -229,6 +273,19 @@ class RepoParser(BaseModel):
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
data = await aread(filename=class_view_pathname, encoding="utf-8")
lines = data.split("\n")
for line in lines:
relationship = RepoParser._split_relationship_line(line)
if not relationship:
continue
relationship_views.append(relationship)
return relationship_views
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
@ -247,6 +304,40 @@ class RepoParser(BaseModel):
info = re.sub(r"<br[^>]*>", "\n", info)
return class_name, info
@staticmethod
def _split_relationship_line(line):
splitters = [" -> ", " [", "];"]
idxs = []
for tag in splitters:
if tag not in line:
return None
idxs.append(line.find(tag))
ret = ClassRelationship()
ret.src = line[0 : idxs[0]].strip('"')
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
mappings = {
'arrowhead="empty"': GENERALIZATION,
'arrowhead="diamond"': COMPOSITION,
'arrowhead="odiamond"': AGGREGATION,
}
for k, v in mappings.items():
if k in properties:
ret.relationship = v
if v != GENERALIZATION:
ret.label = RepoParser._get_label(properties)
break
return ret
@staticmethod
def _get_label(line):
tag = 'label="'
if tag not in line:
return ""
ix = line.find(tag)
eix = line.find('"', ix + len(tag))
return line[ix + len(tag) : eix]
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
@ -271,7 +362,9 @@ class RepoParser(BaseModel):
return mappings
@staticmethod
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationship]):
if not class_views:
return []
c = class_views[0]
@ -290,7 +383,12 @@ class RepoParser(BaseModel):
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
return class_views
for i in range(len(relationship_views)):
v = relationship_views[i]
v.src = RepoParser._repair_ns(v.src, new_mappings)
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
relationship_views[i] = v
return class_views, relationship_views
@staticmethod
def _repair_ns(package, mappings):

View file

@ -451,3 +451,63 @@ class CodeSummarizeContext(BaseModel):
class BugFixContext(BaseContext):
filename: str = ""
# mermaid class view
class ClassMeta(BaseModel):
name: str = ""
abstraction: bool = False
static: bool = False
visibility: str = ""
class ClassAttribute(ClassMeta):
value_type: str = ""
default_value: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
if self.value_type:
content += self.value_type + " "
content += self.name
if self.default_value:
content += "="
if self.value_type not in ["str", "string", "String"]:
content += self.default_value
else:
content += '"' + self.default_value.replace('"', "") + '"'
if self.abstraction:
content += "*"
if self.static:
content += "$"
return content
class ClassMethod(ClassMeta):
args: List[ClassAttribute] = Field(default_factory=list)
return_type: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
if self.return_type:
content += ":" + self.return_type
if self.abstraction:
content += "*"
if self.static:
content += "$"
return content
class ClassView(ClassMeta):
attributes: List[ClassAttribute] = Field(default_factory=list)
methods: List[ClassMethod] = Field(default_factory=list)
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
for v in self.attributes:
content += v.get_mermaid(align=align + 1) + "\n"
for v in self.methods:
content += v.get_mermaid(align=align + 1) + "\n"
content += "".join(["\t" for i in range(align)]) + "}\n"
return content

View file

@ -5,6 +5,12 @@
@Author : mashenquan
@File : metagpt_oas3_api_svc.py
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
curl -X 'POST' \
'http://localhost:8080/openapi/greeting/dave' \
-H 'accept: text/plain' \
-H 'Content-Type: application/json' \
-d '{}'
"""
from pathlib import Path
@ -15,7 +21,7 @@ import connexion
def oas_http_svc():
"""Start the OAS 3.0 OpenAPI HTTP service"""
print("http://localhost:8080/oas3/ui/")
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
app.add_api("metagpt_oas3_api.yaml")
app.add_api("openapi.yaml")

View file

@ -23,7 +23,7 @@ async def post_greeting(name: str) -> str:
if __name__ == "__main__":
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
app.run(port=8082)

View file

@ -407,6 +407,10 @@ def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.

View file

@ -12,9 +12,9 @@ import json
from pathlib import Path
from typing import List
import aiofiles
import networkx
from metagpt.utils.common import aread, awrite
from metagpt.utils.graph_repository import SPO, GraphRepository
@ -55,12 +55,10 @@ class DiGraphRepository(GraphRepository):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
await writer.write(data)
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
data = await reader.read(-1)
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)

View file

@ -138,6 +138,8 @@ class FileRepository:
files = self._git_repo.changed_files
relative_files = {}
for p, ct in files.items():
if ct.value == "D": # deleted
continue
try:
rf = Path(p).relative_to(self._relative_path)
except ValueError:

View file

@ -13,19 +13,25 @@ from typing import List
from pydantic import BaseModel
from metagpt.repo_parser import ClassInfo, RepoFileInfo
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
OF = "Of"
ON = "On"
CLASS = "class"
FUNCTION = "function"
HAS_FUNCTION = "has_function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
@ -73,11 +79,13 @@ class GraphRepository(ABC):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
# file -> class
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
# class detail
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
@ -85,12 +93,22 @@ class GraphRepository(ABC):
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
# file -> function
await graph_db.insert(
subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
)
# function detail
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
@ -105,13 +123,13 @@ class GraphRepository(ABC):
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
object_=code_block.model_dump_json(),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
object_=code_block.model_dump_json(),
)
@staticmethod
@ -129,6 +147,13 @@ class GraphRepository(ABC):
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_PROPERTY,
object_=concat_namespace(c.package, vn),
)
# property detail
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
@ -138,6 +163,15 @@ class GraphRepository(ABC):
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
@ -148,3 +182,19 @@ class GraphRepository(ABC):
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
):
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
)
if not r.label:
continue
await graph_db.insert(
subject=r.src,
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)

View file

@ -11,6 +11,7 @@ import json
import logging
import os
import re
import uuid
import pytest
@ -123,9 +124,9 @@ def loguru_caplog(caplog):
# init & dispose git repo
@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown_git_repo(request):
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
CONFIG.git_reinit = True
# Destroy git repo at the end of the test session.

View file

@ -11,13 +11,19 @@ from pathlib import Path
import pytest
from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.llm import LLM
@pytest.mark.asyncio
async def test_rebuild():
action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
action = RebuildClassView(
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
)
await action.run()
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
assert graph_file_repo.changed_files
if __name__ == "__main__":

View file

@ -19,6 +19,9 @@ from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
ClassAttribute,
ClassMethod,
ClassView,
CodeSummarizeContext,
Document,
Message,
@ -156,5 +159,30 @@ def test_CodeSummarizeContext(file_list, want):
assert want in m
def test_class_view():
attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
assert attr_b.get_mermaid(align=0) == '#str b="0"$'
class_view = ClassView(name="A")
class_view.attributes = [attr_a, attr_b]
method_a = ClassMethod(name="run", visibility="+", abstraction=True)
assert method_a.get_mermaid(align=1) == "\t+run()*"
method_b = ClassMethod(
name="_test",
visibility="#",
static=True,
args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
return_type="str",
)
assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
class_view.methods = [method_a, method_b]
assert (
class_view.get_mermaid(align=0)
== 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -24,13 +24,14 @@ async def test_oas2_svc():
process = subprocess.Popen(["python", str(script_pathname)], cwd=str(workdir), env=env)
await asyncio.sleep(5)
url = "http://localhost:8080/openapi/greeting/dave"
headers = {"accept": "text/plain", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
assert response.text == "Hello dave\n"
process.terminate()
try:
url = "http://localhost:8080/openapi/greeting/dave"
headers = {"accept": "text/plain", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
assert response.text == "Hello dave\n"
finally:
process.terminate()
if __name__ == "__main__":

View file

@ -3,7 +3,7 @@
"""
@Time : 2023/12/26
@Author : mashenquan
@File : test_hello.py
@File : test_openapi_v3_hello.py
"""
import asyncio
import subprocess
@ -24,13 +24,14 @@ async def test_hello():
process = subprocess.Popen(["python", str(script_pathname)], cwd=workdir, env=env)
await asyncio.sleep(5)
url = "http://localhost:8082/openapi/greeting/dave"
headers = {"accept": "text/plain", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
assert response.text == "Hello dave\n"
process.terminate()
try:
url = "http://localhost:8082/openapi/greeting/dave"
headers = {"accept": "text/plain", "Content-Type": "application/json"}
data = {}
response = requests.post(url, headers=headers, json=data)
assert response.text == "Hello dave\n"
finally:
process.terminate()
if __name__ == "__main__":

View file

@ -36,6 +36,7 @@ from metagpt.utils.common import (
read_file_block,
read_json_file,
require_python_version,
split_namespace,
)
@ -163,6 +164,23 @@ class TestGetProjectRoot:
assert concat_namespace("a", "b", "c", "e") == "a:b:c:e"
assert concat_namespace("a", "b", "c", "e", "f") == "a:b:c:e:f"
@pytest.mark.parametrize(
("val", "want"),
[
(
"tests/metagpt/test_role.py:test_react:Input:subscription",
["tests/metagpt/test_role.py", "test_react", "Input", "subscription"],
),
(
"tests/metagpt/test_role.py:test_react:Input:goal",
["tests/metagpt/test_role.py", "test_react", "Input", "goal"],
),
],
)
def test_split_namespace(self, val, want):
res = split_namespace(val)
assert res == want
def test_read_json_file(self):
assert read_json_file(str(Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"), encoding="utf-8")
with pytest.raises(FileNotFoundError):

View file

@ -6,20 +6,34 @@
@File : test_redis.py
"""
import mock
import pytest
from metagpt.config import CONFIG
from metagpt.utils.redis import Redis
async def async_mock_from_url(*args, **kwargs):
mock_client = mock.AsyncMock()
mock_client.set.return_value = None
mock_client.get.side_effect = [b"test", b""]
return mock_client
@pytest.mark.asyncio
async def test_redis():
@mock.patch("aioredis.from_url", return_value=async_mock_from_url())
async def test_redis(mock_from_url):
# Mock
# mock_client = mock.AsyncMock()
# mock_client.set.return_value=None
# mock_client.get.side_effect = [b'test', b'']
# mock_from_url.return_value = mock_client
# Prerequisites
assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
# assert CONFIG.REDIS_USER
assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD"
assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based"
CONFIG.REDIS_HOST = "MOCK_REDIS_HOST"
CONFIG.REDIS_PORT = "MOCK_REDIS_PORT"
CONFIG.REDIS_PASSWORD = "MOCK_REDIS_PASSWORD"
CONFIG.REDIS_DB = 0
conn = Redis()
assert not conn.is_valid

View file

@ -2,20 +2,18 @@
# -*- coding: utf-8 -*-
# @Desc : unittest of repair_llm_raw_output
from metagpt.config import CONFIG
from metagpt.utils.repair_llm_raw_output import (
RepairType,
extract_content_from_output,
repair_invalid_json,
repair_llm_raw_output,
retry_parse_json_text,
)
"""
CONFIG.repair_llm_output should be True before retry_parse_json_text imported.
so we move `from ... impot ...` into each `test_xx` to avoid `Module level import not at top of file` format warning.
"""
CONFIG.repair_llm_output = True
def test_repair_case_sensitivity():
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output
raw_output = """{
"Original requirements": "Write a 2048 game",
"search Information": "",
@ -36,6 +34,8 @@ def test_repair_case_sensitivity():
def test_repair_special_character_missing():
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output
raw_output = """[CONTENT]
"Anything UNCLEAR": "No unclear requirements or information."
[CONTENT]"""
@ -66,11 +66,12 @@ def test_repair_special_character_missing():
target_output = '[CONTENT] {"a": "b"} [/CONTENT]'
output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"])
print("output\n", output)
assert output == target_output
def test_required_key_pair_missing():
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output
raw_output = '[CONTENT] {"a": "b"}'
target_output = '[CONTENT] {"a": "b"}\n[/CONTENT]'
@ -107,6 +108,8 @@ xxx
def test_repair_json_format():
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
raw_output = "{ xxx }]"
target_output = "{ xxx }"
@ -127,6 +130,8 @@ def test_repair_json_format():
def test_repair_invalid_json():
from metagpt.utils.repair_llm_raw_output import repair_invalid_json
raw_output = """{
"key": "value"
},
@ -169,6 +174,8 @@ value
def test_retry_parse_json_text():
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text
invalid_json_text = """{
"Original Requirements": "Create a 2048 game",
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis"
@ -205,6 +212,7 @@ def test_extract_content_from_output():
xxx [CONTENT] xxx [CONTENT] xxxx [/CONTENT]
xxx [CONTENT] xxxx [/CONTENT] xxx [CONTENT][/CONTENT] xxx [CONTENT][/CONTENT] # target pair is the last one
"""
from metagpt.utils.repair_llm_raw_output import extract_content_from_output
output = (
'Sure! Here is the properly formatted JSON output based on the given context:\n\n[CONTENT]\n{\n"'

View file

@ -9,20 +9,36 @@ import uuid
from pathlib import Path
import aiofiles
import mock
import pytest
from metagpt.config import CONFIG
from metagpt.utils.common import aread
from metagpt.utils.s3 import S3
@pytest.mark.asyncio
async def test_s3():
@mock.patch("aioboto3.Session")
async def test_s3(mock_session_class):
# Set up the mock response
data = await aread(__file__, "utf-8")
mock_session_object = mock.Mock()
reader_mock = mock.AsyncMock()
reader_mock.read.side_effect = [data.encode("utf-8"), b"", data.encode("utf-8")]
type(reader_mock).url = mock.PropertyMock(return_value="https://mock")
mock_client = mock.AsyncMock()
mock_client.put_object.return_value = None
mock_client.get_object.return_value = {"Body": reader_mock}
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_session_object.client.return_value = mock_client
mock_session_class.return_value = mock_session_object
# Prerequisites
assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
# assert CONFIG.S3_SECURE: true # true/false
assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
# assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
# assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
# assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
# assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
conn = S3()
assert conn.is_valid
@ -42,6 +58,7 @@ async def test_s3():
assert "http" in res
# Mock session env
type(reader_mock).url = mock.PropertyMock(return_value="")
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
@ -54,6 +71,8 @@ async def test_s3():
finally:
CONFIG.set_context(old_options)
await reader.close()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,13 @@
#!/usr/bin/env python3
# _*_ coding: utf-8 _*_
import pytest
def test_nodeid(request):
print(request.node.nodeid)
assert request.node.nodeid
if __name__ == "__main__":
pytest.main([__file__, "-s"])