add ActionNode review/revise

This commit is contained in:
better629 2024-01-08 16:09:14 +08:00 committed by 莘权 马
parent 662102d188
commit 68e53d2862
4 changed files with 520 additions and 7 deletions

View file

@ -11,7 +11,7 @@ import pytest
from pydantic import ValidationError
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
from metagpt.environment import Environment
from metagpt.llm import LLM
from metagpt.roles import Role
@ -98,6 +98,83 @@ async def test_action_node_two_layer():
assert "579" in answer2.content
@pytest.mark.asyncio
async def test_action_node_review():
key = "Project Name"
node_a = ActionNode(
key=key,
expected_type=str,
instruction='According to the content of "Original Requirements," name the project using snake case style '
"with underline, like 'game_2048' or 'simple_crm.",
example="game_2048",
)
with pytest.raises(RuntimeError):
_ = await node_a.review()
_ = await node_a.fill(context=None, llm=LLM())
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
assert list(review_comments.keys())[0] == key
review_comments = await node_a.review(strgy="complex", review_mode=ReviewMode.AUTO)
assert len(review_comments) == 0
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
with pytest.raises(RuntimeError):
_ = await node.review()
_ = await node.fill(context=None, llm=LLM())
review_comments = await node.review(review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
assert list(review_comments.keys())[0] == key
review_comments = await node.review(strgy="complex", review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
assert list(review_comments.keys())[0] == key
@pytest.mark.asyncio
async def test_action_node_revise():
key = "Project Name"
node_a = ActionNode(
key=key,
expected_type=str,
instruction='According to the content of "Original Requirements," name the project using snake case style '
"with underline, like 'game_2048' or 'simple_crm.",
example="game_2048",
)
with pytest.raises(RuntimeError):
_ = await node_a.review()
_ = await node_a.fill(context=None, llm=LLM())
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
assert "game_snake" in getattr(node_a.instruct_content, key)
revise_contents = await node_a.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 0
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
with pytest.raises(RuntimeError):
_ = await node.revise()
_ = await node.fill(context=None, llm=LLM())
setattr(node.instruct_content, key, "game snake")
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
assert "game_snake" in getattr(node.instruct_content, key)
revise_contents = await node.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
assert "game_snake" in getattr(node.instruct_content, key)
t_dict = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
@ -138,10 +215,10 @@ def test_create_model_class():
assert test_class.__name__ == "test_class"
output = test_class(**t_dict)
print(output.schema())
assert output.schema()["title"] == "test_class"
assert output.schema()["type"] == "object"
assert output.schema()["properties"]["Full API spec"]
print(output.model_json_schema())
assert output.model_json_schema()["title"] == "test_class"
assert output.model_json_schema()["type"] == "object"
assert output.model_json_schema()["properties"]["Full API spec"]
def test_create_model_class_with_fields_unrecognized():

View file

@ -0,0 +1,74 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of human_interaction
import pytest
from pydantic import BaseModel
from metagpt.utils.human_interaction import HumanInteraction
class InstructContent(BaseModel):
test_field1: str = ""
test_field2: list[str] = []
data_mapping = {
"test_field1": (str, ...),
"test_field2": (list[str], ...)
}
human_interaction = HumanInteraction()
def test_input_num(mocker):
mocker.patch("builtins.input", lambda _: "quit")
interact_contents = human_interaction.interact_with_instruct_content(InstructContent(), data_mapping)
assert len(interact_contents) == 0
mocker.patch("builtins.input", lambda _: "1")
input_num = human_interaction.input_num_until_valid(2)
assert input_num == 1
def test_check_input_type():
ret, _ = human_interaction.check_input_type(input_str="test string",
req_type=str)
assert ret
ret, _ = human_interaction.check_input_type(input_str='["test string"]',
req_type=list[str])
assert ret
ret, _ = human_interaction.check_input_type(input_str='{"key", "value"}',
req_type=list[str])
assert not ret
global_index = 0
def mock_input(*args, **kwargs):
"""there are multi input call, return it by global_index"""
arr = ["1", '["test"]', "ignore", "quit"]
global global_index
global_index += 1
if global_index == 3:
raise EOFError()
val = arr[global_index-1]
return val
def test_human_interact_valid_content(mocker):
mocker.patch("builtins.input", mock_input)
input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "review")
assert len(input_contents) == 1
assert input_contents["test_field2"] == '["test"]'
global global_index
global_index = 0
input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "revise")
assert len(input_contents) == 1
assert input_contents["test_field2"] == ["test"]