diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 286cf534d..b4d8c32df 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union from pydantic import BaseModel, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential +from metagpt.actions.action_outcls_registry import register_action_outcls from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess @@ -201,6 +202,7 @@ class ActionNode: return {} if exclude and self.key in exclude else self.get_self_mapping() @classmethod + @register_action_outcls def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" diff --git a/metagpt/actions/action_outcls_registry.py b/metagpt/actions/action_outcls_registry.py new file mode 100644 index 000000000..780a061b4 --- /dev/null +++ b/metagpt/actions/action_outcls_registry.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class +# with same class name and mapping + +from functools import wraps + + +action_outcls_registry = dict() + + +def register_action_outcls(func): + """ + Due to `create_model` return different Class even they have same class name and mapping. + In order to do a comparison, use outcls_id to identify same Class with same class name and field definition + """ + @wraps(func) + def decorater(*args, **kwargs): + """ + arr example + [, 'test', {'field': (str, Ellipsis)}] + """ + arr = list(args) + list(kwargs.values()) + """ + outcls_id example + "_test_{'field': (str, Ellipsis)}" + """ + for idx, item in enumerate(arr): + if isinstance(item, dict): + arr[idx] = dict(sorted(item.items())) + outcls_id = "_".join([str(i) for i in arr]) + # eliminate typing influence + outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict") + + if outcls_id in action_outcls_registry: + return action_outcls_registry[outcls_id] + + out_cls = func(*args, **kwargs) + action_outcls_registry[outcls_id] = out_cls + return out_cls + + return decorater diff --git a/tests/metagpt/actions/test_action_outcls_registry.py b/tests/metagpt/actions/test_action_outcls_registry.py new file mode 100644 index 000000000..e949ac16b --- /dev/null +++ b/tests/metagpt/actions/test_action_outcls_registry.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : unittest of action_outcls_registry + +from typing import List +from metagpt.actions.action_node import ActionNode + + +def test_action_outcls_registry(): + class_name = "test" + out_mapping = {"field": (list[str], ...), "field1": (str, ...)} + out_data = {"field": ["field value1", "field value2"], "field1": "field1 value1"} + + outcls = ActionNode.create_model_class(class_name, mapping=out_mapping) + outinst = outcls(**out_data) + + outcls1 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping) + outinst1 = outcls1(**out_data) + assert outinst1 == outinst + + outcls2 = ActionNode(key="", + expected_type=str, + instruction="", + example="").create_model_class(class_name, out_mapping) + outinst2 = outcls2(**out_data) + assert outinst2 == outinst + + out_mapping = {"field1": (str, ...), "field": (list[str], ...)} # different order + outcls3 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping) + outinst3 = outcls3(**out_data) + assert outinst3 == outinst + + out_mapping2 = {"field1": (str, ...), "field": (List[str], ...)} # typing case + outcls4 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping2) + outinst4 = outcls4(**out_data) + assert outinst4 == outinst + + out_data2 = {"field2": ["field2 value1", "field2 value2"], "field1": "field1 value1"} + out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} # List first + outcls5 = ActionNode.create_model_class(class_name, out_mapping) + outinst5 = outcls5(**out_data2) + + out_mapping = {"field1": (str, ...), "field2": (list[str], ...)} + outcls6 = ActionNode.create_model_class(class_name, out_mapping) + outinst6 = outcls6(**out_data2) + assert outinst5 == outinst6 diff --git a/tests/metagpt/serialize_deserialize/test_architect.py b/tests/metagpt/serialize_deserialize/test_architect.py index 343662494..a6823197a 100644 --- a/tests/metagpt/serialize_deserialize/test_architect.py +++ b/tests/metagpt/serialize_deserialize/test_architect.py @@ -19,5 +19,6 @@ async def test_architect_serdeser(): new_role = Architect(**ser_role_dict) assert new_role.name == "Bob" assert len(new_role.actions) == 1 + assert len(new_role.rc.watch) == 1 assert isinstance(new_role.actions[0], Action) await new_role.actions[0].run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_schema.py b/tests/metagpt/serialize_deserialize/test_schema.py index b55b82088..c5a457a1e 100644 --- a/tests/metagpt/serialize_deserialize/test_schema.py +++ b/tests/metagpt/serialize_deserialize/test_schema.py @@ -31,15 +31,17 @@ def test_message_serdeser_from_create_model(): assert new_message.cause_by == any_to_str(WriteCode) assert new_message.cause_by in [any_to_str(WriteCode)] - assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=` - assert new_message.instruct_content != ic_inst + assert new_message.instruct_content == ic_obj(**out_data) + assert new_message.instruct_content == ic_inst assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump() + assert new_message == message mock_msg = MockMessage() message = Message(content="test_ic", instruct_content=mock_msg) ser_data = message.model_dump() new_message = Message(**ser_data) assert new_message.instruct_content == mock_msg + assert new_message == message def test_message_without_postprocess(): @@ -54,6 +56,7 @@ def test_message_without_postprocess(): ser_data["instruct_content"] = None new_message = MockICMessage(**ser_data) assert new_message.instruct_content != ic_obj(**out_data) + assert new_message != message def test_message_serdeser_from_basecontext(): @@ -83,6 +86,7 @@ def test_message_serdeser_from_basecontext(): new_code_ctxt_msg = Message(**ser_data) assert new_code_ctxt_msg.instruct_content == code_ctxt assert new_code_ctxt_msg.instruct_content.code_doc.filename == "game.py" + assert new_code_ctxt_msg == code_ctxt_msg testing_ctxt = TestingContext( filename="test.py", @@ -94,3 +98,4 @@ def test_message_serdeser_from_basecontext(): new_testing_ctxt_msg = Message(**ser_data) assert new_testing_ctxt_msg.instruct_content == testing_ctxt assert new_testing_ctxt_msg.instruct_content.test_doc.filename == "test.py" + assert new_testing_ctxt_msg == testing_ctxt_msg