feat: Action Node + exclude parameter

refactor: awrite
This commit is contained in:
莘权 马 2023-12-27 22:46:39 +08:00
parent 0adabfe53f
commit 8bf7d3186a
7 changed files with 58 additions and 53 deletions

View file

@ -117,19 +117,20 @@ class ActionNode:
obj.add_children(nodes)
return obj
def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]:
def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""获得子ActionNode的字典以key索引"""
return {k: (v.expected_type, ...) for k, v in self.children.items()}
exclude = exclude or []
return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude}
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
"""get self key: type mapping"""
return {self.key: (self.expected_type, ...)}
def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]:
def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]:
"""get key: type mapping under mode"""
if mode == "children" or (mode == "auto" and self.children):
return self.get_children_mapping()
return self.get_self_mapping()
return self.get_children_mapping(exclude=exclude)
return {} if exclude and self.key in exclude else self.get_self_mapping()
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
@ -154,13 +155,13 @@ class ActionNode:
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
return new_class
def create_children_class(self):
def create_children_class(self, exclude=None):
"""使用object内有的字段直接生成model_class"""
class_name = f"{self.key}_AN"
mapping = self.get_children_mapping()
mapping = self.get_children_mapping(exclude=exclude)
return self.create_model_class(class_name, mapping)
def to_dict(self, format_func=None, mode="auto") -> Dict:
def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict:
"""将当前节点与子节点都按照node: format的格式组织成字典"""
# 如果没有提供格式化函数,使用默认的格式化方式
@ -180,7 +181,10 @@ class ActionNode:
return node_dict
# 遍历子节点并递归调用 to_dict 方法
exclude = exclude or []
for _, child_node in self.children.items():
if child_node.key in exclude:
continue
node_dict.update(child_node.to_dict(format_func))
return node_dict
@ -201,25 +205,25 @@ class ActionNode:
else: # markdown
return f"[{tag}]\n" + text + f"\n[/{tag}]"
def _compile_f(self, schema, mode, tag, format_func, kv_sep) -> str:
nodes = self.to_dict(format_func=format_func, mode=mode)
def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str:
nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude)
text = self.compile_to(nodes, schema, kv_sep)
return self.tagging(text, schema, tag)
def compile_instruction(self, schema="markdown", mode="children", tag="") -> str:
def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str:
"""compile to raw/json/markdown template with all/root/children nodes"""
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ")
return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude)
def compile_example(self, schema="json", mode="children", tag="") -> str:
def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str:
"""compile to raw/json/markdown examples with all/root/children nodes"""
# 这里不能使用f-string因为转译为str后再json.dumps会额外加上引号无法作为有效的example
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list而是str
format_func = lambda i: i.example
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n")
return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude)
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str:
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str:
"""
mode: all/root/children
mode="children": 编译所有子节点为一个统一模板包括instruction与example
@ -235,8 +239,8 @@ class ActionNode:
# FIXME: json instruction会带来格式问题"Project name": "web_2048 # 项目名称使用下划线",
# compile example暂时不支持markdown
instruction = self.compile_instruction(schema="markdown", mode=mode)
example = self.compile_example(schema=schema, tag=TAG, mode=mode)
instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude)
example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude)
# nodes = ", ".join(self.to_dict(mode=mode).keys())
constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT]
constraint = "\n".join(constraints)
@ -291,11 +295,11 @@ class ActionNode:
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout):
prompt = self.compile(context=self.context, schema=schema, mode=mode)
async def simple_fill(self, schema, mode, timeout=CONFIG.timeout, exclude=None):
prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude)
if schema != "raw":
mapping = self.get_mapping(mode)
mapping = self.get_mapping(mode, exclude=exclude)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema, timeout=timeout)
self.content = content
@ -306,7 +310,7 @@ class ActionNode:
return self
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple", timeout=CONFIG.timeout, exclude=[]):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
@ -323,6 +327,7 @@ class ActionNode:
- simple: run only once
- complex: run each node
:param timeout: Timeout for llm invocation.
:param exclude: The keys of ActionNode to exclude.
:return: self
"""
self.set_llm(llm)
@ -331,12 +336,14 @@ class ActionNode:
schema = self.schema
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout)
return await self.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout)
if exclude and i.key in exclude:
continue
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)

View file

@ -23,10 +23,10 @@ from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.fix_bug import FixBug
from metagpt.actions.write_prd_an import (
PROJECT_NAME,
WP_IS_RELATIVE_NODE,
WP_ISSUE_TYPE_NODE,
WRITE_PRD_NODE,
WRITE_PRD_NODE_NO_NAME,
)
from metagpt.config import CONFIG
from metagpt.const import (
@ -124,8 +124,8 @@ class WritePRD(Action):
# logger.info(rsp)
project_name = CONFIG.project_name if CONFIG.project_name else ""
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
write_prd_node = WRITE_PRD_NODE if not project_name else WRITE_PRD_NODE_NO_NAME
node = await write_prd_node.fill(context=context, llm=self.llm) # schema=schema
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
await self._rename_workspace(node)
return node

View file

@ -141,6 +141,7 @@ NODES = [
LANGUAGE,
PROGRAMMING_LANGUAGE,
ORIGINAL_REQUIREMENTS,
PROJECT_NAME,
PRODUCT_GOALS,
USER_STORIES,
COMPETITIVE_ANALYSIS,
@ -151,8 +152,7 @@ NODES = [
ANYTHING_UNCLEAR,
]
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES + [PROJECT_NAME])
WRITE_PRD_NODE_NO_NAME = ActionNode.from_children("WritePRD", NODES)
WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES)
WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON])
WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON])

View file

@ -4,9 +4,8 @@
import json
from pathlib import Path
import aiofiles
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
from metagpt.utils.common import awrite
ICL_SAMPLE = """Interface definition:
```text
@ -255,20 +254,14 @@ class UTGenerator:
return doc
async def _store(self, data, base, folder, fname):
"""Store data in a file."""
file_path = self.get_file_path(Path(base) / folder, fname)
async with aiofiles.open(file_path, mode="w", encoding="utf-8") as file:
await file.write(data)
async def ask_gpt_and_save(self, question: str, tag: str, fname: str):
"""Generate questions and store both questions and answers"""
messages = [self.icl_sample, question]
result = await self.gpt_msgs_to_code(messages=messages)
await self._store(question, self.questions_path, tag, f"{fname}.txt")
await awrite(Path(self.questions_path) / tag / f"{fname}.txt", question)
data = result.get("code", "") if result else ""
await self._store(data, self.ut_py_path, tag, f"{fname}.py")
await awrite(Path(self.ut_py_path) / tag / f"{fname}.py", data)
async def _generate_ut(self, tag, paths):
"""Process the structure under a data path
@ -291,15 +284,3 @@ class UTGenerator:
result = await GPTAPI().aask_code(messages=messages)
return result
def get_file_path(self, base: Path, fname: str):
"""Save different file paths
Args:
base (str): Path
fname (str): File name
"""
path = Path(base)
path.mkdir(parents=True, exist_ok=True)
file_path = path / fname
return str(file_path)

View file

@ -537,6 +537,14 @@ async def aread(file_path: str) -> str:
return content
async def awrite(filename: str | Path, data: str):
"""Write file asynchronously."""
pathname = Path(filename)
pathname.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(str(pathname), mode="w", encoding="utf-8") as writer:
await writer.write(data)
async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
if not Path(filename).exists():
return ""

View file

@ -12,7 +12,6 @@ import asyncio
from pydantic import BaseModel
from metagpt.learn.text_to_embedding import text_to_embedding
from metagpt.tools.openai_text_to_embedding import ResultEmbedding
async def mock_text_to_embedding():
@ -23,8 +22,7 @@ async def mock_text_to_embedding():
for i in inputs:
seed = Input(**i)
data = await text_to_embedding(seed.input)
v = ResultEmbedding(**data)
v = await text_to_embedding(seed.input)
assert len(v.data) > 0

View file

@ -9,6 +9,7 @@
import importlib
import os
import platform
import uuid
from pathlib import Path
from typing import Any, Set
@ -25,6 +26,8 @@ from metagpt.utils.common import (
OutputParser,
any_to_str,
any_to_str_set,
aread,
awrite,
check_cmd_exists,
concat_namespace,
import_class_inst,
@ -170,6 +173,14 @@ class TestGetProjectRoot:
async def test_read_file_block(self):
assert await read_file_block(filename=__file__, lineno=6, end_lineno=6) == "@File : test_common.py\n"
@pytest.mark.asyncio
async def test_read_write(self):
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
await awrite(pathname, "ABC")
data = await aread(pathname)
assert data == "ABC"
pathname.unlink(missing_ok=True)
if __name__ == "__main__":
pytest.main([__file__, "-s"])