feat: merge geekan:env_refactor

This commit is contained in:
莘权 马 2023-12-14 21:14:19 +08:00
commit c24b3fff1b
2 changed files with 79 additions and 19 deletions

View file

@ -27,6 +27,8 @@ SIMPLE_TEMPLATE = """
## context
{context}
-----
## format example
{example}
@ -37,7 +39,7 @@ SIMPLE_TEMPLATE = """
{constraint}
## action
Fill in the above nodes based on the context. Answer in format example.
Fill in the above nodes based on the format example.
"""
@ -108,6 +110,16 @@ class ActionNode:
"""获得子ActionNode的字典以key索引"""
return {k: (v.expected_type, ...) for k, v in self.children.items()}
def get_self_mapping(self) -> Dict[str, Type]:
"""get self key: type mapping"""
return {self.key: (self.expected_type, ...)}
def get_mapping(self, mode="children") -> Dict[str, Type]:
"""get key: type mapping under mode"""
if mode == "children" or (mode == "auto" and self.children):
return self.get_children_mapping()
return self.get_self_mapping()
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Type]):
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""
@ -160,8 +172,8 @@ class ActionNode:
mapping = self.get_children_mapping()
return self.create_model_class(class_name, mapping)
def to_dict(self, format_func=None, mode="all") -> Dict:
"""将当前节点与子节点都按照node: format的格式组织字典"""
def to_dict(self, format_func=None, mode="auto") -> Dict:
"""将当前节点与子节点都按照node: format的格式组织字典"""
# 如果没有提供格式化函数,使用默认的格式化方式
if format_func is None:
@ -171,7 +183,7 @@ class ActionNode:
formatted_value = format_func(self)
# 创建当前节点的键值对
if mode == "children":
if mode == "children" or (mode == "auto" and self.children):
node_dict = {}
else:
node_dict = {self.key: formatted_value}
@ -227,7 +239,7 @@ class ActionNode:
mode="root": NotImplemented
"""
# FIXME: json instruction会带来 "Project name": "web_2048 # 项目名称使用下划线",
# FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线",
self.instruction = self.compile_instruction(to="markdown", mode=mode)
self.example = self.compile_example(to=to, tag="CONTENT", mode=mode)
prompt = template.format(
@ -269,19 +281,59 @@ class ActionNode:
def get(self, key):
return self.instruct_content.dict()[key]
async def fill(self, context, llm, to="json"):
"""运行这个ActionNode并且填槽可以采用不同策略比如只运行子节点"""
self.llm = llm
prompt = self.compile(context=context, to=to)
mapping = self.get_children_mapping()
def set_recursive(self, name, value):
setattr(self, name, value)
for _, i in self.children.items():
i.set_recursive(name, value)
def set_llm(self, llm):
self.set_recursive("llm", llm)
def set_context(self, context):
self.set_recursive("context", context)
async def simple_fill(self, to, mode):
prompt = self.compile(context=self.context, to=to, mode=mode)
mapping = self.get_mapping(mode)
class_name = f"{self.key}_AN"
# 需要传入llm并且实际在ActionNode中执行。需要规划好具体的执行方法
output = await self._aask_v1(prompt, class_name, mapping, format=to)
self.content = output.content
self.instruct_content = output.instruct_content
return self
async def fill(self, context, llm, to="json", mode="auto", strgy="simple"):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
:param llm: Large Language Model with pre-defined system message.
:param to: json/markdown, determine example and output format.
- json: it's easy to open source LLM with json format
- markdown: when generating code, markdown is always better
:param mode: auto/children/root
- auto: automated fill children's nodes and gather outputs, if no children, fill itself
- children: fill children's nodes and gather outputs
- root: fill root's node and gather output
:param strgy: simple/complex
- simple: run only once
- complex: run each node
:return: self
"""
self.set_llm(llm)
self.set_context(context)
if strgy == "simple":
return await self.simple_fill(to, mode)
elif strgy == "complex":
# 这里隐式假设了拥有children
tmp = {}
for _, i in self.children.items():
child = await i.simple_fill(to, mode)
tmp.update(child.instruct_content.dict())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
return self
def action_node_from_tuple_example():
# 示例:列表中包含元组

View file

@ -16,6 +16,13 @@ LANGUAGE = ActionNode(
example="en_us",
)
PROGRAMMING_LANGUAGE = ActionNode(
key="Programming Language",
expected_type=str,
instruction="Python/JavaScript or other mainstream programming language.",
example="Python",
)
ORIGINAL_REQUIREMENTS = ActionNode(
key="Original Requirements",
expected_type=str,
@ -59,14 +66,14 @@ COMPETITIVE_QUADRANT_CHART = ActionNode(
expected_type=str,
instruction="Use mermaid quadrantChart syntax. Distribute scores evenly between 0 and 1",
example="""quadrantChart
title Reach and engagement of campaigns
x-axis Low Reach --> High Reach
y-axis Low Engagement --> High Engagement
quadrant-1 We should expand
quadrant-2 Need to promote
quadrant-3 Re-evaluate
quadrant-4 May be improved
"Campaign: A": [0.3, 0.6]
title "Reach and engagement of campaigns"
x-axis "Low Reach" --> "High Reach"
y-axis "Low Engagement" --> "High Engagement"
quadrant-1 "We should expand"
quadrant-2 "Need to promote"
quadrant-3 "Re-evaluate"
quadrant-4 "May be improved"
"Campaign A": [0.3, 0.6]
"Campaign B": [0.45, 0.23]
"Campaign C": [0.57, 0.69]
"Campaign D": [0.78, 0.34]
@ -124,6 +131,7 @@ REASON = ActionNode(
NODES = [
LANGUAGE,
PROGRAMMING_LANGUAGE,
ORIGINAL_REQUIREMENTS,
PROJECT_NAME,
PRODUCT_GOALS,