mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
fix pydantic v2 model validation for custom class
This commit is contained in:
parent
5649fac62d
commit
0b9becf93f
2 changed files with 24 additions and 16 deletions
|
|
@ -11,7 +11,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model, field_validator, model_validator
|
||||
from pydantic import BaseModel, create_model, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -135,26 +135,21 @@ class ActionNode:
|
|||
@classmethod
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
new_class = create_model(class_name, **mapping)
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def check_name(v, field):
|
||||
if field.name not in mapping.keys():
|
||||
raise ValueError(f"Unrecognized block: {field.name}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_missing_fields(values):
|
||||
def check_fields(cls, values):
|
||||
required_fields = set(mapping.keys())
|
||||
missing_fields = required_fields - set(values.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing fields: {missing_fields}")
|
||||
|
||||
unrecognized_fields = set(values.keys()) - required_fields
|
||||
if unrecognized_fields:
|
||||
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
|
||||
return values
|
||||
|
||||
new_class.__validator_check_name = classmethod(check_name)
|
||||
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
|
||||
validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)}
|
||||
|
||||
new_class = create_model(class_name, __validators__=validators, **mapping)
|
||||
return new_class
|
||||
|
||||
def create_children_class(self, exclude=None):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -113,6 +114,10 @@ t_dict = {
|
|||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
t_dict_min = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
|
|
@ -139,11 +144,19 @@ def test_create_model_class():
|
|||
assert output.schema()["properties"]["Full API spec"]
|
||||
|
||||
|
||||
def test_create_model_class_missing():
|
||||
def test_create_model_class_with_fields_unrecognized():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
_ = test_class(**t_dict) # 这里应该要挂掉
|
||||
_ = test_class(**t_dict) # just warning
|
||||
|
||||
|
||||
def test_create_model_class_with_fields_missing():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
_ = test_class(**t_dict_min)
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue