diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 4c06d0d1d..6c65b33ef 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -11,7 +11,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because import json from typing import Any, Dict, List, Optional, Tuple, Type -from pydantic import BaseModel, create_model, field_validator, model_validator +from pydantic import BaseModel, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.config import CONFIG @@ -135,26 +135,21 @@ class ActionNode: @classmethod def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): """基于pydantic v1的模型动态生成,用来检验结果类型正确性""" - new_class = create_model(class_name, **mapping) - @field_validator("*", mode="before") - @classmethod - def check_name(v, field): - if field.name not in mapping.keys(): - raise ValueError(f"Unrecognized block: {field.name}") - return v - - @model_validator(mode="before") - @classmethod - def check_missing_fields(values): + def check_fields(cls, values): required_fields = set(mapping.keys()) missing_fields = required_fields - set(values.keys()) if missing_fields: raise ValueError(f"Missing fields: {missing_fields}") + + unrecognized_fields = set(values.keys()) - required_fields + if unrecognized_fields: + logger.warning(f"Unrecognized fields: {unrecognized_fields}") return values - new_class.__validator_check_name = classmethod(check_name) - new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields) + validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} + + new_class = create_model(class_name, __validators__=validators, **mapping) return new_class def create_children_class(self, exclude=None): diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 74b4df27f..25aceaa2e 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -8,6 +8,7 @@ from typing import List, Tuple import pytest +from pydantic import ValidationError from metagpt.actions import Action from metagpt.actions.action_node import ActionNode @@ -113,6 +114,10 @@ t_dict = { "Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?", } +t_dict_min = { + "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', +} + WRITE_TASKS_OUTPUT_MAPPING = { "Required Python third-party packages": (str, ...), "Required Other language third-party packages": (str, ...), @@ -139,11 +144,19 @@ def test_create_model_class(): assert output.schema()["properties"]["Full API spec"] -def test_create_model_class_missing(): +def test_create_model_class_with_fields_unrecognized(): test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING) assert test_class.__name__ == "test_class" - _ = test_class(**t_dict) # 这里应该要挂掉 + _ = test_class(**t_dict) # just warning + + +def test_create_model_class_with_fields_missing(): + test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) + assert test_class.__name__ == "test_class" + + with pytest.raises(ValidationError): + _ = test_class(**t_dict_min) def test_create_model_class_with_mapping():