fix ContextMixin ut error

This commit is contained in:
shenchucheng 2024-02-02 16:12:47 +08:00
parent 0118712ff8
commit f6824b078c
2 changed files with 15 additions and 16 deletions

View file

@ -33,20 +33,15 @@ class ContextMixin(BaseModel):
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
@model_validator(mode="after")
def validate_extra(self):
self._process_extra(**(self.model_extra or {}))
def validate_context_mixin_extra(self):
self._process_context_mixin_extra(**(self.model_extra or {}))
return self
def _process_extra(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
):
def _process_context_mixin_extra(self, **kwargs):
"""Process the extra field"""
self.set_context(context)
self.set_config(config)
self.set_llm(llm)
self.set_context(kwargs.get("context"))
self.set_config(kwargs.get("config"))
self.set_llm(kwargs.get("llm"))
def set(self, k, v, override=False):
"""Set attribute"""

View file

@ -23,7 +23,7 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Iterable, Optional, Set, Type, Union
from typing import Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -121,7 +121,7 @@ class RoleContext(BaseModel):
class Role(SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str = ""
profile: str = ""
@ -149,16 +149,20 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
def __init__(self, **data: Any):
@model_validator(mode="after")
def validate_role_extra(self):
self._process_role_extra(**(self.model_extra or {}))
return self
def _process_role_extra(self, **kwargs):
self.pydantic_rebuild_model()
super().__init__(**data)
if self.is_human:
self.llm = HumanProvider(None)
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
self._watch(kwargs.get("watch") or [UserRequirement])
if self.latest_observed_msg:
self.recovered = True