diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 633fc9841..8577338b6 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -9,7 +9,8 @@ NOTE: You should use typing.List instead of list to do type annotation. Because we can use typing to extract the type of the node, but we cannot use built-in list to extract. """ import json -from typing import Any, Dict, List, Optional, Tuple, Type +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type, Union from pydantic import BaseModel, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -18,6 +19,18 @@ from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess from metagpt.utils.common import OutputParser, general_after_log +from metagpt.utils.human_interaction import HumanInteraction + + +class ReviewMode(Enum): + HUMAN = "human" + AUTO = "auto" + + +class ReviseMode(Enum): + HUMAN = "human" + AUTO = "auto" + TAG = "CONTENT" @@ -44,6 +57,58 @@ SIMPLE_TEMPLATE = """ Follow instructions of nodes, generate output and make sure it follows the format example. """ +REVIEW_TEMPLATE = """ +## context +Compare the keys of nodes_output and the corresponding requirements one by one. If a key that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys. + +### nodes_output +{nodes_output} + +----- + +## format example +[{tag}] +{{ + "key1": "comment1", + "key2": "comment2", + "keyn": "commentn" +}} +[/{tag}] + +## nodes: ": # " +- key1: # the first key name of mismatch key +- key2: # the second key name of mismatch key +- keyn: # the last key name of mismatch key + +## constraint +{constraint} + +## action +generate output and make sure it follows the format example. +""" + +REVISE_TEMPLATE = """ +## context +change the nodes_output key's value to meet its comment and no need to add extra comment. + +### nodes_output +{nodes_output} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +generate output and make sure it follows the format example. +""" + def dict_to_markdown(d, prefix="- ", kv_sep="\n", postfix="\n"): markdown_str = "" @@ -104,6 +169,9 @@ class ActionNode: """增加子ActionNode""" self.children[node.key] = node + def get_child(self, key: str) -> Union["ActionNode", None]: + return self.children.get(key, None) + def add_children(self, nodes: List["ActionNode"]): """批量增加子ActionNode""" for node in nodes: @@ -151,6 +219,11 @@ class ActionNode: new_class = create_model(class_name, __validators__=validators, **mapping) return new_class + def create_class(self, mode: str = "auto", class_name: str = None, exclude=None): + class_name = class_name if class_name else f"{self.key}_AN" + mapping = self.get_mapping(mode=mode, exclude=exclude) + return self.create_model_class(class_name, mapping) + def create_children_class(self, exclude=None): """使用object内有的字段直接生成model_class""" class_name = f"{self.key}_AN" @@ -185,6 +258,25 @@ class ActionNode: return node_dict + def update_instruct_content(self, incre_data: dict[str, Any]): + assert self.instruct_content + origin_sc_dict = self.instruct_content.model_dump() + origin_sc_dict.update(incre_data) + output_class = self.create_class() + self.instruct_content = output_class(**origin_sc_dict) + + def keys(self, mode: str = "auto") -> list: + if mode == "children" or (mode == "auto" and self.children): + keys = [] + else: + keys = [self.key] + if mode == "root": + return keys + + for _, child_node in self.children.items(): + keys.append(child_node.key) + return keys + def compile_to(self, i: Dict, schema, kv_sep) -> str: if schema == "json": return json.dumps(i, indent=4) @@ -342,7 +434,170 @@ class ActionNode: if exclude and i.key in exclude: continue child = await i.simple_fill(schema=schema, mode=mode, timeout=timeout, exclude=exclude) - tmp.update(child.instruct_content.dict()) + tmp.update(child.instruct_content.model_dump()) cls = self.create_children_class() self.instruct_content = cls(**tmp) return self + + async def human_review(self) -> dict[str, str]: + review_comments = HumanInteraction().interact_with_instruct_content( + instruct_content=self.instruct_content, interact_type="review" + ) + + return review_comments + + def _makeup_nodes_output_with_req(self) -> dict[str, str]: + instruct_content_dict = self.instruct_content.model_dump() + nodes_output = {} + for key, value in instruct_content_dict.items(): + child = self.get_child(key) + nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction} + return nodes_output + + async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]: + """use key's output value and its instruction to review the modification comment""" + nodes_output = self._makeup_nodes_output_with_req() + """nodes_output format: + { + "key": {"value": "output value", "requirement": "key instruction"} + } + """ + if not nodes_output: + return dict() + + prompt = template.format( + nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4), tag=TAG, constraint=FORMAT_CONSTRAINT + ) + + content = await self.llm.aask(prompt) + # Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys + # of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class. + keys = self.keys() + include_keys = [] + for key in keys: + if f'"{key}":' in content: + include_keys.append(key) + if not include_keys: + return dict() + + exclude_keys = list(set(keys).difference(include_keys)) + output_class_name = f"{self.key}_AN_REVIEW" + output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys) + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + instruct_content = output_class(**parsed_data) + return instruct_content.model_dump() + + async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO): + # generate review comments + if review_mode == ReviewMode.HUMAN: + review_comments = await self.human_review() + else: + review_comments = await self.auto_review() + + if not review_comments: + logger.warning("There are no review comments") + return review_comments + + async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO): + """only give the review comment of each exist and mismatch key + + :param strgy: simple/complex + - simple: run only once + - complex: run each node + """ + if not hasattr(self, "llm"): + raise RuntimeError("use `review` after `fill`") + assert review_mode in ReviewMode + assert self.instruct_content, 'review only support with `schema != "raw"`' + + if strgy == "simple": + review_comments = await self.simple_review(review_mode) + elif strgy == "complex": + # review each child node one-by-one + review_comments = {} + for _, child in self.children.items(): + child_review_comment = await child.simple_review(review_mode) + review_comments.update(child_review_comment) + + return review_comments + + async def human_revise(self) -> dict[str, str]: + review_contents = HumanInteraction().interact_with_instruct_content( + instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise" + ) + # re-fill the ActionNode + self.update_instruct_content(review_contents) + return review_contents + + def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]: + instruct_content_dict = self.instruct_content.model_dump() + nodes_output = {} + for key, value in instruct_content_dict.items(): + if key in review_comments: + nodes_output[key] = {"value": value, "comment": review_comments[key]} + return nodes_output + + async def auto_revise(self, template: str = REVISE_TEMPLATE) -> dict[str, str]: + """revise the value of incorrect keys""" + # generate review comments + review_comments: dict = await self.auto_review() + include_keys = list(review_comments.keys()) + + # generate revise content + nodes_output = self._makeup_nodes_output_with_comment(review_comments) + keys = self.keys() + exclude_keys = list(set(keys).difference(include_keys)) + example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys) + instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys) + + prompt = template.format( + nodes_output=json.dumps(nodes_output, ensure_ascii=False, indent=4), + example=example, + instruction=instruction, + constraint=FORMAT_CONSTRAINT, + ) + + output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys) + output_class_name = f"{self.key}_AN_REVISE" + content, scontent = await self._aask_v1( + prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json" + ) + + # re-fill the ActionNode + sc_dict = scontent.model_dump() + self.update_instruct_content(sc_dict) + return sc_dict + + async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]: + if revise_mode == ReviseMode.HUMAN: + revise_contents = await self.human_revise() + else: + revise_contents = await self.auto_revise() + + return revise_contents + + async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]: + """revise the content of ActionNode and update the instruct_content + + :param strgy: simple/complex + - simple: run only once + - complex: run each node + """ + if not hasattr(self, "llm"): + raise RuntimeError("use `revise` after `fill`") + assert revise_mode in ReviseMode + assert self.instruct_content, 'revise only support with `schema != "raw"`' + + if strgy == "simple": + revise_contents = await self.simple_revise(revise_mode) + elif strgy == "complex": + # revise each child node one-by-one + revise_contents = {} + for _, child in self.children.items(): + child_revise_content = await child.simple_revise(revise_mode) + revise_contents.update(child_revise_content) + self.update_instruct_content(revise_contents) + + return revise_contents diff --git a/metagpt/utils/human_interaction.py b/metagpt/utils/human_interaction.py new file mode 100644 index 000000000..3b245cac8 --- /dev/null +++ b/metagpt/utils/human_interaction.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : human interaction to get required type text + +import json +from typing import Any, Tuple, Type + +from pydantic import BaseModel + +from metagpt.logs import logger +from metagpt.utils.common import import_class + + +class HumanInteraction(object): + stop_list = ("q", "quit", "exit") + + def multilines_input(self, prompt: str = "Enter: ") -> str: + logger.warning("Enter your content, use Ctrl-D or Ctrl-Z ( windows ) to save it.") + logger.info(f"{prompt}\n") + lines = [] + while True: + try: + line = input() + lines.append(line) + except EOFError: + break + return "".join(lines) + + def check_input_type(self, input_str: str, req_type: Type) -> Tuple[bool, Any]: + check_ret = True + if req_type == str: + # required_type = str, just return True + return check_ret, input_str + try: + input_str = input_str.strip() + data = json.loads(input_str) + except Exception: + return False, None + + actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import + tmp_key = "tmp" + tmp_cls = actionnode_class.create_model_class(class_name=tmp_key.upper(), mapping={tmp_key: (req_type, ...)}) + try: + _ = tmp_cls(**{tmp_key: data}) + except Exception: + check_ret = False + return check_ret, data + + def input_until_valid(self, prompt: str, req_type: Type) -> Any: + # check the input with req_type until it's ok + while True: + input_content = self.multilines_input(prompt) + check_ret, structure_content = self.check_input_type(input_content, req_type) + if check_ret: + break + else: + logger.error(f"Input content can't meet required_type: {req_type}, please Re-Enter.") + return structure_content + + def input_num_until_valid(self, num_max: int) -> int: + while True: + input_num = input("Enter the num of the interaction key: ") + input_num = input_num.strip() + if input_num in self.stop_list: + return input_num + try: + input_num = int(input_num) + if 0 <= input_num < num_max: + return input_num + except Exception: + pass + + def interact_with_instruct_content( + self, instruct_content: BaseModel, mapping: dict = dict(), interact_type: str = "review" + ) -> dict[str, Any]: + assert interact_type in ["review", "revise"] + assert instruct_content + instruct_content_dict = instruct_content.model_dump() + num_fields_map = dict(zip(range(0, len(instruct_content_dict)), instruct_content_dict.keys())) + logger.info( + f"\n{interact_type.upper()} interaction\n" + f"Interaction data: {num_fields_map}\n" + f"Enter the num to interact with corresponding field or `q`/`quit`/`exit` to stop interaction.\n" + f"Enter the field content until it meet field required type.\n" + ) + + interact_contents = {} + while True: + input_num = self.input_num_until_valid(len(instruct_content_dict)) + if input_num in self.stop_list: + logger.warning("Stop human interaction") + break + + field = num_fields_map.get(input_num) + logger.info(f"You choose to interact with field: {field}, and do a `{interact_type}` operation.") + + if interact_type == "review": + prompt = "Enter your review comment: " + req_type = str + else: + prompt = "Enter your revise content: " + req_type = mapping.get(field)[0] # revise need input content match the required_type + + field_content = self.input_until_valid(prompt=prompt, req_type=req_type) + interact_contents[field] = field_content + + return interact_contents diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 384c4507b..fd2c83ac9 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -11,7 +11,7 @@ import pytest from pydantic import ValidationError from metagpt.actions import Action -from metagpt.actions.action_node import ActionNode +from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode from metagpt.environment import Environment from metagpt.llm import LLM from metagpt.roles import Role @@ -98,6 +98,83 @@ async def test_action_node_two_layer(): assert "579" in answer2.content +@pytest.mark.asyncio +async def test_action_node_review(): + key = "Project Name" + node_a = ActionNode( + key=key, + expected_type=str, + instruction='According to the content of "Original Requirements," name the project using snake case style ' + "with underline, like 'game_2048' or 'simple_crm.", + example="game_2048", + ) + + with pytest.raises(RuntimeError): + _ = await node_a.review() + + _ = await node_a.fill(context=None, llm=LLM()) + setattr(node_a.instruct_content, key, "game snake") # wrong content to review + + review_comments = await node_a.review(review_mode=ReviewMode.AUTO) + assert len(review_comments) == 1 + assert list(review_comments.keys())[0] == key + + review_comments = await node_a.review(strgy="complex", review_mode=ReviewMode.AUTO) + assert len(review_comments) == 0 + + node = ActionNode.from_children(key="WritePRD", nodes=[node_a]) + with pytest.raises(RuntimeError): + _ = await node.review() + + _ = await node.fill(context=None, llm=LLM()) + + review_comments = await node.review(review_mode=ReviewMode.AUTO) + assert len(review_comments) == 1 + assert list(review_comments.keys())[0] == key + + review_comments = await node.review(strgy="complex", review_mode=ReviewMode.AUTO) + assert len(review_comments) == 1 + assert list(review_comments.keys())[0] == key + + +@pytest.mark.asyncio +async def test_action_node_revise(): + key = "Project Name" + node_a = ActionNode( + key=key, + expected_type=str, + instruction='According to the content of "Original Requirements," name the project using snake case style ' + "with underline, like 'game_2048' or 'simple_crm.", + example="game_2048", + ) + + with pytest.raises(RuntimeError): + _ = await node_a.review() + + _ = await node_a.fill(context=None, llm=LLM()) + setattr(node_a.instruct_content, key, "game snake") # wrong content to revise + revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO) + assert len(revise_contents) == 1 + assert "game_snake" in getattr(node_a.instruct_content, key) + + revise_contents = await node_a.revise(strgy="complex", revise_mode=ReviseMode.AUTO) + assert len(revise_contents) == 0 + + node = ActionNode.from_children(key="WritePRD", nodes=[node_a]) + with pytest.raises(RuntimeError): + _ = await node.revise() + + _ = await node.fill(context=None, llm=LLM()) + setattr(node.instruct_content, key, "game snake") + revise_contents = await node.revise(revise_mode=ReviseMode.AUTO) + assert len(revise_contents) == 1 + assert "game_snake" in getattr(node.instruct_content, key) + + revise_contents = await node.revise(strgy="complex", revise_mode=ReviseMode.AUTO) + assert len(revise_contents) == 1 + assert "game_snake" in getattr(node.instruct_content, key) + + t_dict = { "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', "Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n', @@ -138,10 +215,10 @@ def test_create_model_class(): assert test_class.__name__ == "test_class" output = test_class(**t_dict) - print(output.schema()) - assert output.schema()["title"] == "test_class" - assert output.schema()["type"] == "object" - assert output.schema()["properties"]["Full API spec"] + print(output.model_json_schema()) + assert output.model_json_schema()["title"] == "test_class" + assert output.model_json_schema()["type"] == "object" + assert output.model_json_schema()["properties"]["Full API spec"] def test_create_model_class_with_fields_unrecognized(): diff --git a/tests/metagpt/utils/test_human_interaction.py b/tests/metagpt/utils/test_human_interaction.py new file mode 100644 index 000000000..038fc0d98 --- /dev/null +++ b/tests/metagpt/utils/test_human_interaction.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of human_interaction + +import pytest + +from pydantic import BaseModel + +from metagpt.utils.human_interaction import HumanInteraction + + +class InstructContent(BaseModel): + test_field1: str = "" + test_field2: list[str] = [] + + +data_mapping = { + "test_field1": (str, ...), + "test_field2": (list[str], ...) +} + +human_interaction = HumanInteraction() + + +def test_input_num(mocker): + mocker.patch("builtins.input", lambda _: "quit") + + interact_contents = human_interaction.interact_with_instruct_content(InstructContent(), data_mapping) + assert len(interact_contents) == 0 + + mocker.patch("builtins.input", lambda _: "1") + input_num = human_interaction.input_num_until_valid(2) + assert input_num == 1 + + +def test_check_input_type(): + ret, _ = human_interaction.check_input_type(input_str="test string", + req_type=str) + assert ret + + ret, _ = human_interaction.check_input_type(input_str='["test string"]', + req_type=list[str]) + assert ret + + ret, _ = human_interaction.check_input_type(input_str='{"key", "value"}', + req_type=list[str]) + assert not ret + + +global_index = 0 + + +def mock_input(*args, **kwargs): + """there are multi input call, return it by global_index""" + arr = ["1", '["test"]', "ignore", "quit"] + global global_index + global_index += 1 + if global_index == 3: + raise EOFError() + val = arr[global_index-1] + return val + + +def test_human_interact_valid_content(mocker): + mocker.patch("builtins.input", mock_input) + input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "review") + assert len(input_contents) == 1 + assert input_contents["test_field2"] == '["test"]' + + global global_index + global_index = 0 + input_contents = HumanInteraction().interact_with_instruct_content(InstructContent(), data_mapping, "revise") + assert len(input_contents) == 1 + assert input_contents["test_field2"] == ["test"]