mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 20:03:28 +02:00
feat: Action Node + exclude parameter
refactor: awrite
This commit is contained in:
parent
0adabfe53f
commit
8bf7d3186a
7 changed files with 58 additions and 53 deletions
|
|
@ -117,19 +117,20 @@ class ActionNode:
|
|||
obj.add_children(nodes)
|
||||
return obj
|
||||
|
||||
def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""获得子ActionNode的字典,以key索引"""
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items()}
|
||||
exclude = exclude or []
|
||||
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
|
||||
|
||||
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get self key: type mapping"""
|
||||
return {self.key: (self.expected_type, ...)}
|
||||
|
||||
def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]:
|
||||
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
|
||||
"""get key: type mapping under mode"""
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
return self.get_children_mapping()
|
||||
return self.get_self_mapping()
|
||||
return self.get_children_mapping(exclude=exclude)
|
||||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
|
|
@ -154,13 +155,13 @@ class ActionNode:
|
|||
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
|
||||
return new_class
|
||||
|
||||
def create_children_class(self):
|
||||
def create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
mapping = self.get_children_mapping()
|
||||
mapping = self.get_children_mapping(exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def to_dict(self, format_func=None, mode="auto") -> Dict:
|
||||
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
|
||||
"""将当前节点与子节点都按照node: format的格式组织成字典"""
|
||||
|
||||
# 如果没有提供格式化函数,使用默认的格式化方式
|
||||
|
|
@ -180,7 +181,10 @@ class ActionNode:
|
|||
return node_dict
|
||||
|
||||
# 遍历子节点并递归调用 to_dict 方法
|
||||
exclude = exclude or []
|
||||
for _, child_node in self.children.items():
|
||||
if child_node.key in exclude:
|
||||
continue
|
||||
node_dict.update(child_node.to_dict(format_func))
|
||||
|
||||
return node_dict
|
||||
|
|
@ -201,25 +205,25 @@ class ActionNode:
|
|||
else: # markdown
|
||||
return f"[{tag}]\n" + text + f"\n[/{tag}]"
|
||||
|
||||
def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str:
|
||||
nodes = self.to_dict(format_func=format_func, mode=mode)
|
||||
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
|
||||
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
|
||||
text = self.compile_to(nodes, schema, kv_sep)
|
||||
return self.tagging(text, schema, tag)
|
||||
|
||||
def compile_instruction(self, schema="markdown", mode="children", tag="") -> str:
|
||||
def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str:
|
||||
"""compile to raw/json/markdown template with all/root/children nodes"""
|
||||
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ")
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude)
|
||||
|
||||
def compile_example(self, schema="json", mode="children", tag="") -> str:
|
||||
def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str:
|
||||
"""compile to raw/json/markdown examples with all/root/children nodes"""
|
||||
|
||||
# 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example
|
||||
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str
|
||||
format_func = lambda i: i.example
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n")
|
||||
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude)
|
||||
|
||||
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str:
|
||||
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str:
|
||||
"""
|
||||
mode: all/root/children
|
||||
mode="children": 编译所有子节点为一个统一模板,包括instruction与example
|
||||
|
|
@ -235,8 +239,8 @@ class ActionNode:
|
|||
|
||||
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
|
||||
# compile example暂时不支持markdown
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode)
|
||||
example = self.compile_example(schema=schema, tag=TAG, mode=mode)
|
||||
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
|
||||
example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude)
|
||||
# nodes = ", ".join(self.to_dict(mode=mode).keys())
|
||||
constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT]
|
||||
constraint = "\n".join(constraints)
|
||||
|
|
@ -291,11 +295,11 @@ class ActionNode:
|
|||
def set_context(self, context):
|
||||
self.set_recursive("context", context)
|
||||
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode)
|
||||
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
|
||||
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
|
||||
|
||||
if schema != "raw":
|
||||
mapping = self.get_mapping(mode)
|
||||
mapping = self.get_mapping(mode, exclude=exclude)
|
||||
class_name = f"{self.key}_AN"
|
||||
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
|
||||
self.content = content
|
||||
|
|
@ -306,7 +310,7 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout):
|
||||
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
|
|
@ -323,6 +327,7 @@ class ActionNode:
|
|||
- simple: run only once
|
||||
- complex: run each node
|
||||
:param timeout: Timeout for llm invocation.
|
||||
:param exclude: The keys of ActionNode to exclude.
|
||||
:return: self
|
||||
"""
|
||||
self.set_llm(llm)
|
||||
|
|
@ -331,12 +336,14 @@ class ActionNode:
|
|||
schema = self.schema
|
||||
|
||||
if strgy == "simple":
|
||||
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout)
|
||||
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
elif strgy == "complex":
|
||||
# 这里隐式假设了拥有children
|
||||
tmp = {}
|
||||
for _, i in self.children.items():
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout)
|
||||
if exclude and i.key in exclude:
|
||||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.dict())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
|
|
|
|||
|
|
@ -23,10 +23,10 @@ from metagpt.actions import Action, ActionOutput
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.write_prd_an import (
|
||||
PROJECT_NAME,
|
||||
WP_IS_RELATIVE_NODE,
|
||||
WP_ISSUE_TYPE_NODE,
|
||||
WRITE_PRD_NODE,
|
||||
WRITE_PRD_NODE_NO_NAME,
|
||||
)
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
|
|
@ -124,8 +124,8 @@ class WritePRD(Action):
|
|||
# logger.info(rsp)
|
||||
project_name = CONFIG.project_name if CONFIG.project_name else ""
|
||||
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
|
||||
write_prd_node = WRITE_PRD_NODE if not project_name else WRITE_PRD_NODE_NO_NAME
|
||||
node = await write_prd_node.fill(context=context, llm=self.llm) # schema=schema
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
|
||||
await self._rename_workspace(node)
|
||||
return node
|
||||
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ NODES = [
|
|||
LANGUAGE,
|
||||
PROGRAMMING_LANGUAGE,
|
||||
ORIGINAL_REQUIREMENTS,
|
||||
PROJECT_NAME,
|
||||
PRODUCT_GOALS,
|
||||
USER_STORIES,
|
||||
COMPETITIVE_ANALYSIS,
|
||||
|
|
@ -151,8 +152,7 @@ NODES = [
|
|||
ANYTHING_UNCLEAR,
|
||||
]
|
||||
|
||||
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES + [PROJECT_NAME])
|
||||
WRITE_PRD_NODE_NO_NAME = ActionNode.from_children("WritePRD", NODES)
|
||||
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES)
|
||||
WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON])
|
||||
WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON])
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,8 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
|
||||
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
|
||||
from metagpt.utils.common import awrite
|
||||
|
||||
ICL_SAMPLE = """Interface definition:
|
||||
```text
|
||||
|
|
@ -255,20 +254,14 @@ class UTGenerator:
|
|||
|
||||
return doc
|
||||
|
||||
async def _store(self, data, base, folder, fname):
|
||||
"""Store data in a file."""
|
||||
file_path = self.get_file_path(Path(base) / folder, fname)
|
||||
async with aiofiles.open(file_path, mode="w", encoding="utf-8") as file:
|
||||
await file.write(data)
|
||||
|
||||
async def ask_gpt_and_save(self, question: str, tag: str, fname: str):
|
||||
"""Generate questions and store both questions and answers"""
|
||||
messages = [self.icl_sample, question]
|
||||
result = await self.gpt_msgs_to_code(messages=messages)
|
||||
|
||||
await self._store(question, self.questions_path, tag, f"{fname}.txt")
|
||||
await awrite(Path(self.questions_path) / tag / f"{fname}.txt", question)
|
||||
data = result.get("code", "") if result else ""
|
||||
await self._store(data, self.ut_py_path, tag, f"{fname}.py")
|
||||
await awrite(Path(self.ut_py_path) / tag / f"{fname}.py", data)
|
||||
|
||||
async def _generate_ut(self, tag, paths):
|
||||
"""Process the structure under a data path
|
||||
|
|
@ -291,15 +284,3 @@ class UTGenerator:
|
|||
result = await GPTAPI().aask_code(messages=messages)
|
||||
|
||||
return result
|
||||
|
||||
def get_file_path(self, base: Path, fname: str):
|
||||
"""Save different file paths
|
||||
|
||||
Args:
|
||||
base (str): Path
|
||||
fname (str): File name
|
||||
"""
|
||||
path = Path(base)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
file_path = path / fname
|
||||
return str(file_path)
|
||||
|
|
|
|||
|
|
@ -537,6 +537,14 @@ async def aread(file_path: str) -> str:
|
|||
return content
|
||||
|
||||
|
||||
async def awrite(filename: str | Path, data: str):
|
||||
"""Write file asynchronously."""
|
||||
pathname = Path(filename)
|
||||
pathname.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer:
|
||||
await writer.write(data)
|
||||
|
||||
|
||||
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
|
||||
if not Path(filename).exists():
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import asyncio
|
|||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.text_to_embedding import text_to_embedding
|
||||
from metagpt.tools.openai_text_to_embedding import ResultEmbedding
|
||||
|
||||
|
||||
async def mock_text_to_embedding():
|
||||
|
|
@ -23,8 +22,7 @@ async def mock_text_to_embedding():
|
|||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
data = await text_to_embedding(seed.input)
|
||||
v = ResultEmbedding(**data)
|
||||
v = await text_to_embedding(seed.input)
|
||||
assert len(v.data) > 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
import importlib
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Set
|
||||
|
||||
|
|
@ -25,6 +26,8 @@ from metagpt.utils.common import (
|
|||
OutputParser,
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
aread,
|
||||
awrite,
|
||||
check_cmd_exists,
|
||||
concat_namespace,
|
||||
import_class_inst,
|
||||
|
|
@ -170,6 +173,14 @@ class TestGetProjectRoot:
|
|||
async def test_read_file_block(self):
|
||||
assert await read_file_block(filename=__file__, lineno=6, end_lineno=6) == "@File : test_common.py\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write(self):
|
||||
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
|
||||
await awrite(pathname, "ABC")
|
||||
data = await aread(pathname)
|
||||
assert data == "ABC"
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue