mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add ActionNode review/revise
This commit is contained in:
parent
662102d188
commit
68e53d2862
4 changed files with 520 additions and 7 deletions
|
|
@ -9,7 +9,8 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
we can use typing to extract the type of the node, but we cannot use built-in list to extract.
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
|
@ -18,6 +19,18 @@ from metagpt.llm import BaseLLM
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
from metagpt.utils.human_interaction import HumanInteraction
|
||||
|
||||
|
||||
class ReviewMode(Enum):
|
||||
HUMAN = "human"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class ReviseMode(Enum):
|
||||
HUMAN = "human"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
TAG = "CONTENT"
|
||||
|
||||
|
|
@ -44,6 +57,58 @@ SIMPLE_TEMPLATE = """
|
|||
Follow instructions of nodes, generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
REVIEW_TEMPLATE = """
|
||||
## context
|
||||
Compare the keys of nodes_output and the corresponding requirements one by one. If a key that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys.
|
||||
|
||||
### nodes_output
|
||||
{nodes_output}
|
||||
|
||||
-----
|
||||
|
||||
## format example
|
||||
[{tag}]
|
||||
{{
|
||||
"key1": "comment1",
|
||||
"key2": "comment2",
|
||||
"keyn": "commentn"
|
||||
}}
|
||||
[/{tag}]
|
||||
|
||||
## nodes: "<node>: <type> # <instruction>"
|
||||
- key1: <class \'str\'> # the first key name of mismatch key
|
||||
- key2: <class \'str\'> # the second key name of mismatch key
|
||||
- keyn: <class \'str\'> # the last key name of mismatch key
|
||||
|
||||
## constraint
|
||||
{constraint}
|
||||
|
||||
## action
|
||||
generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
REVISE_TEMPLATE = """
|
||||
## context
|
||||
change the nodes_output key's value to meet its comment and no need to add extra comment.
|
||||
|
||||
### nodes_output
|
||||
{nodes_output}
|
||||
|
||||
-----
|
||||
|
||||
## format example
|
||||
{example}
|
||||
|
||||
## nodes: "<node>: <type> # <instruction>"
|
||||
{instruction}
|
||||
|
||||
## constraint
|
||||
{constraint}
|
||||
|
||||
## action
|
||||
generate output and make sure it follows the format example.
|
||||
"""
|
||||
|
||||
|
||||
def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"):
|
||||
markdown_str = ""
|
||||
|
|
@ -104,6 +169,9 @@ class ActionNode:
|
|||
"""增加子ActionNode"""
|
||||
self.children[node.key] = node
|
||||
|
||||
def get_child(self, key: str) -> Union["ActionNode", None]:
|
||||
return self.children.get(key, None)
|
||||
|
||||
def add_children(self, nodes: List["ActionNode"]):
|
||||
"""批量增加子ActionNode"""
|
||||
for node in nodes:
|
||||
|
|
@ -151,6 +219,11 @@ class ActionNode:
|
|||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
return new_class
|
||||
|
||||
def create_class(self, mode: str = "auto", class_name: str = None, exclude=None):
|
||||
class_name = class_name if class_name else f"{self.key}_AN"
|
||||
mapping = self.get_mapping(mode=mode, exclude=exclude)
|
||||
return self.create_model_class(class_name, mapping)
|
||||
|
||||
def create_children_class(self, exclude=None):
|
||||
"""使用object内有的字段直接生成model_class"""
|
||||
class_name = f"{self.key}_AN"
|
||||
|
|
@ -185,6 +258,25 @@ class ActionNode:
|
|||
|
||||
return node_dict
|
||||
|
||||
def update_instruct_content(self, incre_data: dict[str, Any]):
|
||||
assert self.instruct_content
|
||||
origin_sc_dict = self.instruct_content.model_dump()
|
||||
origin_sc_dict.update(incre_data)
|
||||
output_class = self.create_class()
|
||||
self.instruct_content = output_class(**origin_sc_dict)
|
||||
|
||||
def keys(self, mode: str = "auto") -> list:
|
||||
if mode == "children" or (mode == "auto" and self.children):
|
||||
keys = []
|
||||
else:
|
||||
keys = [self.key]
|
||||
if mode == "root":
|
||||
return keys
|
||||
|
||||
for _, child_node in self.children.items():
|
||||
keys.append(child_node.key)
|
||||
return keys
|
||||
|
||||
def compile_to(self, i: Dict, schema, kv_sep) -> str:
|
||||
if schema == "json":
|
||||
return json.dumps(i, indent=4)
|
||||
|
|
@ -342,7 +434,170 @@ class ActionNode:
|
|||
if exclude and i.key in exclude:
|
||||
continue
|
||||
child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude)
|
||||
tmp.update(child.instruct_content.dict())
|
||||
tmp.update(child.instruct_content.model_dump())
|
||||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
async def human_review(self) -> dict[str, str]:
|
||||
review_comments = HumanInteraction().interact_with_instruct_content(
|
||||
instruct_content=self.instruct_content, interact_type="review"
|
||||
)
|
||||
|
||||
return review_comments
|
||||
|
||||
def _makeup_nodes_output_with_req(self) -> dict[str, str]:
|
||||
instruct_content_dict = self.instruct_content.model_dump()
|
||||
nodes_output = {}
|
||||
for key, value in instruct_content_dict.items():
|
||||
child = self.get_child(key)
|
||||
nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction}
|
||||
return nodes_output
|
||||
|
||||
async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]:
|
||||
"""use key's output value and its instruction to review the modification comment"""
|
||||
nodes_output = self._makeup_nodes_output_with_req()
|
||||
"""nodes_output format:
|
||||
{
|
||||
"key": {"value": "output value", "requirement": "key instruction"}
|
||||
}
|
||||
"""
|
||||
if not nodes_output:
|
||||
return dict()
|
||||
|
||||
prompt = template.format(
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4), tag=TAG, constraint=FORMAT_CONSTRAINT
|
||||
)
|
||||
|
||||
content = await self.llm.aask(prompt)
|
||||
# Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys
|
||||
# of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class.
|
||||
keys = self.keys()
|
||||
include_keys = []
|
||||
for key in keys:
|
||||
if f'"{key}":' in content:
|
||||
include_keys.append(key)
|
||||
if not include_keys:
|
||||
return dict()
|
||||
|
||||
exclude_keys = list(set(keys).difference(include_keys))
|
||||
output_class_name = f"{self.key}_AN_REVIEW"
|
||||
output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys)
|
||||
parsed_data = llm_output_postprocess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
instruct_content = output_class(**parsed_data)
|
||||
return instruct_content.model_dump()
|
||||
|
||||
async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO):
|
||||
# generate review comments
|
||||
if review_mode == ReviewMode.HUMAN:
|
||||
review_comments = await self.human_review()
|
||||
else:
|
||||
review_comments = await self.auto_review()
|
||||
|
||||
if not review_comments:
|
||||
logger.warning("There are no review comments")
|
||||
return review_comments
|
||||
|
||||
async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO):
|
||||
"""only give the review comment of each exist and mismatch key
|
||||
|
||||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
"""
|
||||
if not hasattr(self, "llm"):
|
||||
raise RuntimeError("use `review` after `fill`")
|
||||
assert review_mode in ReviewMode
|
||||
assert self.instruct_content, 'review only support with `schema != "raw"`'
|
||||
|
||||
if strgy == "simple":
|
||||
review_comments = await self.simple_review(review_mode)
|
||||
elif strgy == "complex":
|
||||
# review each child node one-by-one
|
||||
review_comments = {}
|
||||
for _, child in self.children.items():
|
||||
child_review_comment = await child.simple_review(review_mode)
|
||||
review_comments.update(child_review_comment)
|
||||
|
||||
return review_comments
|
||||
|
||||
async def human_revise(self) -> dict[str, str]:
|
||||
review_contents = HumanInteraction().interact_with_instruct_content(
|
||||
instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise"
|
||||
)
|
||||
# re-fill the ActionNode
|
||||
self.update_instruct_content(review_contents)
|
||||
return review_contents
|
||||
|
||||
def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]:
|
||||
instruct_content_dict = self.instruct_content.model_dump()
|
||||
nodes_output = {}
|
||||
for key, value in instruct_content_dict.items():
|
||||
if key in review_comments:
|
||||
nodes_output[key] = {"value": value, "comment": review_comments[key]}
|
||||
return nodes_output
|
||||
|
||||
async def auto_revise(self, template: str = REVISE_TEMPLATE) -> dict[str, str]:
|
||||
"""revise the value of incorrect keys"""
|
||||
# generate review comments
|
||||
review_comments: dict = await self.auto_review()
|
||||
include_keys = list(review_comments.keys())
|
||||
|
||||
# generate revise content
|
||||
nodes_output = self._makeup_nodes_output_with_comment(review_comments)
|
||||
keys = self.keys()
|
||||
exclude_keys = list(set(keys).difference(include_keys))
|
||||
example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys)
|
||||
instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys)
|
||||
|
||||
prompt = template.format(
|
||||
nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4),
|
||||
example=example,
|
||||
instruction=instruction,
|
||||
constraint=FORMAT_CONSTRAINT,
|
||||
)
|
||||
|
||||
output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys)
|
||||
output_class_name = f"{self.key}_AN_REVISE"
|
||||
content, scontent = await self._aask_v1(
|
||||
prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json"
|
||||
)
|
||||
|
||||
# re-fill the ActionNode
|
||||
sc_dict = scontent.model_dump()
|
||||
self.update_instruct_content(sc_dict)
|
||||
return sc_dict
|
||||
|
||||
async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
|
||||
if revise_mode == ReviseMode.HUMAN:
|
||||
revise_contents = await self.human_revise()
|
||||
else:
|
||||
revise_contents = await self.auto_revise()
|
||||
|
||||
return revise_contents
|
||||
|
||||
async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]:
|
||||
"""revise the content of ActionNode and update the instruct_content
|
||||
|
||||
:param strgy: simple/complex
|
||||
- simple: run only once
|
||||
- complex: run each node
|
||||
"""
|
||||
if not hasattr(self, "llm"):
|
||||
raise RuntimeError("use `revise` after `fill`")
|
||||
assert revise_mode in ReviseMode
|
||||
assert self.instruct_content, 'revise only support with `schema != "raw"`'
|
||||
|
||||
if strgy == "simple":
|
||||
revise_contents = await self.simple_revise(revise_mode)
|
||||
elif strgy == "complex":
|
||||
# revise each child node one-by-one
|
||||
revise_contents = {}
|
||||
for _, child in self.children.items():
|
||||
child_revise_content = await child.simple_revise(revise_mode)
|
||||
revise_contents.update(child_revise_content)
|
||||
self.update_instruct_content(revise_contents)
|
||||
|
||||
return revise_contents
|
||||
|
|
|
|||
107
metagpt/utils/human_interaction.py
Normal file
107
metagpt/utils/human_interaction.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : human interaction to get required type text
|
||||
|
||||
import json
|
||||
from typing import Any, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
class HumanInteraction(object):
|
||||
stop_list = ("q", "quit", "exit")
|
||||
|
||||
def multilines_input(self, prompt: str = "Enter: ") -> str:
|
||||
logger.warning("Enter your content, use Ctrl-D or Ctrl-Z ( windows ) to save it.")
|
||||
logger.info(f"{prompt}\n")
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
line = input()
|
||||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
return "".join(lines)
|
||||
|
||||
def check_input_type(self, input_str: str, req_type: Type) -> Tuple[bool, Any]:
|
||||
check_ret = True
|
||||
if req_type == str:
|
||||
# required_type = str, just return True
|
||||
return check_ret, input_str
|
||||
try:
|
||||
input_str = input_str.strip()
|
||||
data = json.loads(input_str)
|
||||
except Exception:
|
||||
return False, None
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
tmp_key = "tmp"
|
||||
tmp_cls = actionnode_class.create_model_class(class_name=tmp_key.upper(), mapping={tmp_key: (req_type, ...)})
|
||||
try:
|
||||
_ = tmp_cls(**{tmp_key: data})
|
||||
except Exception:
|
||||
check_ret = False
|
||||
return check_ret, data
|
||||
|
||||
def input_until_valid(self, prompt: str, req_type: Type) -> Any:
|
||||
# check the input with req_type until it's ok
|
||||
while True:
|
||||
input_content = self.multilines_input(prompt)
|
||||
check_ret, structure_content = self.check_input_type(input_content, req_type)
|
||||
if check_ret:
|
||||
break
|
||||
else:
|
||||
logger.error(f"Input content can't meet required_type: {req_type}, please Re-Enter.")
|
||||
return structure_content
|
||||
|
||||
def input_num_until_valid(self, num_max: int) -> int:
|
||||
while True:
|
||||
input_num = input("Enter the num of the interaction key: ")
|
||||
input_num = input_num.strip()
|
||||
if input_num in self.stop_list:
|
||||
return input_num
|
||||
try:
|
||||
input_num = int(input_num)
|
||||
if 0 <= input_num < num_max:
|
||||
return input_num
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def interact_with_instruct_content(
|
||||
self, instruct_content: BaseModel, mapping: dict = dict(), interact_type: str = "review"
|
||||
) -> dict[str, Any]:
|
||||
assert interact_type in ["review", "revise"]
|
||||
assert instruct_content
|
||||
instruct_content_dict = instruct_content.model_dump()
|
||||
num_fields_map = dict(zip(range(0, len(instruct_content_dict)), instruct_content_dict.keys()))
|
||||
logger.info(
|
||||
f"\n{interact_type.upper()} interaction\n"
|
||||
f"Interaction data: {num_fields_map}\n"
|
||||
f"Enter the num to interact with corresponding field or `q`/`quit`/`exit` to stop interaction.\n"
|
||||
f"Enter the field content until it meet field required type.\n"
|
||||
)
|
||||
|
||||
interact_contents = {}
|
||||
while True:
|
||||
input_num = self.input_num_until_valid(len(instruct_content_dict))
|
||||
if input_num in self.stop_list:
|
||||
logger.warning("Stop human interaction")
|
||||
break
|
||||
|
||||
field = num_fields_map.get(input_num)
|
||||
logger.info(f"You choose to interact with field: {field}, and do a `{interact_type}` operation.")
|
||||
|
||||
if interact_type == "review":
|
||||
prompt = "Enter your review comment: "
|
||||
req_type = str
|
||||
else:
|
||||
prompt = "Enter your revise content: "
|
||||
req_type = mapping.get(field)[0] # revise need input content match the required_type
|
||||
|
||||
field_content = self.input_until_valid(prompt=prompt, req_type=req_type)
|
||||
interact_contents[field] = field_content
|
||||
|
||||
return interact_contents
|
||||
|
|
@ -11,7 +11,7 @@ import pytest
|
|||
from pydantic import ValidationError
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -98,6 +98,83 @@ async def test_action_node_two_layer():
|
|||
assert "579" in answer2.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_review():
|
||||
key = "Project Name"
|
||||
node_a = ActionNode(
|
||||
key=key,
|
||||
expected_type=str,
|
||||
instruction='According to the content of "Original Requirements," name the project using snake case style '
|
||||
"with underline, like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
|
||||
|
||||
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
assert list(review_comments.keys())[0] == key
|
||||
|
||||
review_comments = await node_a.review(strgy="complex", review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 0
|
||||
|
||||
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node.review()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
|
||||
review_comments = await node.review(review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
assert list(review_comments.keys())[0] == key
|
||||
|
||||
review_comments = await node.review(strgy="complex", review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
assert list(review_comments.keys())[0] == key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_revise():
|
||||
key = "Project Name"
|
||||
node_a = ActionNode(
|
||||
key=key,
|
||||
expected_type=str,
|
||||
instruction='According to the content of "Original Requirements," name the project using snake case style '
|
||||
"with underline, like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
|
||||
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
assert "game_snake" in getattr(node_a.instruct_content, key)
|
||||
|
||||
revise_contents = await node_a.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 0
|
||||
|
||||
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node.revise()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
setattr(node.instruct_content, key, "game snake")
|
||||
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
assert "game_snake" in getattr(node.instruct_content, key)
|
||||
|
||||
revise_contents = await node.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
assert "game_snake" in getattr(node.instruct_content, key)
|
||||
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
|
|
@ -138,10 +215,10 @@ def test_create_model_class():
|
|||
assert test_class.__name__ == "test_class"
|
||||
|
||||
output = test_class(**t_dict)
|
||||
print(output.schema())
|
||||
assert output.schema()["title"] == "test_class"
|
||||
assert output.schema()["type"] == "object"
|
||||
assert output.schema()["properties"]["Full API spec"]
|
||||
print(output.model_json_schema())
|
||||
assert output.model_json_schema()["title"] == "test_class"
|
||||
assert output.model_json_schema()["type"] == "object"
|
||||
assert output.model_json_schema()["properties"]["Full API spec"]
|
||||
|
||||
|
||||
def test_create_model_class_with_fields_unrecognized():
|
||||
|
|
|
|||
74
tests/metagpt/utils/test_human_interaction.py
Normal file
74
tests/metagpt/utils/test_human_interaction.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of human_interaction
|
||||
|
||||
import pytest
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.utils.human_interaction import HumanInteraction
|
||||
|
||||
|
||||
class InstructContent(BaseModel):
|
||||
test_field1: str = ""
|
||||
test_field2: list[str] = []
|
||||
|
||||
|
||||
data_mapping = {
|
||||
"test_field1": (str, ...),
|
||||
"test_field2": (list[str], ...)
|
||||
}
|
||||
|
||||
human_interaction = HumanInteraction()
|
||||
|
||||
|
||||
def test_input_num(mocker):
|
||||
mocker.patch("builtins.input", lambda _: "quit")
|
||||
|
||||
interact_contents = human_interaction.interact_with_instruct_content(InstructContent(), data_mapping)
|
||||
assert len(interact_contents) == 0
|
||||
|
||||
mocker.patch("builtins.input", lambda _: "1")
|
||||
input_num = human_interaction.input_num_until_valid(2)
|
||||
assert input_num == 1
|
||||
|
||||
|
||||
def test_check_input_type():
|
||||
ret, _ = human_interaction.check_input_type(input_str="test string",
|
||||
req_type=str)
|
||||
assert ret
|
||||
|
||||
ret, _ = human_interaction.check_input_type(input_str='["test string"]',
|
||||
req_type=list[str])
|
||||
assert ret
|
||||
|
||||
ret, _ = human_interaction.check_input_type(input_str='{"key", "value"}',
|
||||
req_type=list[str])
|
||||
assert not ret
|
||||
|
||||
|
||||
global_index = 0
|
||||
|
||||
|
||||
def mock_input(*args, **kwargs):
|
||||
"""there are multi input call, return it by global_index"""
|
||||
arr = ["1", '["test"]', "ignore", "quit"]
|
||||
global global_index
|
||||
global_index += 1
|
||||
if global_index == 3:
|
||||
raise EOFError()
|
||||
val = arr[global_index-1]
|
||||
return val
|
||||
|
||||
|
||||
def test_human_interact_valid_content(mocker):
|
||||
mocker.patch("builtins.input", mock_input)
|
||||
input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "review")
|
||||
assert len(input_contents) == 1
|
||||
assert input_contents["test_field2"] == '["test"]'
|
||||
|
||||
global global_index
|
||||
global_index = 0
|
||||
input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "revise")
|
||||
assert len(input_contents) == 1
|
||||
assert input_contents["test_field2"] == ["test"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue