fix pydantic v2 model validation for custom class

This commit is contained in:
geekan 2024-01-02 19:27:42 +08:00
parent 5649fac62d
commit 0b9becf93f
2 changed files with 24 additions and 16 deletions

View file

@ -11,7 +11,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
import json
from typing import Any, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, create_model, field_validator, model_validator
from pydantic import BaseModel, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
@ -135,26 +135,21 @@ class ActionNode:
@classmethod
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""
new_class = create_model(class_name, **mapping)
@field_validator("*", mode="before")
@classmethod
def check_name(v, field):
if field.name not in mapping.keys():
raise ValueError(f"Unrecognized block: {field.name}")
return v
@model_validator(mode="before")
@classmethod
def check_missing_fields(values):
def check_fields(cls, values):
required_fields = set(mapping.keys())
missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")
unrecognized_fields = set(values.keys()) - required_fields
if unrecognized_fields:
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
return values
new_class.__validator_check_name = classmethod(check_name)
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
new_class = create_model(class_name, __validators__=validators, **mapping)
return new_class
def create_children_class(self, exclude=None):

View file

@ -8,6 +8,7 @@
from typing import List, Tuple
import pytest
from pydantic import ValidationError
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
@ -113,6 +114,10 @@ t_dict = {
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
}
t_dict_min = {
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
}
WRITE_TASKS_OUTPUT_MAPPING = {
"Required Python third-party packages": (str, ...),
"Required Other language third-party packages": (str, ...),
@ -139,11 +144,19 @@ def test_create_model_class():
assert output.schema()["properties"]["Full API spec"]
def test_create_model_class_missing():
def test_create_model_class_with_fields_unrecognized():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
assert test_class.__name__ == "test_class"
_ = test_class(**t_dict) # 这里应该要挂掉
_ = test_class(**t_dict) # just warning
def test_create_model_class_with_fields_missing():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
with pytest.raises(ValidationError):
_ = test_class(**t_dict_min)
def test_create_model_class_with_mapping():