fix actions/roles ser&deser

This commit is contained in:
better629 2023-11-30 19:31:26 +08:00
parent caacfcff7a
commit c70c8358d3
4 changed files with 30 additions and 25 deletions

View file

@ -117,23 +117,21 @@ class SearchAndSummarize(Action):
@root_validator
def validate_engine_and_run_func(cls, values):
engine = values.get('engine')
search_func = values.get('search_func')
engine = values.get("engine")
search_func = values.get("search_func")
config = Config()
if engine is None:
engine = config.search_engine
config_data = {
'engine': engine,
'run_func': search_func
}
search_engine = SearchEngine(**config_data)
try:
search_engine = SearchEngine(engine=engine, run_func=search_func)
except pydantic.ValidationError:
search_engine = None
values['search_engine'] = search_engine
values["search_engine"] = search_engine
return values
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
print(context)
if self.search_engine is None:
logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature")
return ""

View file

@ -226,17 +226,14 @@ class WritePRD(Action):
name: str = ""
content: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
assistant_search_action: Action = None
async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput:
# self.assistant_search_action = SearchAndSummarize()
if self.assistant_search_action is None:
self.assistant_search_action = SearchAndSummarize()
# self.assistant_search_action = SearchAndSummarize()
rsp = await self.assistant_search_action.run(context=requirements)
info = f"### Search Results\n{self.assistant_search_action.result}\n\n### Search Summary\n{rsp}"
if self.assistant_search_action.result:
logger.info(self.assistant_search_action.result)
sas = SearchAndSummarize()
# rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
rsp = ""
info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}"
if sas.result:
logger.info(sas.result)
logger.info(rsp)
prompt_template, format_example = get_template(templates, format)

View file

@ -88,7 +88,7 @@ class RoleSetting(BaseModel):
class RoleContext(BaseModel):
"""Role Runtime Context"""
env: "Environment" = Field(default=None)
env: "Environment" = Field(default=None, exclude=True)
memory: Memory = Field(default_factory=Memory)
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
@ -133,7 +133,7 @@ class Role(BaseModel):
_role_id: str = ""
_states: list[str] = Field(default=[])
_actions: list[Action] = Field(default=[])
_rc: RoleContext = RoleContext()
_rc: RoleContext = Field(default=RoleContext, exclude=True)
# builtin variables
recovered: bool = False # to tag if a recovered role
@ -143,7 +143,8 @@ class Role(BaseModel):
"_llm": LLM() if not is_human else HumanProvider(),
"_role_id": _role_id,
"_states": [],
"_actions": []
"_actions": [],
"_rc": RoleContext()
}
class Config:
@ -169,6 +170,8 @@ class Role(BaseModel):
self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal,
desc=self.desc, constraints=self.constraints,
is_human=self.is_human)
self._private_attributes["_role_id"] = str(self._setting)
for key in self._private_attributes.keys():
if key in kwargs:
object.__setattr__(self, key, kwargs[key])
@ -176,10 +179,15 @@ class Role(BaseModel):
setting = RoleSetting(**kwargs[key])
object.__setattr__(self, "_setting", setting)
elif key == "_rc":
_rc = RoleContext
_rc = RoleContext()
object.__setattr__(self, "_rc", _rc)
else:
object.__setattr__(self, key, self._private_attributes[key])
if key == "_rc":
# # Warning, if use self._private_attributes["_rc"],
# # self._rc will be a shared object between roles, so init one or reset it inside `_reset`
object.__setattr__(self, key, RoleContext())
else:
object.__setattr__(self, key, self._private_attributes[key])
# deserialize child classes dynamically for inherited `role`
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
@ -192,6 +200,7 @@ class Role(BaseModel):
def _reset(self):
object.__setattr__(self, "_states", [])
object.__setattr__(self, "_actions", [])
# object.__setattr__(self, "_rc", RoleContext())
def serialize(self, stg_path: Path = None):
stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \
@ -289,7 +298,6 @@ class Role(BaseModel):
for idx, action in enumerate(actions):
if not isinstance(action, Action):
## 默认初始化
# import pdb; pdb.set_trace()
i = action(name="", llm=self._llm)
else:
if self._setting.is_human and not isinstance(action.llm, HumanProvider):

View file

@ -51,7 +51,9 @@ def format_trackback_info(limit: int = 2):
def serialize_decorator(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
result = await func(self, *args, **kwargs)
self.serialize() # Team.serialize
return result
except KeyboardInterrupt as kbi:
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
self.serialize() # Team.serialize