add ActionNode review/revise

This commit is contained in:
better629 2024-01-08 16:09:14 +08:00 committed by 莘权 马
parent 662102d188
commit 68e53d2862
4 changed files with 520 additions and 7 deletions

View file

@ -9,7 +9,8 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
"""
import json
from typing import Any, Dict, List, Optional, Tuple, Type
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
@ -18,6 +19,18 @@ from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
from metagpt.utils.common import OutputParser, general_after_log
from metagpt.utils.human_interaction import HumanInteraction
class ReviewMode(Enum):
HUMAN = "human"
AUTO = "auto"
class ReviseMode(Enum):
HUMAN = "human"
AUTO = "auto"
TAG = "CONTENT"
@ -44,6 +57,58 @@ SIMPLE_TEMPLATE = """
Follow instructions of nodes, generate output and make sure it follows the format example.
"""
REVIEW_TEMPLATE = """
## context
Compare the keys of nodes_output and the corresponding requirements one by one. If a key that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
### nodes_output
{nodes_output}
-----
## format example
[{tag}]
{{
"key1": "comment1",
"key2": "comment2",
"keyn": "commentn"
}}
[/{tag}]
## nodes: "<node>: <type> # <instruction>"
- key1: <class \'str\'> # the first key name of mismatch key
- key2: <class \'str\'> # the second key name of mismatch key
- keyn: <class \'str\'> # the last key name of mismatch key
## constraint
{constraint}
## action
generate output and make sure it follows the format example.
"""
REVISE_TEMPLATE = """
## context
change the nodes_output key's value to meet its comment and no need to add extra comment.
### nodes_output
{nodes_output}
-----
## format example
{example}
## nodes: "<node>: <type> # <instruction>"
{instruction}
## constraint
{constraint}
## action
generate output and make sure it follows the format example.
"""
def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"):
markdown_str = ""
@ -104,6 +169,9 @@ class ActionNode:
"""增加子ActionNode"""
self.children[node.key] = node
def get_child(self, key: str) -> Union["ActionNode", None]:
return self.children.get(key, None)
def add_children(self, nodes: List["ActionNode"]):
"""批量增加子ActionNode"""
for node in nodes:
@ -151,6 +219,11 @@ class ActionNode:
new_class = create_model(class_name, __validators__=validators, **mapping)
return new_class
def create_class(self, mode: str = "auto", class_name: str = None, exclude=None):
class_name = class_name if class_name else f"{self.key}_AN"
mapping = self.get_mapping(mode=mode, exclude=exclude)
return self.create_model_class(class_name, mapping)
def create_children_class(self, exclude=None):
"""使用object内有的字段直接生成model_class"""
class_name = f"{self.key}_AN"
@ -185,6 +258,25 @@ class ActionNode:
return node_dict
def update_instruct_content(self, incre_data: dict[str, Any]):
assert self.instruct_content
origin_sc_dict = self.instruct_content.model_dump()
origin_sc_dict.update(incre_data)
output_class = self.create_class()
self.instruct_content = output_class(**origin_sc_dict)
def keys(self, mode: str = "auto") -> list:
if mode == "children" or (mode == "auto" and self.children):
keys = []
else:
keys = [self.key]
if mode == "root":
return keys
for _, child_node in self.children.items():
keys.append(child_node.key)
return keys
def compile_to(self, i: Dict, schema, kv_sep) -> str:
if schema == "json":
return json.dumps(i, indent=4)
@ -342,7 +434,170 @@ class ActionNode:
if exclude and i.key in exclude:
continue
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
tmp.update(child.instruct_content.dict())
tmp.update(child.instruct_content.model_dump())
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
return self
async def human_review(self) -> dict[str, str]:
review_comments = HumanInteraction().interact_with_instruct_content(
instruct_content=self.instruct_content, interact_type="review"
)
return review_comments
def _makeup_nodes_output_with_req(self) -> dict[str, str]:
instruct_content_dict = self.instruct_content.model_dump()
nodes_output = {}
for key, value in instruct_content_dict.items():
child = self.get_child(key)
nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction}
return nodes_output
async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]:
"""use key's output value and its instruction to review the modification comment"""
nodes_output = self._makeup_nodes_output_with_req()
"""nodes_output format:
{
"key": {"value": "output value", "requirement": "key instruction"}
}
"""
if not nodes_output:
return dict()
prompt = template.format(
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4), tag=TAG, constraint=FORMAT_CONSTRAINT
)
content = await self.llm.aask(prompt)
# Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys
# of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class.
keys = self.keys()
include_keys = []
for key in keys:
if f'"{key}":' in content:
include_keys.append(key)
if not include_keys:
return dict()
exclude_keys = list(set(keys).difference(include_keys))
output_class_name = f"{self.key}_AN_REVIEW"
output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys)
parsed_data = llm_output_postprocess(
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
)
instruct_content = output_class(**parsed_data)
return instruct_content.model_dump()
async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO):
# generate review comments
if review_mode == ReviewMode.HUMAN:
review_comments = await self.human_review()
else:
review_comments = await self.auto_review()
if not review_comments:
logger.warning("There are no review comments")
return review_comments
async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO):
"""only give the review comment of each exist and mismatch key
:param strgy: simple/complex
- simple: run only once
- complex: run each node
"""
if not hasattr(self, "llm"):
raise RuntimeError("use `review` after `fill`")
assert review_mode in ReviewMode
assert self.instruct_content, 'review only support with `schema != "raw"`'
if strgy == "simple":
review_comments = await self.simple_review(review_mode)
elif strgy == "complex":
# review each child node one-by-one
review_comments = {}
for _, child in self.children.items():
child_review_comment = await child.simple_review(review_mode)
review_comments.update(child_review_comment)
return review_comments
async def human_revise(self) -> dict[str, str]:
review_contents = HumanInteraction().interact_with_instruct_content(
instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise"
)
# re-fill the ActionNode
self.update_instruct_content(review_contents)
return review_contents
def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]:
instruct_content_dict = self.instruct_content.model_dump()
nodes_output = {}
for key, value in instruct_content_dict.items():
if key in review_comments:
nodes_output[key] = {"value": value, "comment": review_comments[key]}
return nodes_output
async def auto_revise(self, template: str = REVISE_TEMPLATE) -> dict[str, str]:
"""revise the value of incorrect keys"""
# generate review comments
review_comments: dict = await self.auto_review()
include_keys = list(review_comments.keys())
# generate revise content
nodes_output = self._makeup_nodes_output_with_comment(review_comments)
keys = self.keys()
exclude_keys = list(set(keys).difference(include_keys))
example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys)
instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys)
prompt = template.format(
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4),
example=example,
instruction=instruction,
constraint=FORMAT_CONSTRAINT,
)
output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys)
output_class_name = f"{self.key}_AN_REVISE"
content, scontent = await self._aask_v1(
prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json"
)
# re-fill the ActionNode
sc_dict = scontent.model_dump()
self.update_instruct_content(sc_dict)
return sc_dict
async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
if revise_mode == ReviseMode.HUMAN:
revise_contents = await self.human_revise()
else:
revise_contents = await self.auto_revise()
return revise_contents
async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
"""revise the content of ActionNode and update the instruct_content
:param strgy: simple/complex
- simple: run only once
- complex: run each node
"""
if not hasattr(self, "llm"):
raise RuntimeError("use `revise` after `fill`")
assert revise_mode in ReviseMode
assert self.instruct_content, 'revise only support with `schema != "raw"`'
if strgy == "simple":
revise_contents = await self.simple_revise(revise_mode)
elif strgy == "complex":
# revise each child node one-by-one
revise_contents = {}
for _, child in self.children.items():
child_revise_content = await child.simple_revise(revise_mode)
revise_contents.update(child_revise_content)
self.update_instruct_content(revise_contents)
return revise_contents

View file

@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : human interaction to get required type text
import json
from typing import Any, Tuple, Type
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.common import import_class
class HumanInteraction(object):
stop_list = ("q", "quit", "exit")
def multilines_input(self, prompt: str = "Enter: ") -> str:
logger.warning("Enter your content, use Ctrl-D or Ctrl-Z ( windows ) to save it.")
logger.info(f"{prompt}\n")
lines = []
while True:
try:
line = input()
lines.append(line)
except EOFError:
break
return "".join(lines)
def check_input_type(self, input_str: str, req_type: Type) -> Tuple[bool, Any]:
check_ret = True
if req_type == str:
# required_type = str, just return True
return check_ret, input_str
try:
input_str = input_str.strip()
data = json.loads(input_str)
except Exception:
return False, None
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
tmp_key = "tmp"
tmp_cls = actionnode_class.create_model_class(class_name=tmp_key.upper(), mapping={tmp_key: (req_type, ...)})
try:
_ = tmp_cls(**{tmp_key: data})
except Exception:
check_ret = False
return check_ret, data
def input_until_valid(self, prompt: str, req_type: Type) -> Any:
# check the input with req_type until it's ok
while True:
input_content = self.multilines_input(prompt)
check_ret, structure_content = self.check_input_type(input_content, req_type)
if check_ret:
break
else:
logger.error(f"Input content can't meet required_type: {req_type}, please Re-Enter.")
return structure_content
def input_num_until_valid(self, num_max: int) -> int:
while True:
input_num = input("Enter the num of the interaction key: ")
input_num = input_num.strip()
if input_num in self.stop_list:
return input_num
try:
input_num = int(input_num)
if 0 <= input_num < num_max:
return input_num
except Exception:
pass
def interact_with_instruct_content(
self, instruct_content: BaseModel, mapping: dict = dict(), interact_type: str = "review"
) -> dict[str, Any]:
assert interact_type in ["review", "revise"]
assert instruct_content
instruct_content_dict = instruct_content.model_dump()
num_fields_map = dict(zip(range(0, len(instruct_content_dict)), instruct_content_dict.keys()))
logger.info(
f"\n{interact_type.upper()} interaction\n"
f"Interaction data: {num_fields_map}\n"
f"Enter the num to interact with corresponding field or `q`/`quit`/`exit` to stop interaction.\n"
f"Enter the field content until it meet field required type.\n"
)
interact_contents = {}
while True:
input_num = self.input_num_until_valid(len(instruct_content_dict))
if input_num in self.stop_list:
logger.warning("Stop human interaction")
break
field = num_fields_map.get(input_num)
logger.info(f"You choose to interact with field: {field}, and do a `{interact_type}` operation.")
if interact_type == "review":
prompt = "Enter your review comment: "
req_type = str
else:
prompt = "Enter your revise content: "
req_type = mapping.get(field)[0] # revise need input content match the required_type
field_content = self.input_until_valid(prompt=prompt, req_type=req_type)
interact_contents[field] = field_content
return interact_contents

View file

@ -11,7 +11,7 @@ import pytest
from pydantic import ValidationError
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
from metagpt.environment import Environment
from metagpt.llm import LLM
from metagpt.roles import Role
@ -98,6 +98,83 @@ async def test_action_node_two_layer():
assert "579" in answer2.content
@pytest.mark.asyncio
async def test_action_node_review():
key = "Project Name"
node_a = ActionNode(
key=key,
expected_type=str,
instruction='According to the content of "Original Requirements," name the project using snake case style '
"with underline, like 'game_2048' or 'simple_crm.",
example="game_2048",
)
with pytest.raises(RuntimeError):
_ = await node_a.review()
_ = await node_a.fill(context=None, llm=LLM())
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
assert list(review_comments.keys())[0] == key
review_comments = await node_a.review(strgy="complex", review_mode=ReviewMode.AUTO)
assert len(review_comments) == 0
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
with pytest.raises(RuntimeError):
_ = await node.review()
_ = await node.fill(context=None, llm=LLM())
review_comments = await node.review(review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
assert list(review_comments.keys())[0] == key
review_comments = await node.review(strgy="complex", review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
assert list(review_comments.keys())[0] == key
@pytest.mark.asyncio
async def test_action_node_revise():
key = "Project Name"
node_a = ActionNode(
key=key,
expected_type=str,
instruction='According to the content of "Original Requirements," name the project using snake case style '
"with underline, like 'game_2048' or 'simple_crm.",
example="game_2048",
)
with pytest.raises(RuntimeError):
_ = await node_a.review()
_ = await node_a.fill(context=None, llm=LLM())
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
assert "game_snake" in getattr(node_a.instruct_content, key)
revise_contents = await node_a.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 0
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
with pytest.raises(RuntimeError):
_ = await node.revise()
_ = await node.fill(context=None, llm=LLM())
setattr(node.instruct_content, key, "game snake")
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
assert "game_snake" in getattr(node.instruct_content, key)
revise_contents = await node.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
assert "game_snake" in getattr(node.instruct_content, key)
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
@ -138,10 +215,10 @@ def test_create_model_class():
assert test_class.__name__ == "test_class"
output = test_class(**t_dict)
print(output.schema())
assert output.schema()["title"] == "test_class"
assert output.schema()["type"] == "object"
assert output.schema()["properties"]["Full API spec"]
print(output.model_json_schema())
assert output.model_json_schema()["title"] == "test_class"
assert output.model_json_schema()["type"] == "object"
assert output.model_json_schema()["properties"]["Full API spec"]
def test_create_model_class_with_fields_unrecognized():

View file

@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of human_interaction
import pytest
from pydantic import BaseModel
from metagpt.utils.human_interaction import HumanInteraction
class InstructContent(BaseModel):
test_field1: str = ""
test_field2: list[str] = []
data_mapping = {
"test_field1": (str, ...),
"test_field2": (list[str], ...)
}
human_interaction = HumanInteraction()
def test_input_num(mocker):
mocker.patch("builtins.input", lambda _: "quit")
interact_contents = human_interaction.interact_with_instruct_content(InstructContent(), data_mapping)
assert len(interact_contents) == 0
mocker.patch("builtins.input", lambda _: "1")
input_num = human_interaction.input_num_until_valid(2)
assert input_num == 1
def test_check_input_type():
ret, _ = human_interaction.check_input_type(input_str="test string",
req_type=str)
assert ret
ret, _ = human_interaction.check_input_type(input_str='["test string"]',
req_type=list[str])
assert ret
ret, _ = human_interaction.check_input_type(input_str='{"key", "value"}',
req_type=list[str])
assert not ret
global_index = 0
def mock_input(*args, **kwargs):
"""there are multi input call, return it by global_index"""
arr = ["1", '["test"]', "ignore", "quit"]
global global_index
global_index += 1
if global_index == 3:
raise EOFError()
val = arr[global_index-1]
return val
def test_human_interact_valid_content(mocker):
mocker.patch("builtins.input", mock_input)
input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "review")
assert len(input_contents) == 1
assert input_contents["test_field2"] == '["test"]'
global global_index
global_index = 0
input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "revise")
assert len(input_contents) == 1
assert input_contents["test_field2"] == ["test"]