From e199c6b476bca1de6ca35c63af3c082eb91650f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 6 Aug 2024 15:03:24 +0800 Subject: [PATCH] fixbug: Optional not working --- metagpt/actions/action_node.py | 18 ++++++++++++++++-- tests/metagpt/actions/test_action_node.py | 18 +++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 31e4cc0fc..4c07ed99d 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -236,13 +236,27 @@ class ActionNode: def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v2的模型动态生成,用来检验结果类型正确性""" + def is_optional_type(tp): + if typing.get_origin(tp) is Union: + args = typing.get_args(tp) + non_none_types = [arg for arg in args if arg is not type(None)] + return len(non_none_types) == 1 and len(args) == 2 + return False + def check_fields(cls, values): - required_fields = set(mapping.keys()) + all_fields = set(mapping.keys()) + required_fields = set() + for k, v in mapping.items(): + type_v, field_info = v + if is_optional_type(type_v): + continue + required_fields.add(k) + missing_fields = required_fields - set(values.keys()) if missing_fields: raise ValueError(f"Missing fields: {missing_fields}") - unrecognized_fields = set(values.keys()) - required_fields + unrecognized_fields = set(values.keys()) - all_fields if unrecognized_fields: logger.warning(f"Unrecognized fields: {unrecognized_fields}") return values diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 989e2249c..338d87242 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -6,7 +6,7 @@ @File : test_action_node.py """ from pathlib import Path -from typing import List, Tuple +from typing import List, Optional, Tuple import pytest from pydantic import BaseModel, Field, ValidationError @@ -302,6 +302,18 @@ def test_action_node_from_pydantic_and_print_everything(): assert "tasks" in code, "tasks should be in code" +def test_optional(): + mapping = { + "Logic Analysis": (Optional[List[Tuple[str, str]]], Field(default=None)), + "Task list": (Optional[List[str]], None), + "Anything UNCLEAR": (Optional[str], None), + } + m = {"Anything UNCLEAR": "a"} + t = ActionNode.create_model_class("test_class_1", mapping) + + t1 = t(**m) + assert t1 + + if __name__ == "__main__": - test_create_model_class() - test_create_model_class_with_mapping() + pytest.main([__file__, "-s"])