From 8bf7d3186a003052fae6c71c84871cb6dccf8e8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 27 Dec 2023 22:46:39 +0800 Subject: [PATCH] feat: Action Node + exclude parameter refactor: awrite --- metagpt/actions/action_node.py | 53 +++++++++++-------- metagpt/actions/write_prd.py | 6 +-- metagpt/actions/write_prd_an.py | 4 +- metagpt/tools/ut_writer.py | 25 ++------- metagpt/utils/common.py | 8 +++ tests/metagpt/learn/test_text_to_embedding.py | 4 +- tests/metagpt/utils/test_common.py | 11 ++++ 7 files changed, 58 insertions(+), 53 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index b554f15dd..9534e91c5 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -117,19 +117,20 @@ class ActionNode: obj.add_children(nodes) return obj - def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]: + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" - return {k: (v.expected_type, ...) for k, v in self.children.items()} + exclude = exclude or [] + return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} - def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]: + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: """get key: type mapping under mode""" if mode == "children" or (mode == "auto" and self.children): - return self.get_children_mapping() - return self.get_self_mapping() + return self.get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self.get_self_mapping() @classmethod def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): @@ -154,13 +155,13 @@ class ActionNode: new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) return new_class - def create_children_class(self): + def create_children_class(self, exclude=None): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" - mapping = self.get_children_mapping() + mapping = self.get_children_mapping(exclude=exclude) return self.create_model_class(class_name, mapping) - def to_dict(self, format_func=None, mode="auto") -> Dict: + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: """将当前节点与子节点都按照node: format的格式组织成字典""" # 如果没有提供格式化函数,使用默认的格式化方式 @@ -180,7 +181,10 @@ class ActionNode: return node_dict # 遍历子节点并递归调用 to_dict 方法 + exclude = exclude or [] for _, child_node in self.children.items(): + if child_node.key in exclude: + continue node_dict.update(child_node.to_dict(format_func)) return node_dict @@ -201,25 +205,25 @@ class ActionNode: else: # markdown return f"[{tag}]\n" + text + f"\n[/{tag}]" - def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str: - nodes = self.to_dict(format_func=format_func, mode=mode) + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) text = self.compile_to(nodes, schema, kv_sep) return self.tagging(text, schema, tag) - def compile_instruction(self, schema="markdown", mode="children", tag="") -> str: + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: """compile to raw/json/markdown template with all/root/children nodes""" format_func = lambda i: f"{i.expected_type} # {i.instruction}" - return self._compile_f(schema, mode, tag, format_func, kv_sep=": ") + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) - def compile_example(self, schema="json", mode="children", tag="") -> str: + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: """compile to raw/json/markdown examples with all/root/children nodes""" # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str format_func = lambda i: i.example - return self._compile_f(schema, mode, tag, format_func, kv_sep="\n") + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) - def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str: + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: """ mode: all/root/children mode="children": 编译所有子节点为一个统一模板,包括instruction与example @@ -235,8 +239,8 @@ class ActionNode: # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", # compile example暂时不支持markdown - instruction = self.compile_instruction(schema="markdown", mode=mode) - example = self.compile_example(schema=schema, tag=TAG, mode=mode) + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) # nodes = ", ".join(self.to_dict(mode=mode).keys()) constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] constraint = "\n".join(constraints) @@ -291,11 +295,11 @@ class ActionNode: def set_context(self, context): self.set_recursive("context", context) - async def simple_fill(self, schema, mode, timeout=CONFIG.timeout): - prompt = self.compile(context=self.context, schema=schema, mode=mode) + async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) if schema != "raw": - mapping = self.get_mapping(mode) + mapping = self.get_mapping(mode, exclude=exclude) class_name = f"{self.key}_AN" content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout) self.content = content @@ -306,7 +310,7 @@ class ActionNode: return self - async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout): + async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]): """Fill the node(s) with mode. :param context: Everything we should know when filling node. @@ -323,6 +327,7 @@ class ActionNode: - simple: run only once - complex: run each node :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. :return: self """ self.set_llm(llm) @@ -331,12 +336,14 @@ class ActionNode: schema = self.schema if strgy == "simple": - return await self.simple_fill(schema=schema, mode=mode, timeout=timeout) + return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) elif strgy == "complex": # 这里隐式假设了拥有children tmp = {} for _, i in self.children.items(): - child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout) + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) tmp.update(child.instruct_content.dict()) cls = self.create_children_class() self.instruct_content = cls(**tmp) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 289354a11..de647f167 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -23,10 +23,10 @@ from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.fix_bug import FixBug from metagpt.actions.write_prd_an import ( + PROJECT_NAME, WP_IS_RELATIVE_NODE, WP_ISSUE_TYPE_NODE, WRITE_PRD_NODE, - WRITE_PRD_NODE_NO_NAME, ) from metagpt.config import CONFIG from metagpt.const import ( @@ -124,8 +124,8 @@ class WritePRD(Action): # logger.info(rsp) project_name = CONFIG.project_name if CONFIG.project_name else "" context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name) - write_prd_node = WRITE_PRD_NODE if not project_name else WRITE_PRD_NODE_NO_NAME - node = await write_prd_node.fill(context=context, llm=self.llm) # schema=schema + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema await self._rename_workspace(node) return node diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py index e33da2451..948d7d62f 100644 --- a/metagpt/actions/write_prd_an.py +++ b/metagpt/actions/write_prd_an.py @@ -141,6 +141,7 @@ NODES = [ LANGUAGE, PROGRAMMING_LANGUAGE, ORIGINAL_REQUIREMENTS, + PROJECT_NAME, PRODUCT_GOALS, USER_STORIES, COMPETITIVE_ANALYSIS, @@ -151,8 +152,7 @@ NODES = [ ANYTHING_UNCLEAR, ] -WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES + [PROJECT_NAME]) -WRITE_PRD_NODE_NO_NAME = ActionNode.from_children("WritePRD", NODES) +WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES) WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON]) WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON]) diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 41b2acbd5..f2f2bf51c 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,9 +4,8 @@ import json from pathlib import Path -import aiofiles - from metagpt.provider.openai_api import OpenAILLM as GPTAPI +from metagpt.utils.common import awrite ICL_SAMPLE = """Interface definition: ```text @@ -255,20 +254,14 @@ class UTGenerator: return doc - async def _store(self, data, base, folder, fname): - """Store data in a file.""" - file_path = self.get_file_path(Path(base) / folder, fname) - async with aiofiles.open(file_path, mode="w", encoding="utf-8") as file: - await file.write(data) - async def ask_gpt_and_save(self, question: str, tag: str, fname: str): """Generate questions and store both questions and answers""" messages = [self.icl_sample, question] result = await self.gpt_msgs_to_code(messages=messages) - await self._store(question, self.questions_path, tag, f"{fname}.txt") + await awrite(Path(self.questions_path) / tag / f"{fname}.txt", question) data = result.get("code", "") if result else "" - await self._store(data, self.ut_py_path, tag, f"{fname}.py") + await awrite(Path(self.ut_py_path) / tag / f"{fname}.py", data) async def _generate_ut(self, tag, paths): """Process the structure under a data path @@ -291,15 +284,3 @@ class UTGenerator: result = await GPTAPI().aask_code(messages=messages) return result - - def get_file_path(self, base: Path, fname: str): - """Save different file paths - - Args: - base (str): Path - fname (str): File name - """ - path = Path(base) - path.mkdir(parents=True, exist_ok=True) - file_path = path / fname - return str(file_path) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index ced17bb7f..f03de1da1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -537,6 +537,14 @@ async def aread(file_path: str) -> str: return content +async def awrite(filename: str | Path, data: str): + """Write file asynchronously.""" + pathname = Path(filename) + pathname.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer: + await writer.write(data) + + async def read_file_block(filename: str | Path, lineno: int, end_lineno: int): if not Path(filename).exists(): return "" diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py index e3d20a759..f9ad20ee7 100644 --- a/tests/metagpt/learn/test_text_to_embedding.py +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -12,7 +12,6 @@ import asyncio from pydantic import BaseModel from metagpt.learn.text_to_embedding import text_to_embedding -from metagpt.tools.openai_text_to_embedding import ResultEmbedding async def mock_text_to_embedding(): @@ -23,8 +22,7 @@ async def mock_text_to_embedding(): for i in inputs: seed = Input(**i) - data = await text_to_embedding(seed.input) - v = ResultEmbedding(**data) + v = await text_to_embedding(seed.input) assert len(v.data) > 0 diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index 5e49023a0..53708527f 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -9,6 +9,7 @@ import importlib import os import platform +import uuid from pathlib import Path from typing import Any, Set @@ -25,6 +26,8 @@ from metagpt.utils.common import ( OutputParser, any_to_str, any_to_str_set, + aread, + awrite, check_cmd_exists, concat_namespace, import_class_inst, @@ -170,6 +173,14 @@ class TestGetProjectRoot: async def test_read_file_block(self): assert await read_file_block(filename=__file__, lineno=6, end_lineno=6) == "@File : test_common.py\n" + @pytest.mark.asyncio + async def test_read_write(self): + pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp" + await awrite(pathname, "ABC") + data = await aread(pathname) + assert data == "ABC" + pathname.unlink(missing_ok=True) + if __name__ == "__main__": pytest.main([__file__, "-s"])