add action_outcls decorator to support init same class with same class name and fields

This commit is contained in:
better629 2024-01-10 19:13:19 +08:00
parent ba6793383f
commit 58c2c55ee9
5 changed files with 98 additions and 2 deletions

View file

@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@ -201,6 +202,7 @@ class ActionNode:
return {} if exclude and self.key in exclude else self.get_self_mapping()
@classmethod
@register_action_outcls
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""

View file

@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class
# with same class name and mapping
from functools import wraps
action_outcls_registry = dict()
def register_action_outcls(func):
"""
Due to `create_model` return different Class even they have same class name and mapping.
In order to do a comparison, use outcls_id to identify same Class with same class name and field definition
"""
@wraps(func)
def decorater(*args, **kwargs):
"""
arr example
[<class 'metagpt.actions.action_node.ActionNode'>, 'test', {'field': (str, Ellipsis)}]
"""
arr = list(args) + list(kwargs.values())
"""
outcls_id example
"<class 'metagpt.actions.action_node.ActionNode'>_test_{'field': (str, Ellipsis)}"
"""
for idx, item in enumerate(arr):
if isinstance(item, dict):
arr[idx] = dict(sorted(item.items()))
outcls_id = "_".join([str(i) for i in arr])
# eliminate typing influence
outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict")
if outcls_id in action_outcls_registry:
return action_outcls_registry[outcls_id]
out_cls = func(*args, **kwargs)
action_outcls_registry[outcls_id] = out_cls
return out_cls
return decorater

View file

@ -0,0 +1,46 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of action_outcls_registry
from typing import List
from metagpt.actions.action_node import ActionNode
def test_action_outcls_registry():
class_name = "test"
out_mapping = {"field": (list[str], ...), "field1": (str, ...)}
out_data = {"field": ["field value1", "field value2"], "field1": "field1 value1"}
outcls = ActionNode.create_model_class(class_name, mapping=out_mapping)
outinst = outcls(**out_data)
outcls1 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping)
outinst1 = outcls1(**out_data)
assert outinst1 == outinst
outcls2 = ActionNode(key="",
expected_type=str,
instruction="",
example="").create_model_class(class_name, out_mapping)
outinst2 = outcls2(**out_data)
assert outinst2 == outinst
out_mapping = {"field1": (str, ...), "field": (list[str], ...)} # different order
outcls3 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping)
outinst3 = outcls3(**out_data)
assert outinst3 == outinst
out_mapping2 = {"field1": (str, ...), "field": (List[str], ...)} # typing case
outcls4 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping2)
outinst4 = outcls4(**out_data)
assert outinst4 == outinst
out_data2 = {"field2": ["field2 value1", "field2 value2"], "field1": "field1 value1"}
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} # List first
outcls5 = ActionNode.create_model_class(class_name, out_mapping)
outinst5 = outcls5(**out_data2)
out_mapping = {"field1": (str, ...), "field2": (list[str], ...)}
outcls6 = ActionNode.create_model_class(class_name, out_mapping)
outinst6 = outcls6(**out_data2)
assert outinst5 == outinst6

View file

@ -19,5 +19,6 @@ async def test_architect_serdeser():
new_role = Architect(**ser_role_dict)
assert new_role.name == "Bob"
assert len(new_role.actions) == 1
assert len(new_role.rc.watch) == 1
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run(with_messages="write a cli snake game")

View file

@ -31,15 +31,17 @@ def test_message_serdeser_from_create_model():
assert new_message.cause_by == any_to_str(WriteCode)
assert new_message.cause_by in [any_to_str(WriteCode)]
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content != ic_inst
assert new_message.instruct_content == ic_obj(**out_data)
assert new_message.instruct_content == ic_inst
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
assert new_message == message
mock_msg = MockMessage()
message = Message(content="test_ic", instruct_content=mock_msg)
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content == mock_msg
assert new_message == message
def test_message_without_postprocess():
@ -54,6 +56,7 @@ def test_message_without_postprocess():
ser_data["instruct_content"] = None
new_message = MockICMessage(**ser_data)
assert new_message.instruct_content != ic_obj(**out_data)
assert new_message != message
def test_message_serdeser_from_basecontext():
@ -83,6 +86,7 @@ def test_message_serdeser_from_basecontext():
new_code_ctxt_msg = Message(**ser_data)
assert new_code_ctxt_msg.instruct_content == code_ctxt
assert new_code_ctxt_msg.instruct_content.code_doc.filename == "game.py"
assert new_code_ctxt_msg == code_ctxt_msg
testing_ctxt = TestingContext(
filename="test.py",
@ -94,3 +98,4 @@ def test_message_serdeser_from_basecontext():
new_testing_ctxt_msg = Message(**ser_data)
assert new_testing_ctxt_msg.instruct_content == testing_ctxt
assert new_testing_ctxt_msg.instruct_content.test_doc.filename == "test.py"
assert new_testing_ctxt_msg == testing_ctxt_msg