Merge pull request #827 from shenchucheng/fix-context-mixin-ut-error

fix ContextMixin ut error
This commit is contained in:
geekan 2024-02-02 17:10:39 +08:00 committed by GitHub
commit 8b4db5f374
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 16 deletions

View file

@ -33,20 +33,16 @@ class ContextMixin(BaseModel):
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
@model_validator(mode="after")
def validate_extra(self):
self._process_extra(**(self.model_extra or {}))
def validate_context_mixin_extra(self):
self._process_context_mixin_extra()
return self
def _process_extra(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
):
def _process_context_mixin_extra(self):
"""Process the extra field"""
self.set_context(context)
self.set_config(config)
self.set_llm(llm)
kwargs = self.model_extra or {}
self.set_context(kwargs.pop("context", None))
self.set_config(kwargs.pop("config", None))
self.set_llm(kwargs.pop("llm", None))
def set(self, k, v, override=False):
"""Set attribute"""

View file

@ -23,7 +23,7 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Iterable, Optional, Set, Type, Union
from typing import Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -121,7 +121,7 @@ class RoleContext(BaseModel):
class Role(SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str = ""
profile: str = ""
@ -149,16 +149,21 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
def __init__(self, **data: Any):
@model_validator(mode="after")
def validate_role_extra(self):
self._process_role_extra()
return self
def _process_role_extra(self):
self.pydantic_rebuild_model()
super().__init__(**data)
kwargs = self.model_extra or {}
if self.is_human:
self.llm = HumanProvider(None)
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
self._watch(kwargs.pop("watch", [UserRequirement]))
if self.latest_observed_msg:
self.recovered = True