Merge branch 'geekan/dev' into feature/rebuild

This commit is contained in:
莘权 马 2024-02-02 17:20:14 +08:00
commit 0e864dc9da
4 changed files with 43 additions and 42 deletions

View file

@ -33,20 +33,16 @@ class ContextMixin(BaseModel):
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
@model_validator(mode="after")
def validate_extra(self):
self._process_extra(**(self.model_extra or {}))
def validate_context_mixin_extra(self):
self._process_context_mixin_extra()
return self
def _process_extra(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
):
def _process_context_mixin_extra(self):
"""Process the extra field"""
self.set_context(context)
self.set_config(config)
self.set_llm(llm)
kwargs = self.model_extra or {}
self.set_context(kwargs.pop("context", None))
self.set_config(kwargs.pop("config", None))
self.set_llm(kwargs.pop("llm", None))
def set(self, k, v, override=False):
"""Set attribute"""

View file

@ -23,7 +23,7 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Iterable, Optional, Set, Type, Union
from typing import Iterable, Optional, Set, Type, Union
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
@ -121,7 +121,7 @@ class RoleContext(BaseModel):
class Role(SerializationMixin, ContextMixin, BaseModel):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str = ""
profile: str = ""
@ -149,16 +149,21 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
def __init__(self, **data: Any):
@model_validator(mode="after")
def validate_role_extra(self):
self._process_role_extra()
return self
def _process_role_extra(self):
self.pydantic_rebuild_model()
super().__init__(**data)
kwargs = self.model_extra or {}
if self.is_human:
self.llm = HumanProvider(None)
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
self._watch(kwargs.pop("watch", [UserRequirement]))
if self.latest_observed_msg:
self.recovered = True