diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py index 060150f4d..59daa692f 100644 --- a/metagpt/context_mixin.py +++ b/metagpt/context_mixin.py @@ -33,20 +33,16 @@ class ContextMixin(BaseModel): private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) @model_validator(mode="after") - def validate_extra(self): - self._process_extra(**(self.model_extra or {})) + def validate_context_mixin_extra(self): + self._process_context_mixin_extra() return self - def _process_extra( - self, - context: Optional[Context] = None, - config: Optional[Config] = None, - llm: Optional[BaseLLM] = None, - ): + def _process_context_mixin_extra(self): """Process the extra field""" - self.set_context(context) - self.set_config(config) - self.set_llm(llm) + kwargs = self.model_extra or {} + self.set_context(kwargs.pop("context", None)) + self.set_config(kwargs.pop("config", None)) + self.set_llm(kwargs.pop("llm", None)) def set(self, k, v, override=False): """Set attribute""" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 5da39f80f..c098f95af 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -23,7 +23,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Iterable, Optional, Set, Type, Union +from typing import Iterable, Optional, Set, Type, Union from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator @@ -121,7 +121,7 @@ class RoleContext(BaseModel): class Role(SerializationMixin, ContextMixin, BaseModel): """Role/Agent""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") name: str = "" profile: str = "" @@ -149,16 +149,21 @@ class Role(SerializationMixin, ContextMixin, BaseModel): __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` - def __init__(self, **data: Any): + @model_validator(mode="after") + def validate_role_extra(self): + self._process_role_extra() + return self + + def _process_role_extra(self): self.pydantic_rebuild_model() - super().__init__(**data) + kwargs = self.model_extra or {} if self.is_human: self.llm = HumanProvider(None) self._check_actions() self.llm.system_prompt = self._get_prefix() - self._watch(data.get("watch") or [UserRequirement]) + self._watch(kwargs.pop("watch", [UserRequirement])) if self.latest_observed_msg: self.recovered = True