mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add action_outcls decorator to support init same class with same class name and fields
This commit is contained in:
parent
ba6793383f
commit
58c2c55ee9
5 changed files with 98 additions and 2 deletions
|
|
@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|||
from pydantic import BaseModel, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
|
|
@ -201,6 +202,7 @@ class ActionNode:
|
|||
return {} if exclude and self.key in exclude else self.get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
@register_action_outcls
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
|
||||
|
|
|
|||
42
metagpt/actions/action_outcls_registry.py
Normal file
42
metagpt/actions/action_outcls_registry.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class
|
||||
# with same class name and mapping
|
||||
|
||||
from functools import wraps
|
||||
|
||||
|
||||
action_outcls_registry = dict()
|
||||
|
||||
|
||||
def register_action_outcls(func):
|
||||
"""
|
||||
Due to `create_model` return different Class even they have same class name and mapping.
|
||||
In order to do a comparison, use outcls_id to identify same Class with same class name and field definition
|
||||
"""
|
||||
@wraps(func)
|
||||
def decorater(*args, **kwargs):
|
||||
"""
|
||||
arr example
|
||||
[<class 'metagpt.actions.action_node.ActionNode'>, 'test', {'field': (str, Ellipsis)}]
|
||||
"""
|
||||
arr = list(args) + list(kwargs.values())
|
||||
"""
|
||||
outcls_id example
|
||||
"<class 'metagpt.actions.action_node.ActionNode'>_test_{'field': (str, Ellipsis)}"
|
||||
"""
|
||||
for idx, item in enumerate(arr):
|
||||
if isinstance(item, dict):
|
||||
arr[idx] = dict(sorted(item.items()))
|
||||
outcls_id = "_".join([str(i) for i in arr])
|
||||
# eliminate typing influence
|
||||
outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict")
|
||||
|
||||
if outcls_id in action_outcls_registry:
|
||||
return action_outcls_registry[outcls_id]
|
||||
|
||||
out_cls = func(*args, **kwargs)
|
||||
action_outcls_registry[outcls_id] = out_cls
|
||||
return out_cls
|
||||
|
||||
return decorater
|
||||
46
tests/metagpt/actions/test_action_outcls_registry.py
Normal file
46
tests/metagpt/actions/test_action_outcls_registry.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of action_outcls_registry
|
||||
|
||||
from typing import List
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
|
||||
def test_action_outcls_registry():
|
||||
class_name = "test"
|
||||
out_mapping = {"field": (list[str], ...), "field1": (str, ...)}
|
||||
out_data = {"field": ["field value1", "field value2"], "field1": "field1 value1"}
|
||||
|
||||
outcls = ActionNode.create_model_class(class_name, mapping=out_mapping)
|
||||
outinst = outcls(**out_data)
|
||||
|
||||
outcls1 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping)
|
||||
outinst1 = outcls1(**out_data)
|
||||
assert outinst1 == outinst
|
||||
|
||||
outcls2 = ActionNode(key="",
|
||||
expected_type=str,
|
||||
instruction="",
|
||||
example="").create_model_class(class_name, out_mapping)
|
||||
outinst2 = outcls2(**out_data)
|
||||
assert outinst2 == outinst
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field": (list[str], ...)} # different order
|
||||
outcls3 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping)
|
||||
outinst3 = outcls3(**out_data)
|
||||
assert outinst3 == outinst
|
||||
|
||||
out_mapping2 = {"field1": (str, ...), "field": (List[str], ...)} # typing case
|
||||
outcls4 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping2)
|
||||
outinst4 = outcls4(**out_data)
|
||||
assert outinst4 == outinst
|
||||
|
||||
out_data2 = {"field2": ["field2 value1", "field2 value2"], "field1": "field1 value1"}
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} # List first
|
||||
outcls5 = ActionNode.create_model_class(class_name, out_mapping)
|
||||
outinst5 = outcls5(**out_data2)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (list[str], ...)}
|
||||
outcls6 = ActionNode.create_model_class(class_name, out_mapping)
|
||||
outinst6 = outcls6(**out_data2)
|
||||
assert outinst5 == outinst6
|
||||
|
|
@ -19,5 +19,6 @@ async def test_architect_serdeser():
|
|||
new_role = Architect(**ser_role_dict)
|
||||
assert new_role.name == "Bob"
|
||||
assert len(new_role.actions) == 1
|
||||
assert len(new_role.rc.watch) == 1
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
await new_role.actions[0].run(with_messages="write a cli snake game")
|
||||
|
|
|
|||
|
|
@ -31,15 +31,17 @@ def test_message_serdeser_from_create_model():
|
|||
assert new_message.cause_by == any_to_str(WriteCode)
|
||||
assert new_message.cause_by in [any_to_str(WriteCode)]
|
||||
|
||||
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
|
||||
assert new_message.instruct_content != ic_inst
|
||||
assert new_message.instruct_content == ic_obj(**out_data)
|
||||
assert new_message.instruct_content == ic_inst
|
||||
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
|
||||
assert new_message == message
|
||||
|
||||
mock_msg = MockMessage()
|
||||
message = Message(content="test_ic", instruct_content=mock_msg)
|
||||
ser_data = message.model_dump()
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.instruct_content == mock_msg
|
||||
assert new_message == message
|
||||
|
||||
|
||||
def test_message_without_postprocess():
|
||||
|
|
@ -54,6 +56,7 @@ def test_message_without_postprocess():
|
|||
ser_data["instruct_content"] = None
|
||||
new_message = MockICMessage(**ser_data)
|
||||
assert new_message.instruct_content != ic_obj(**out_data)
|
||||
assert new_message != message
|
||||
|
||||
|
||||
def test_message_serdeser_from_basecontext():
|
||||
|
|
@ -83,6 +86,7 @@ def test_message_serdeser_from_basecontext():
|
|||
new_code_ctxt_msg = Message(**ser_data)
|
||||
assert new_code_ctxt_msg.instruct_content == code_ctxt
|
||||
assert new_code_ctxt_msg.instruct_content.code_doc.filename == "game.py"
|
||||
assert new_code_ctxt_msg == code_ctxt_msg
|
||||
|
||||
testing_ctxt = TestingContext(
|
||||
filename="test.py",
|
||||
|
|
@ -94,3 +98,4 @@ def test_message_serdeser_from_basecontext():
|
|||
new_testing_ctxt_msg = Message(**ser_data)
|
||||
assert new_testing_ctxt_msg.instruct_content == testing_ctxt
|
||||
assert new_testing_ctxt_msg.instruct_content.test_doc.filename == "test.py"
|
||||
assert new_testing_ctxt_msg == testing_ctxt_msg
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue