refine code

This commit is contained in:
geekan 2023-12-19 23:58:18 +08:00
parent 62f34db137
commit 0f78d4ea51
3 changed files with 30 additions and 30 deletions

View file

@ -112,15 +112,15 @@ class ActionNode(Generic[T]):
obj.add_children(nodes)
return obj
def get_children_mapping(self) -> Dict[str, Type]:
def get_children_mapping(self) -> Dict[str, Tuple[Type, Any]]:
"""获得子ActionNode的字典以key索引"""
return {k: (v.expected_type, ...) for k, v in self.children.items()}
def get_self_mapping(self) -> Dict[str, Type]:
def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]:
"""get self key: type mapping"""
return {self.key: (self.expected_type, ...)}
def get_mapping(self, mode="children") -> Dict[str, Type]:
def get_mapping(self, mode="children") -> Dict[str, Tuple[Type, Any]]:
"""get key: type mapping under mode"""
if mode == "children" or (mode == "auto" and self.children):
return self.get_children_mapping()
@ -175,46 +175,46 @@ class ActionNode(Generic[T]):
return node_dict
# 遍历子节点并递归调用 to_dict 方法
for child_key, child_node in self.children.items():
for _, child_node in self.children.items():
node_dict.update(child_node.to_dict(format_func))
return node_dict
def compile_to(self, i: Dict, to) -> str:
if to == "json":
def compile_to(self, i: Dict, schema) -> str:
if schema == "json":
return json.dumps(i, indent=4)
elif to == "markdown":
elif schema == "markdown":
return dict_to_markdown(i)
else:
return str(i)
def tagging(self, text, to, tag="") -> str:
def tagging(self, text, schema, tag="") -> str:
if not tag:
return text
if to == "json":
if schema == "json":
return f"[{tag}]\n" + text + f"\n[/{tag}]"
else:
return f"[{tag}]\n" + text + f"\n[/{tag}]"
def _compile_f(self, to, mode, tag, format_func) -> str:
def _compile_f(self, schema, mode, tag, format_func) -> str:
nodes = self.to_dict(format_func=format_func, mode=mode)
text = self.compile_to(nodes, to)
return self.tagging(text, to, tag)
text = self.compile_to(nodes, schema)
return self.tagging(text, schema, tag)
def compile_instruction(self, to="raw", mode="children", tag="") -> str:
def compile_instruction(self, schema="raw", mode="children", tag="") -> str:
"""compile to raw/json/markdown template with all/root/children nodes"""
format_func = lambda i: f"{i.expected_type} # {i.instruction}"
return self._compile_f(to, mode, tag, format_func)
return self._compile_f(schema, mode, tag, format_func)
def compile_example(self, to="raw", mode="children", tag="") -> str:
def compile_example(self, schema="raw", mode="children", tag="") -> str:
"""compile to raw/json/markdown examples with all/root/children nodes"""
# 这里不能使用f-string因为转译为str后再json.dumps会额外加上引号无法作为有效的example
# 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list而是str
format_func = lambda i: i.example
return self._compile_f(to, mode, tag, format_func)
return self._compile_f(schema, mode, tag, format_func)
def compile(self, context, to="json", mode="children", template=SIMPLE_TEMPLATE) -> str:
def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE) -> str:
"""
mode: all/root/children
mode="children": 编译所有子节点为一个统一模板包括instruction与example
@ -224,8 +224,8 @@ class ActionNode(Generic[T]):
# FIXME: json instruction会带来格式问题"Project name": "web_2048 # 项目名称使用下划线",
# compile example暂时不支持markdown
self.instruction = self.compile_instruction(to="markdown", mode=mode)
self.example = self.compile_example(to=to, tag="CONTENT", mode=mode)
self.instruction = self.compile_instruction(schema="markdown", mode=mode)
self.example = self.compile_example(schema=schema, tag="CONTENT", mode=mode)
prompt = template.format(
context=context, example=self.example, instruction=self.instruction, constraint=CONSTRAINT
)
@ -272,22 +272,22 @@ class ActionNode(Generic[T]):
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, to, mode):
prompt = self.compile(context=self.context, to=to, mode=mode)
async def simple_fill(self, schema, mode):
prompt = self.compile(context=self.context, schema=schema, mode=mode)
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=to)
content, scontent = await self._aask_v1(prompt, class_name, mapping, schema=schema)
self.content = content
self.instruct_content = scontent
return self
async def fill(self, context, llm, to="json", mode="auto", strgy="simple"):
async def fill(self, context, llm, schema="json", mode="auto", strgy="simple"):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
:param llm: Large Language Model with pre-defined system message.
:param to: json/markdown, determine example and output format.
:param schema: json/markdown, determine example and output format.
- json: it's easy to open source LLM with json format
- markdown: when generating code, markdown is always better
:param mode: auto/children/root
@ -303,12 +303,12 @@ class ActionNode(Generic[T]):
self.set_context(context)
if strgy == "simple":
return await self.simple_fill(to, mode)
return await self.simple_fill(schema, mode)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
child = await i.simple_fill(to, mode)
child = await i.simple_fill(schema, mode)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)

View file

@ -81,12 +81,12 @@ class WriteDesign(Action):
return ActionOutput(content=changed_files.json(), instruct_content=changed_files)
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
return node
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, to=schema)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
system_design_doc.content = node.instruct_content.json(ensure_ascii=False)
return system_design_doc

View file

@ -121,7 +121,7 @@ class WritePRD(Action):
# logger.info(rsp)
project_name = CONFIG.project_name if CONFIG.project_name else ""
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, to=schema)
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, schema=schema)
await self._rename_workspace(node)
return node
@ -134,7 +134,7 @@ class WritePRD(Action):
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, to=schema)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
prd_doc.content = node.instruct_content.json(ensure_ascii=False)
await self._rename_workspace(node)
return prd_doc