mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-30 14:35:17 +02:00
Merge branch 'main' into llm_mock
This commit is contained in:
commit
271ecc30a2
23 changed files with 575 additions and 114 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -172,6 +172,8 @@ tests/metagpt/utils/file_repo_git
|
|||
*.png
|
||||
htmlcov
|
||||
htmlcov.*
|
||||
*.dot
|
||||
*.pkl
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ paths:
|
|||
post:
|
||||
summary: Generate greeting
|
||||
description: Generates a greeting message.
|
||||
operationId: hello.post_greeting
|
||||
operationId: openapi_v3_hello.post_greeting
|
||||
responses:
|
||||
200:
|
||||
description: greeting response
|
||||
|
|
|
|||
|
|
@ -9,60 +9,187 @@
|
|||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
|
||||
from metagpt.const import (
|
||||
AGGREGATION,
|
||||
COMPOSITION,
|
||||
DATA_API_DESIGN_FILE_REPO,
|
||||
GENERALIZATION,
|
||||
GRAPH_REPO_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import RepoParser
|
||||
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
|
||||
from metagpt.utils.common import split_namespace
|
||||
from metagpt.utils.di_graph_repository import DiGraphRepository
|
||||
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
|
||||
|
||||
|
||||
class RebuildClassView(Action):
|
||||
def __init__(self, name="", context=None, llm=None):
|
||||
super().__init__(name=name, context=context, llm=llm)
|
||||
|
||||
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
|
||||
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
|
||||
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
|
||||
repo_parser = RepoParser(base_directory=self.context)
|
||||
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
|
||||
repo_parser = RepoParser(base_directory=Path(self.context))
|
||||
class_views, relationship_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
|
||||
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
|
||||
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
|
||||
symbols = repo_parser.generate_symbols() # use ast
|
||||
for file_info in symbols:
|
||||
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
|
||||
await self._create_mermaid_class_view(graph_db=graph_db)
|
||||
await self._save(graph_db=graph_db)
|
||||
await self._create_mermaid_class_views(graph_db=graph_db)
|
||||
await graph_db.save()
|
||||
|
||||
async def _create_mermaid_class_view(self, graph_db):
|
||||
pass
|
||||
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
|
||||
# if not dataset:
|
||||
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
|
||||
# return
|
||||
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
|
||||
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
|
||||
# code_type = ""
|
||||
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
|
||||
# for spo in dataset:
|
||||
# if spo.object_ in ["javascript", "python"]:
|
||||
# code_type = spo.object_
|
||||
# break
|
||||
async def _create_mermaid_class_views(self, graph_db):
|
||||
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = path / CONFIG.git_repo.workdir.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
|
||||
content = "classDiagram\n"
|
||||
logger.debug(content)
|
||||
await writer.write(content)
|
||||
# class names
|
||||
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
|
||||
class_distinct = set()
|
||||
relationship_distinct = set()
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
|
||||
for r in rows:
|
||||
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
|
||||
|
||||
# try:
|
||||
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
|
||||
# class_view = node.instruct_content.model_dump()["Class View"]
|
||||
# except Exception as e:
|
||||
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
|
||||
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
|
||||
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
|
||||
@staticmethod
|
||||
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
|
||||
fields = split_namespace(ns_class_name)
|
||||
if len(fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
|
||||
async def _save(self, graph_db):
|
||||
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
|
||||
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
|
||||
all_class_view = []
|
||||
for spo in dataset:
|
||||
title = f"---\ntitle: {spo.subject}\n---\n"
|
||||
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
|
||||
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
|
||||
all_class_view.append(spo.object_)
|
||||
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
|
||||
class_view = ClassView(name=fields[1])
|
||||
rows = await graph_db.select(subject=ns_class_name)
|
||||
for r in rows:
|
||||
name = split_namespace(r.object_)[-1]
|
||||
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
|
||||
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
|
||||
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
|
||||
attribute = ClassAttribute(
|
||||
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
|
||||
)
|
||||
class_view.attributes.append(attribute)
|
||||
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
|
||||
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
|
||||
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
|
||||
class_view.methods.append(method)
|
||||
|
||||
# update graph db
|
||||
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
|
||||
|
||||
content = class_view.get_mermaid(align=1)
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
distinct.add(ns_class_name)
|
||||
|
||||
@staticmethod
|
||||
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
|
||||
s_fields = split_namespace(ns_class_name)
|
||||
if len(s_fields) > 2:
|
||||
# Ignore sub-class
|
||||
return
|
||||
|
||||
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
|
||||
mappings = {
|
||||
GENERALIZATION: " <|-- ",
|
||||
COMPOSITION: " *-- ",
|
||||
AGGREGATION: " o-- ",
|
||||
}
|
||||
content = ""
|
||||
for p, v in predicates.items():
|
||||
rows = await graph_db.select(subject=ns_class_name, predicate=p)
|
||||
for r in rows:
|
||||
o_fields = split_namespace(r.object_)
|
||||
if len(o_fields) > 2:
|
||||
# Ignore sub-class
|
||||
continue
|
||||
relationship = mappings.get(v, " .. ")
|
||||
link = f"{o_fields[1]}{relationship}{s_fields[1]}"
|
||||
distinct.add(link)
|
||||
content += f"\t{link}\n"
|
||||
|
||||
if content:
|
||||
logger.debug(content)
|
||||
await file_writer.write(content)
|
||||
|
||||
@staticmethod
|
||||
def _parse_name(name: str, language="python"):
|
||||
pattern = re.compile(r"<I>(.*?)<\/I>")
|
||||
result = re.search(pattern, name)
|
||||
|
||||
abstraction = ""
|
||||
if result:
|
||||
name = result.group(1)
|
||||
abstraction = "*"
|
||||
if name.startswith("__"):
|
||||
visibility = "-"
|
||||
elif name.startswith("_"):
|
||||
visibility = "#"
|
||||
else:
|
||||
visibility = "+"
|
||||
return name, visibility, abstraction
|
||||
|
||||
@staticmethod
|
||||
async def _parse_variable_type(ns_name, graph_db) -> str:
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
|
||||
if not rows:
|
||||
return ""
|
||||
vals = rows[0].object_.replace("'", "").split(":")
|
||||
if len(vals) == 1:
|
||||
return ""
|
||||
val = vals[-1].strip()
|
||||
return "" if val == "NoneType" else val + " "
|
||||
|
||||
@staticmethod
|
||||
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
|
||||
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
|
||||
if not rows:
|
||||
return
|
||||
info = rows[0].object_.replace("'", "")
|
||||
|
||||
fs_tag = "("
|
||||
ix = info.find(fs_tag)
|
||||
fe_tag = "):"
|
||||
eix = info.rfind(fe_tag)
|
||||
if eix < 0:
|
||||
fe_tag = ")"
|
||||
eix = info.rfind(fe_tag)
|
||||
args_info = info[ix + len(fs_tag) : eix].strip()
|
||||
method.return_type = info[eix + len(fe_tag) :].strip()
|
||||
if method.return_type == "None":
|
||||
method.return_type = ""
|
||||
if "(" in method.return_type:
|
||||
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
|
||||
|
||||
# parse args
|
||||
if not args_info:
|
||||
return
|
||||
splitter_ixs = []
|
||||
cost = 0
|
||||
for i in range(len(args_info)):
|
||||
if args_info[i] == "[":
|
||||
cost += 1
|
||||
elif args_info[i] == "]":
|
||||
cost -= 1
|
||||
if args_info[i] == "," and cost == 0:
|
||||
splitter_ixs.append(i)
|
||||
splitter_ixs.append(len(args_info))
|
||||
args = []
|
||||
ix = 0
|
||||
for eix in splitter_ixs:
|
||||
args.append(args_info[ix:eix])
|
||||
ix = eix + 1
|
||||
for arg in args:
|
||||
parts = arg.strip().split(":")
|
||||
if len(parts) == 1:
|
||||
method.args.append(ClassAttribute(name=parts[0].strip()))
|
||||
continue
|
||||
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class WriteCode(Action):
|
|||
if not coding_context.code_doc:
|
||||
# avoid root_path pydantic ValidationError if use WriteCode alone
|
||||
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
|
||||
coding_context.code_doc.content = code
|
||||
return coding_context
|
||||
|
||||
|
|
|
|||
|
|
@ -126,3 +126,8 @@ LLM_API_TIMEOUT = 300
|
|||
|
||||
# Message id
|
||||
IGNORED_MESSAGE_ID = "0"
|
||||
|
||||
# Class Relationship
|
||||
GENERALIZATION = "Generalize"
|
||||
COMPOSITION = "Composite"
|
||||
AGGREGATION = "Aggregate"
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ import json
|
|||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiofiles
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, aread
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
|
|
@ -46,6 +46,13 @@ class ClassInfo(BaseModel):
|
|||
methods: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassRelationship(BaseModel):
|
||||
src: str = ""
|
||||
dest: str = ""
|
||||
relationship: str = ""
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
class RepoParser(BaseModel):
|
||||
base_directory: Path = Field(default=None)
|
||||
|
||||
|
|
@ -60,7 +67,8 @@ class RepoParser(BaseModel):
|
|||
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
|
||||
for node in tree:
|
||||
info = RepoParser.node_to_str(node)
|
||||
file_info.page_info.append(info)
|
||||
if info:
|
||||
file_info.page_info.append(info)
|
||||
if isinstance(node, ast.ClassDef):
|
||||
class_methods = [m.name for m in node.body if is_func(m)]
|
||||
file_info.classes.append({"name": node.name, "methods": class_methods})
|
||||
|
|
@ -110,7 +118,9 @@ class RepoParser(BaseModel):
|
|||
return output_path
|
||||
|
||||
@staticmethod
|
||||
def node_to_str(node) -> (int, int, str, str | Tuple):
|
||||
def node_to_str(node) -> CodeBlockInfo | None:
|
||||
if isinstance(node, ast.Try):
|
||||
return None
|
||||
if any_to_str(node) == any_to_str(ast.Expr):
|
||||
return CodeBlockInfo(
|
||||
lineno=node.lineno,
|
||||
|
|
@ -129,6 +139,7 @@ class RepoParser(BaseModel):
|
|||
},
|
||||
any_to_str(ast.If): RepoParser._parse_if,
|
||||
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
|
||||
any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target),
|
||||
}
|
||||
func = mappings.get(any_to_str(node))
|
||||
if func:
|
||||
|
|
@ -143,7 +154,8 @@ class RepoParser(BaseModel):
|
|||
else:
|
||||
raise NotImplementedError(f"Not implement:{val}")
|
||||
return code_block
|
||||
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
|
||||
logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_expr(node) -> List:
|
||||
|
|
@ -164,22 +176,51 @@ class RepoParser(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _parse_if(n):
|
||||
tokens = [RepoParser._parse_variable(n.test.left)]
|
||||
for item in n.test.comparators:
|
||||
tokens.append(RepoParser._parse_variable(item))
|
||||
tokens = []
|
||||
try:
|
||||
if isinstance(n.test, ast.BoolOp):
|
||||
tokens = []
|
||||
for v in n.test.values:
|
||||
tokens.extend(RepoParser._parse_if_compare(v))
|
||||
return tokens
|
||||
if isinstance(n.test, ast.Compare):
|
||||
v = RepoParser._parse_variable(n.test.left)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
for item in n.test.comparators:
|
||||
v = RepoParser._parse_variable(item)
|
||||
if v:
|
||||
tokens.append(v)
|
||||
return tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Unsupported if: {n}, err:{e}")
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def _parse_if_compare(n):
|
||||
if hasattr(n, "left"):
|
||||
return RepoParser._parse_variable(n.left)
|
||||
else:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _parse_variable(node):
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: x.value,
|
||||
any_to_str(ast.Name): lambda x: x.id,
|
||||
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
|
||||
}
|
||||
func = funcs.get(any_to_str(node))
|
||||
if not func:
|
||||
raise NotImplementedError(f"Not implement:{node}")
|
||||
return func(node)
|
||||
try:
|
||||
funcs = {
|
||||
any_to_str(ast.Constant): lambda x: x.value,
|
||||
any_to_str(ast.Name): lambda x: x.id,
|
||||
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}"
|
||||
if hasattr(x.value, "id")
|
||||
else f"{x.attr}",
|
||||
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
|
||||
any_to_str(ast.Tuple): lambda x: "",
|
||||
}
|
||||
func = funcs.get(any_to_str(node))
|
||||
if not func:
|
||||
raise NotImplementedError(f"Not implement:{node}")
|
||||
return func(node)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unsupported variable:{node}, err:{e}")
|
||||
|
||||
@staticmethod
|
||||
def _parse_assign(node):
|
||||
|
|
@ -197,18 +238,21 @@ class RepoParser(BaseModel):
|
|||
raise ValueError(f"{result}")
|
||||
class_view_pathname = path / "classes.dot"
|
||||
class_views = await self._parse_classes(class_view_pathname)
|
||||
relationship_views = await self._parse_class_relationships(class_view_pathname)
|
||||
packages_pathname = path / "packages.dot"
|
||||
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
|
||||
class_views, relationship_views = RepoParser._repair_namespaces(
|
||||
class_views=class_views, relationship_views=relationship_views, path=path
|
||||
)
|
||||
class_view_pathname.unlink(missing_ok=True)
|
||||
packages_pathname.unlink(missing_ok=True)
|
||||
return class_views
|
||||
return class_views, relationship_views
|
||||
|
||||
async def _parse_classes(self, class_view_pathname):
|
||||
class_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return class_views
|
||||
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
|
||||
lines = await reader.readlines()
|
||||
data = await aread(filename=class_view_pathname, encoding="utf-8")
|
||||
lines = data.split("\n")
|
||||
for line in lines:
|
||||
package_name, info = RepoParser._split_class_line(line)
|
||||
if not package_name:
|
||||
|
|
@ -229,6 +273,19 @@ class RepoParser(BaseModel):
|
|||
class_views.append(class_info)
|
||||
return class_views
|
||||
|
||||
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
|
||||
relationship_views = []
|
||||
if not class_view_pathname.exists():
|
||||
return relationship_views
|
||||
data = await aread(filename=class_view_pathname, encoding="utf-8")
|
||||
lines = data.split("\n")
|
||||
for line in lines:
|
||||
relationship = RepoParser._split_relationship_line(line)
|
||||
if not relationship:
|
||||
continue
|
||||
relationship_views.append(relationship)
|
||||
return relationship_views
|
||||
|
||||
@staticmethod
|
||||
def _split_class_line(line):
|
||||
part_splitor = '" ['
|
||||
|
|
@ -247,6 +304,40 @@ class RepoParser(BaseModel):
|
|||
info = re.sub(r"<br[^>]*>", "\n", info)
|
||||
return class_name, info
|
||||
|
||||
@staticmethod
|
||||
def _split_relationship_line(line):
|
||||
splitters = [" -> ", " [", "];"]
|
||||
idxs = []
|
||||
for tag in splitters:
|
||||
if tag not in line:
|
||||
return None
|
||||
idxs.append(line.find(tag))
|
||||
ret = ClassRelationship()
|
||||
ret.src = line[0 : idxs[0]].strip('"')
|
||||
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
|
||||
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
|
||||
mappings = {
|
||||
'arrowhead="empty"': GENERALIZATION,
|
||||
'arrowhead="diamond"': COMPOSITION,
|
||||
'arrowhead="odiamond"': AGGREGATION,
|
||||
}
|
||||
for k, v in mappings.items():
|
||||
if k in properties:
|
||||
ret.relationship = v
|
||||
if v != GENERALIZATION:
|
||||
ret.label = RepoParser._get_label(properties)
|
||||
break
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _get_label(line):
|
||||
tag = 'label="'
|
||||
if tag not in line:
|
||||
return ""
|
||||
ix = line.find(tag)
|
||||
eix = line.find('"', ix + len(tag))
|
||||
return line[ix + len(tag) : eix]
|
||||
|
||||
@staticmethod
|
||||
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
|
||||
mappings = {
|
||||
|
|
@ -271,7 +362,9 @@ class RepoParser(BaseModel):
|
|||
return mappings
|
||||
|
||||
@staticmethod
|
||||
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
|
||||
def _repair_namespaces(
|
||||
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
|
||||
) -> (List[ClassInfo], List[ClassRelationship]):
|
||||
if not class_views:
|
||||
return []
|
||||
c = class_views[0]
|
||||
|
|
@ -290,7 +383,12 @@ class RepoParser(BaseModel):
|
|||
|
||||
for c in class_views:
|
||||
c.package = RepoParser._repair_ns(c.package, new_mappings)
|
||||
return class_views
|
||||
for i in range(len(relationship_views)):
|
||||
v = relationship_views[i]
|
||||
v.src = RepoParser._repair_ns(v.src, new_mappings)
|
||||
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
|
||||
relationship_views[i] = v
|
||||
return class_views, relationship_views
|
||||
|
||||
@staticmethod
|
||||
def _repair_ns(package, mappings):
|
||||
|
|
|
|||
|
|
@ -451,3 +451,63 @@ class CodeSummarizeContext(BaseModel):
|
|||
|
||||
class BugFixContext(BaseContext):
|
||||
filename: str = ""
|
||||
|
||||
|
||||
# mermaid class view
|
||||
class ClassMeta(BaseModel):
|
||||
name: str = ""
|
||||
abstraction: bool = False
|
||||
static: bool = False
|
||||
visibility: str = ""
|
||||
|
||||
|
||||
class ClassAttribute(ClassMeta):
|
||||
value_type: str = ""
|
||||
default_value: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
if self.value_type:
|
||||
content += self.value_type + " "
|
||||
content += self.name
|
||||
if self.default_value:
|
||||
content += "="
|
||||
if self.value_type not in ["str", "string", "String"]:
|
||||
content += self.default_value
|
||||
else:
|
||||
content += '"' + self.default_value.replace('"', "") + '"'
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassMethod(ClassMeta):
|
||||
args: List[ClassAttribute] = Field(default_factory=list)
|
||||
return_type: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
if self.return_type:
|
||||
content += ":" + self.return_type
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassView(ClassMeta):
|
||||
attributes: List[ClassAttribute] = Field(default_factory=list)
|
||||
methods: List[ClassMethod] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
||||
for v in self.attributes:
|
||||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
for v in self.methods:
|
||||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
content += "".join(["\t" for i in range(align)]) + "}\n"
|
||||
return content
|
||||
|
|
|
|||
|
|
@ -5,6 +5,12 @@
|
|||
@Author : mashenquan
|
||||
@File : metagpt_oas3_api_svc.py
|
||||
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
|
||||
|
||||
curl -X 'POST' \
|
||||
'http://localhost:8080/openapi/greeting/dave' \
|
||||
-H 'accept: text/plain' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}'
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
|
@ -15,7 +21,7 @@ import connexion
|
|||
def oas_http_svc():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service"""
|
||||
print("http://localhost:8080/oas3/ui/")
|
||||
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
|
||||
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("metagpt_oas3_api.yaml")
|
||||
app.add_api("openapi.yaml")
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ async def post_greeting(name: str) -> str:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
|
||||
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
|
||||
app.run(port=8082)
|
||||
|
|
|
|||
|
|
@ -407,6 +407,10 @@ def concat_namespace(*args) -> str:
|
|||
return ":".join(str(value) for value in args)
|
||||
|
||||
|
||||
def split_namespace(ns_class_name: str) -> List[str]:
|
||||
return ns_class_name.split(":")
|
||||
|
||||
|
||||
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
|
||||
"""
|
||||
Generates a logging function to be used after a call is retried.
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import aiofiles
|
||||
import networkx
|
||||
|
||||
from metagpt.utils.common import aread, awrite
|
||||
from metagpt.utils.graph_repository import SPO, GraphRepository
|
||||
|
||||
|
||||
|
|
@ -55,12 +55,10 @@ class DiGraphRepository(GraphRepository):
|
|||
if not path.exists():
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pathname = Path(path) / self.name
|
||||
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
|
||||
await writer.write(data)
|
||||
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
|
||||
|
||||
async def load(self, pathname: str | Path):
|
||||
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read(-1)
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
m = json.loads(data)
|
||||
self._repo = networkx.node_link_graph(m)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,6 +138,8 @@ class FileRepository:
|
|||
files = self._git_repo.changed_files
|
||||
relative_files = {}
|
||||
for p, ct in files.items():
|
||||
if ct.value == "D": # deleted
|
||||
continue
|
||||
try:
|
||||
rf = Path(p).relative_to(self._relative_path)
|
||||
except ValueError:
|
||||
|
|
|
|||
|
|
@ -13,19 +13,25 @@ from typing import List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.repo_parser import ClassInfo, RepoFileInfo
|
||||
from metagpt.logs import logger
|
||||
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
|
||||
from metagpt.utils.common import concat_namespace
|
||||
|
||||
|
||||
class GraphKeyword:
|
||||
IS = "is"
|
||||
OF = "Of"
|
||||
ON = "On"
|
||||
CLASS = "class"
|
||||
FUNCTION = "function"
|
||||
HAS_FUNCTION = "has_function"
|
||||
SOURCE_CODE = "source_code"
|
||||
NULL = "<null>"
|
||||
GLOBAL_VARIABLE = "global_variable"
|
||||
CLASS_FUNCTION = "class_function"
|
||||
CLASS_PROPERTY = "class_property"
|
||||
HAS_CLASS_FUNCTION = "has_class_function"
|
||||
HAS_CLASS_PROPERTY = "has_class_property"
|
||||
HAS_CLASS = "has_class"
|
||||
HAS_PAGE_INFO = "has_page_info"
|
||||
HAS_CLASS_VIEW = "has_class_view"
|
||||
|
|
@ -73,11 +79,13 @@ class GraphRepository(ABC):
|
|||
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
|
||||
for c in file_info.classes:
|
||||
class_name = c.get("name", "")
|
||||
# file -> class
|
||||
await graph_db.insert(
|
||||
subject=file_info.file,
|
||||
predicate=GraphKeyword.HAS_CLASS,
|
||||
object_=concat_namespace(file_info.file, class_name),
|
||||
)
|
||||
# class detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.IS,
|
||||
|
|
@ -85,12 +93,22 @@ class GraphRepository(ABC):
|
|||
)
|
||||
methods = c.get("methods", [])
|
||||
for fn in methods:
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name),
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
object_=concat_namespace(file_info.file, class_name, fn),
|
||||
)
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, class_name, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
object_=GraphKeyword.CLASS_FUNCTION,
|
||||
)
|
||||
for f in file_info.functions:
|
||||
# file -> function
|
||||
await graph_db.insert(
|
||||
subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
|
||||
)
|
||||
|
|
@ -105,13 +123,13 @@ class GraphRepository(ABC):
|
|||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, *code_block.tokens),
|
||||
predicate=GraphKeyword.HAS_PAGE_INFO,
|
||||
object_=code_block.json(ensure_ascii=False),
|
||||
object_=code_block.model_dump_json(),
|
||||
)
|
||||
for k, v in code_block.properties.items():
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(file_info.file, k, v),
|
||||
predicate=GraphKeyword.HAS_PAGE_INFO,
|
||||
object_=code_block.json(ensure_ascii=False),
|
||||
object_=code_block.model_dump_json(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -129,6 +147,13 @@ class GraphRepository(ABC):
|
|||
object_=GraphKeyword.CLASS,
|
||||
)
|
||||
for vn, vt in c.attributes.items():
|
||||
# class -> property
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_PROPERTY,
|
||||
object_=concat_namespace(c.package, vn),
|
||||
)
|
||||
# property detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, vn),
|
||||
predicate=GraphKeyword.IS,
|
||||
|
|
@ -138,6 +163,15 @@ class GraphRepository(ABC):
|
|||
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
|
||||
)
|
||||
for fn, desc in c.methods.items():
|
||||
if "</I>" in desc and "<I>" not in desc:
|
||||
logger.error(desc)
|
||||
# class -> function
|
||||
await graph_db.insert(
|
||||
subject=c.package,
|
||||
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
|
||||
object_=concat_namespace(c.package, fn),
|
||||
)
|
||||
# function detail
|
||||
await graph_db.insert(
|
||||
subject=concat_namespace(c.package, fn),
|
||||
predicate=GraphKeyword.IS,
|
||||
|
|
@ -148,3 +182,19 @@ class GraphRepository(ABC):
|
|||
predicate=GraphKeyword.HAS_ARGS_DESC,
|
||||
object_=desc,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def update_graph_db_with_class_relationship_views(
|
||||
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
|
||||
):
|
||||
for r in relationship_views:
|
||||
await graph_db.insert(
|
||||
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
|
||||
)
|
||||
if not r.label:
|
||||
continue
|
||||
await graph_db.insert(
|
||||
subject=r.src,
|
||||
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
|
||||
object_=concat_namespace(r.dest, r.label),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -123,9 +124,9 @@ def loguru_caplog(caplog):
|
|||
|
||||
|
||||
# init & dispose git repo
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown_git_repo(request):
|
||||
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
|
||||
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
CONFIG.git_reinit = True
|
||||
|
||||
# Destroy git repo at the end of the test session.
|
||||
|
|
|
|||
|
|
@ -11,13 +11,19 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild():
|
||||
action = RebuildClassView(name="RedBean", context=Path(__file__).parent.parent, llm=LLM())
|
||||
action = RebuildClassView(
|
||||
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
|
||||
)
|
||||
await action.run()
|
||||
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
|
||||
assert graph_file_repo.changed_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ from metagpt.config import CONFIG
|
|||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
ClassAttribute,
|
||||
ClassMethod,
|
||||
ClassView,
|
||||
CodeSummarizeContext,
|
||||
Document,
|
||||
Message,
|
||||
|
|
@ -156,5 +159,30 @@ def test_CodeSummarizeContext(file_list, want):
|
|||
assert want in m
|
||||
|
||||
|
||||
def test_class_view():
|
||||
attr_a = ClassAttribute(name="a", value_type="int", default_value="0", visibility="+", abstraction=True)
|
||||
assert attr_a.get_mermaid(align=1) == "\t+int a=0*"
|
||||
attr_b = ClassAttribute(name="b", value_type="str", default_value="0", visibility="#", static=True)
|
||||
assert attr_b.get_mermaid(align=0) == '#str b="0"$'
|
||||
class_view = ClassView(name="A")
|
||||
class_view.attributes = [attr_a, attr_b]
|
||||
|
||||
method_a = ClassMethod(name="run", visibility="+", abstraction=True)
|
||||
assert method_a.get_mermaid(align=1) == "\t+run()*"
|
||||
method_b = ClassMethod(
|
||||
name="_test",
|
||||
visibility="#",
|
||||
static=True,
|
||||
args=[ClassAttribute(name="a", value_type="str"), ClassAttribute(name="b", value_type="int")],
|
||||
return_type="str",
|
||||
)
|
||||
assert method_b.get_mermaid(align=0) == "#_test(str a,int b):str$"
|
||||
class_view.methods = [method_a, method_b]
|
||||
assert (
|
||||
class_view.get_mermaid(align=0)
|
||||
== 'class A{\n\t+int a=0*\n\t#str b="0"$\n\t+run()*\n\t#_test(str a,int b):str$\n}\n'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -24,13 +24,14 @@ async def test_oas2_svc():
|
|||
process = subprocess.Popen(["python", str(script_pathname)], cwd=str(workdir), env=env)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
url = "http://localhost:8080/openapi/greeting/dave"
|
||||
headers = {"accept": "text/plain", "Content-Type": "application/json"}
|
||||
data = {}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
assert response.text == "Hello dave\n"
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
url = "http://localhost:8080/openapi/greeting/dave"
|
||||
headers = {"accept": "text/plain", "Content-Type": "application/json"}
|
||||
data = {}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
assert response.text == "Hello dave\n"
|
||||
finally:
|
||||
process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/12/26
|
||||
@Author : mashenquan
|
||||
@File : test_hello.py
|
||||
@File : test_openapi_v3_hello.py
|
||||
"""
|
||||
import asyncio
|
||||
import subprocess
|
||||
|
|
@ -24,13 +24,14 @@ async def test_hello():
|
|||
process = subprocess.Popen(["python", str(script_pathname)], cwd=workdir, env=env)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
url = "http://localhost:8082/openapi/greeting/dave"
|
||||
headers = {"accept": "text/plain", "Content-Type": "application/json"}
|
||||
data = {}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
assert response.text == "Hello dave\n"
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
url = "http://localhost:8082/openapi/greeting/dave"
|
||||
headers = {"accept": "text/plain", "Content-Type": "application/json"}
|
||||
data = {}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
assert response.text == "Hello dave\n"
|
||||
finally:
|
||||
process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -36,6 +36,7 @@ from metagpt.utils.common import (
|
|||
read_file_block,
|
||||
read_json_file,
|
||||
require_python_version,
|
||||
split_namespace,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -163,6 +164,23 @@ class TestGetProjectRoot:
|
|||
assert concat_namespace("a", "b", "c", "e") == "a:b:c:e"
|
||||
assert concat_namespace("a", "b", "c", "e", "f") == "a:b:c:e:f"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("val", "want"),
|
||||
[
|
||||
(
|
||||
"tests/metagpt/test_role.py:test_react:Input:subscription",
|
||||
["tests/metagpt/test_role.py", "test_react", "Input", "subscription"],
|
||||
),
|
||||
(
|
||||
"tests/metagpt/test_role.py:test_react:Input:goal",
|
||||
["tests/metagpt/test_role.py", "test_react", "Input", "goal"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_namespace(self, val, want):
|
||||
res = split_namespace(val)
|
||||
assert res == want
|
||||
|
||||
def test_read_json_file(self):
|
||||
assert read_json_file(str(Path(__file__).parent / "../../data/ut_writer/yft_swaggerApi.json"), encoding="utf-8")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
|
|
|
|||
|
|
@ -6,20 +6,34 @@
|
|||
@File : test_redis.py
|
||||
"""
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.redis import Redis
|
||||
|
||||
|
||||
async def async_mock_from_url(*args, **kwargs):
|
||||
mock_client = mock.AsyncMock()
|
||||
mock_client.set.return_value = None
|
||||
mock_client.get.side_effect = [b"test", b""]
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis():
|
||||
@mock.patch("aioredis.from_url", return_value=async_mock_from_url())
|
||||
async def test_redis(mock_from_url):
|
||||
# Mock
|
||||
# mock_client = mock.AsyncMock()
|
||||
# mock_client.set.return_value=None
|
||||
# mock_client.get.side_effect = [b'test', b'']
|
||||
# mock_from_url.return_value = mock_client
|
||||
|
||||
# Prerequisites
|
||||
assert CONFIG.REDIS_HOST and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
|
||||
assert CONFIG.REDIS_PORT and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
|
||||
# assert CONFIG.REDIS_USER
|
||||
assert CONFIG.REDIS_PASSWORD is not None and CONFIG.REDIS_PASSWORD != "YOUR_REDIS_PASSWORD"
|
||||
assert CONFIG.REDIS_DB is not None and CONFIG.REDIS_DB != "YOUR_REDIS_DB_INDEX, str, 0-based"
|
||||
CONFIG.REDIS_HOST = "MOCK_REDIS_HOST"
|
||||
CONFIG.REDIS_PORT = "MOCK_REDIS_PORT"
|
||||
CONFIG.REDIS_PASSWORD = "MOCK_REDIS_PASSWORD"
|
||||
CONFIG.REDIS_DB = 0
|
||||
|
||||
conn = Redis()
|
||||
assert not conn.is_valid
|
||||
|
|
|
|||
|
|
@ -2,20 +2,18 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of repair_llm_raw_output
|
||||
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.repair_llm_raw_output import (
|
||||
RepairType,
|
||||
extract_content_from_output,
|
||||
repair_invalid_json,
|
||||
repair_llm_raw_output,
|
||||
retry_parse_json_text,
|
||||
)
|
||||
|
||||
"""
|
||||
CONFIG.repair_llm_output should be True before retry_parse_json_text imported.
|
||||
so we move `from ... impot ...` into each `test_xx` to avoid `Module level import not at top of file` format warning.
|
||||
"""
|
||||
CONFIG.repair_llm_output = True
|
||||
|
||||
|
||||
def test_repair_case_sensitivity():
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output
|
||||
|
||||
raw_output = """{
|
||||
"Original requirements": "Write a 2048 game",
|
||||
"search Information": "",
|
||||
|
|
@ -36,6 +34,8 @@ def test_repair_case_sensitivity():
|
|||
|
||||
|
||||
def test_repair_special_character_missing():
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output
|
||||
|
||||
raw_output = """[CONTENT]
|
||||
"Anything UNCLEAR": "No unclear requirements or information."
|
||||
[CONTENT]"""
|
||||
|
|
@ -66,11 +66,12 @@ def test_repair_special_character_missing():
|
|||
target_output = '[CONTENT] {"a": "b"} [/CONTENT]'
|
||||
|
||||
output = repair_llm_raw_output(output=raw_output, req_keys=["[/CONTENT]"])
|
||||
print("output\n", output)
|
||||
assert output == target_output
|
||||
|
||||
|
||||
def test_required_key_pair_missing():
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output
|
||||
|
||||
raw_output = '[CONTENT] {"a": "b"}'
|
||||
target_output = '[CONTENT] {"a": "b"}\n[/CONTENT]'
|
||||
|
||||
|
|
@ -107,6 +108,8 @@ xxx
|
|||
|
||||
|
||||
def test_repair_json_format():
|
||||
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
|
||||
|
||||
raw_output = "{ xxx }]"
|
||||
target_output = "{ xxx }"
|
||||
|
||||
|
|
@ -127,6 +130,8 @@ def test_repair_json_format():
|
|||
|
||||
|
||||
def test_repair_invalid_json():
|
||||
from metagpt.utils.repair_llm_raw_output import repair_invalid_json
|
||||
|
||||
raw_output = """{
|
||||
"key": "value"
|
||||
},
|
||||
|
|
@ -169,6 +174,8 @@ value
|
|||
|
||||
|
||||
def test_retry_parse_json_text():
|
||||
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text
|
||||
|
||||
invalid_json_text = """{
|
||||
"Original Requirements": "Create a 2048 game",
|
||||
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis"
|
||||
|
|
@ -205,6 +212,7 @@ def test_extract_content_from_output():
|
|||
xxx [CONTENT] xxx [CONTENT] xxxx [/CONTENT]
|
||||
xxx [CONTENT] xxxx [/CONTENT] xxx [CONTENT][/CONTENT] xxx [CONTENT][/CONTENT] # target pair is the last one
|
||||
"""
|
||||
from metagpt.utils.repair_llm_raw_output import extract_content_from_output
|
||||
|
||||
output = (
|
||||
'Sure! Here is the properly formatted JSON output based on the given context:\n\n[CONTENT]\n{\n"'
|
||||
|
|
|
|||
|
|
@ -9,20 +9,36 @@ import uuid
|
|||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3():
|
||||
@mock.patch("aioboto3.Session")
|
||||
async def test_s3(mock_session_class):
|
||||
# Set up the mock response
|
||||
data = await aread(__file__, "utf-8")
|
||||
mock_session_object = mock.Mock()
|
||||
reader_mock = mock.AsyncMock()
|
||||
reader_mock.read.side_effect = [data.encode("utf-8"), b"", data.encode("utf-8")]
|
||||
type(reader_mock).url = mock.PropertyMock(return_value="https://mock")
|
||||
mock_client = mock.AsyncMock()
|
||||
mock_client.put_object.return_value = None
|
||||
mock_client.get_object.return_value = {"Body": reader_mock}
|
||||
mock_client.__aenter__.return_value = mock_client
|
||||
mock_client.__aexit__.return_value = None
|
||||
mock_session_object.client.return_value = mock_client
|
||||
mock_session_class.return_value = mock_session_object
|
||||
|
||||
# Prerequisites
|
||||
assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
|
||||
assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
|
||||
assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
|
||||
# assert CONFIG.S3_SECURE: true # true/false
|
||||
assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
|
||||
# assert CONFIG.S3_ACCESS_KEY and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
|
||||
# assert CONFIG.S3_SECRET_KEY and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
|
||||
# assert CONFIG.S3_ENDPOINT_URL and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
|
||||
# assert CONFIG.S3_BUCKET and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
|
||||
|
||||
conn = S3()
|
||||
assert conn.is_valid
|
||||
|
|
@ -42,6 +58,7 @@ async def test_s3():
|
|||
assert "http" in res
|
||||
|
||||
# Mock session env
|
||||
type(reader_mock).url = mock.PropertyMock(return_value="")
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
|
||||
|
|
@ -54,6 +71,8 @@ async def test_s3():
|
|||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
await reader.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
13
tests/metagpt/utils/test_session.py
Normal file
13
tests/metagpt/utils/test_session.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#!/usr/bin/env python3
|
||||
# _*_ coding: utf-8 _*_
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_nodeid(request):
|
||||
print(request.node.nodeid)
|
||||
assert request.node.nodeid
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue