From f6824b078cc18fd21bba20f99da8f761b0510ffd Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Fri, 2 Feb 2024 16:12:47 +0800 Subject: [PATCH 1/2] fix ContextMixin ut error --- metagpt/context_mixin.py | 17 ++++++----------- metagpt/roles/role.py | 14 +++++++++----- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py index 060150f4d..cf0604606 100644 --- a/metagpt/context_mixin.py +++ b/metagpt/context_mixin.py @@ -33,20 +33,15 @@ class ContextMixin(BaseModel): private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) @model_validator(mode="after") - def validate_extra(self): - self._process_extra(**(self.model_extra or {})) + def validate_context_mixin_extra(self): + self._process_context_mixin_extra(**(self.model_extra or {})) return self - def _process_extra( - self, - context: Optional[Context] = None, - config: Optional[Config] = None, - llm: Optional[BaseLLM] = None, - ): + def _process_context_mixin_extra(self, **kwargs): """Process the extra field""" - self.set_context(context) - self.set_config(config) - self.set_llm(llm) + self.set_context(kwargs.get("context")) + self.set_config(kwargs.get("config")) + self.set_llm(kwargs.get("llm")) def set(self, k, v, override=False): """Set attribute""" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 5da39f80f..20cd4da99 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -23,7 +23,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Iterable, Optional, Set, Type, Union +from typing import Iterable, Optional, Set, Type, Union from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator @@ -121,7 +121,7 @@ class RoleContext(BaseModel): class Role(SerializationMixin, ContextMixin, BaseModel): """Role/Agent""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") name: str = "" profile: str = "" @@ -149,16 +149,20 @@ class Role(SerializationMixin, ContextMixin, BaseModel): __hash__ = object.__hash__ # support Role as hashable type in `Environment.members` - def __init__(self, **data: Any): + @model_validator(mode="after") + def validate_role_extra(self): + self._process_role_extra(**(self.model_extra or {})) + return self + + def _process_role_extra(self, **kwargs): self.pydantic_rebuild_model() - super().__init__(**data) if self.is_human: self.llm = HumanProvider(None) self._check_actions() self.llm.system_prompt = self._get_prefix() - self._watch(data.get("watch") or [UserRequirement]) + self._watch(kwargs.get("watch") or [UserRequirement]) if self.latest_observed_msg: self.recovered = True From 3125f4c0c799837cbab655a58e86f52ceb5fcb9f Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Fri, 2 Feb 2024 16:35:51 +0800 Subject: [PATCH 2/2] remove extra value after model_validator in Role/ContextMixin --- metagpt/context_mixin.py | 11 ++++++----- metagpt/roles/role.py | 7 ++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py index cf0604606..59daa692f 100644 --- a/metagpt/context_mixin.py +++ b/metagpt/context_mixin.py @@ -34,14 +34,15 @@ class ContextMixin(BaseModel): @model_validator(mode="after") def validate_context_mixin_extra(self): - self._process_context_mixin_extra(**(self.model_extra or {})) + self._process_context_mixin_extra() return self - def _process_context_mixin_extra(self, **kwargs): + def _process_context_mixin_extra(self): """Process the extra field""" - self.set_context(kwargs.get("context")) - self.set_config(kwargs.get("config")) - self.set_llm(kwargs.get("llm")) + kwargs = self.model_extra or {} + self.set_context(kwargs.pop("context", None)) + self.set_config(kwargs.pop("config", None)) + self.set_llm(kwargs.pop("llm", None)) def set(self, k, v, override=False): """Set attribute""" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 20cd4da99..c098f95af 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -151,18 +151,19 @@ class Role(SerializationMixin, ContextMixin, BaseModel): @model_validator(mode="after") def validate_role_extra(self): - self._process_role_extra(**(self.model_extra or {})) + self._process_role_extra() return self - def _process_role_extra(self, **kwargs): + def _process_role_extra(self): self.pydantic_rebuild_model() + kwargs = self.model_extra or {} if self.is_human: self.llm = HumanProvider(None) self._check_actions() self.llm.system_prompt = self._get_prefix() - self._watch(kwargs.get("watch") or [UserRequirement]) + self._watch(kwargs.pop("watch", [UserRequirement])) if self.latest_observed_msg: self.recovered = True