mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
fix actions/roles ser&deser
This commit is contained in:
parent
caacfcff7a
commit
c70c8358d3
4 changed files with 30 additions and 25 deletions
|
|
@ -117,23 +117,21 @@ class SearchAndSummarize(Action):
|
|||
|
||||
@root_validator
|
||||
def validate_engine_and_run_func(cls, values):
|
||||
engine = values.get('engine')
|
||||
search_func = values.get('search_func')
|
||||
engine = values.get("engine")
|
||||
search_func = values.get("search_func")
|
||||
config = Config()
|
||||
|
||||
if engine is None:
|
||||
engine = config.search_engine
|
||||
config_data = {
|
||||
'engine': engine,
|
||||
'run_func': search_func
|
||||
}
|
||||
search_engine = SearchEngine(**config_data)
|
||||
try:
|
||||
search_engine = SearchEngine(engine=engine, run_func=search_func)
|
||||
except pydantic.ValidationError:
|
||||
search_engine = None
|
||||
|
||||
values['search_engine'] = search_engine
|
||||
values["search_engine"] = search_engine
|
||||
return values
|
||||
|
||||
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
|
||||
print(context)
|
||||
if self.search_engine is None:
|
||||
logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature")
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -226,17 +226,14 @@ class WritePRD(Action):
|
|||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
assistant_search_action: Action = None
|
||||
|
||||
async def run(self, requirements, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput:
|
||||
# self.assistant_search_action = SearchAndSummarize()
|
||||
if self.assistant_search_action is None:
|
||||
self.assistant_search_action = SearchAndSummarize()
|
||||
# self.assistant_search_action = SearchAndSummarize()
|
||||
rsp = await self.assistant_search_action.run(context=requirements)
|
||||
info = f"### Search Results\n{self.assistant_search_action.result}\n\n### Search Summary\n{rsp}"
|
||||
if self.assistant_search_action.result:
|
||||
logger.info(self.assistant_search_action.result)
|
||||
sas = SearchAndSummarize()
|
||||
# rsp = await sas.run(context=requirements, system_text=SEARCH_AND_SUMMARIZE_SYSTEM_EN_US)
|
||||
rsp = ""
|
||||
info = f"### Search Results\n{sas.result}\n\n### Search Summary\n{rsp}"
|
||||
if sas.result:
|
||||
logger.info(sas.result)
|
||||
logger.info(rsp)
|
||||
|
||||
prompt_template, format_example = get_template(templates, format)
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class RoleSetting(BaseModel):
|
|||
|
||||
class RoleContext(BaseModel):
|
||||
"""Role Runtime Context"""
|
||||
env: "Environment" = Field(default=None)
|
||||
env: "Environment" = Field(default=None, exclude=True)
|
||||
memory: Memory = Field(default_factory=Memory)
|
||||
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
|
||||
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
|
||||
|
|
@ -133,7 +133,7 @@ class Role(BaseModel):
|
|||
_role_id: str = ""
|
||||
_states: list[str] = Field(default=[])
|
||||
_actions: list[Action] = Field(default=[])
|
||||
_rc: RoleContext = RoleContext()
|
||||
_rc: RoleContext = Field(default=RoleContext, exclude=True)
|
||||
|
||||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
|
|
@ -143,7 +143,8 @@ class Role(BaseModel):
|
|||
"_llm": LLM() if not is_human else HumanProvider(),
|
||||
"_role_id": _role_id,
|
||||
"_states": [],
|
||||
"_actions": []
|
||||
"_actions": [],
|
||||
"_rc": RoleContext()
|
||||
}
|
||||
|
||||
class Config:
|
||||
|
|
@ -169,6 +170,8 @@ class Role(BaseModel):
|
|||
self._private_attributes["_setting"] = RoleSetting(name=self.name, profile=self.profile, goal=self.goal,
|
||||
desc=self.desc, constraints=self.constraints,
|
||||
is_human=self.is_human)
|
||||
self._private_attributes["_role_id"] = str(self._setting)
|
||||
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
|
|
@ -176,10 +179,15 @@ class Role(BaseModel):
|
|||
setting = RoleSetting(**kwargs[key])
|
||||
object.__setattr__(self, "_setting", setting)
|
||||
elif key == "_rc":
|
||||
_rc = RoleContext
|
||||
_rc = RoleContext()
|
||||
object.__setattr__(self, "_rc", _rc)
|
||||
else:
|
||||
object.__setattr__(self, key, self._private_attributes[key])
|
||||
if key == "_rc":
|
||||
# # Warning, if use self._private_attributes["_rc"],
|
||||
# # self._rc will be a shared object between roles, so init one or reset it inside `_reset`
|
||||
object.__setattr__(self, key, RoleContext())
|
||||
else:
|
||||
object.__setattr__(self, key, self._private_attributes[key])
|
||||
|
||||
# deserialize child classes dynamically for inherited `role`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
|
|
@ -192,6 +200,7 @@ class Role(BaseModel):
|
|||
def _reset(self):
|
||||
object.__setattr__(self, "_states", [])
|
||||
object.__setattr__(self, "_actions", [])
|
||||
# object.__setattr__(self, "_rc", RoleContext())
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \
|
||||
|
|
@ -289,7 +298,6 @@ class Role(BaseModel):
|
|||
for idx, action in enumerate(actions):
|
||||
if not isinstance(action, Action):
|
||||
## 默认初始化
|
||||
# import pdb; pdb.set_trace()
|
||||
i = action(name="", llm=self._llm)
|
||||
else:
|
||||
if self._setting.is_human and not isinstance(action.llm, HumanProvider):
|
||||
|
|
|
|||
|
|
@ -51,7 +51,9 @@ def format_trackback_info(limit: int = 2):
|
|||
def serialize_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
result = await func(self, *args, **kwargs)
|
||||
self.serialize() # Team.serialize
|
||||
return result
|
||||
except KeyboardInterrupt as kbi:
|
||||
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
self.serialize() # Team.serialize
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue