Merge pull request #1234 from iorisa/refactor/role/act_by_order

refactor: optimizing act_by_order mode of `Role`
This commit is contained in:
Alexander Wu 2024-05-17 11:00:57 +08:00 committed by GitHub
commit 6b70f7b0ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 12 additions and 31 deletions

View file

@ -9,7 +9,7 @@
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.roles.role import Role
from metagpt.roles.role import Role, RoleReactMode
from metagpt.utils.common import any_to_name
@ -35,17 +35,8 @@ class ProductManager(Role):
self.set_actions([PrepareDocuments, WritePRD])
self._watch([UserRequirement, PrepareDocuments])
self.todo_action = any_to_name(PrepareDocuments)
async def _think(self) -> bool:
"""Decide what to do"""
if self.git_repo and not self.config.git_reinit:
self._set_state(1)
else:
self._set_state(0)
self.config.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self.rc.todo)
self.rc.react_mode = RoleReactMode.BY_ORDER
self.todo_action = any_to_name(WritePRD)
async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)

View file

@ -370,6 +370,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
self.recovered = False # avoid max_react_loop out of work
return True
if self.rc.react_mode == RoleReactMode.BY_ORDER:
if self.rc.max_react_loop != len(self.actions):
self.rc.max_react_loop = len(self.actions)
self._set_state(self.rc.state + 1)
return self.rc.state >= 0 and self.rc.state < len(self.actions)
prompt = self._get_prefix()
prompt += STATE_TEMPLATE.format(
history=self.rc.history,
@ -460,8 +466,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act
while actions_taken < self.rc.max_react_loop:
# think
await self._think()
if self.rc.todo is None:
todo = await self._think()
if not todo:
break
# act
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
@ -469,15 +475,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
actions_taken += 1
return rsp # return output from the last action
async def _act_by_order(self) -> Message:
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state
rsp = Message(content="No actions taken yet") # return default message if actions=[]
for i in range(start_idx, len(self.states)):
self._set_state(i)
rsp = await self._act()
return rsp # return output from the last action
async def _plan_and_act(self) -> Message:
"""first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically."""
@ -518,10 +515,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
async def react(self) -> Message:
"""Entry to one of three strategies by which Role reacts to the observed Message"""
if self.rc.react_mode == RoleReactMode.REACT:
if self.rc.react_mode == RoleReactMode.REACT or self.rc.react_mode == RoleReactMode.BY_ORDER:
rsp = await self._react()
elif self.rc.react_mode == RoleReactMode.BY_ORDER:
rsp = await self._act_by_order()
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
rsp = await self._plan_and_act()
else:

View file

@ -10,7 +10,6 @@ import json
import pytest
from metagpt.actions import WritePRD
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.context import Context
from metagpt.logs import logger
@ -30,11 +29,7 @@ async def test_product_manager(new_filename):
rsp = await product_manager.run(MockMessages.req)
assert context.git_repo
assert context.repo
assert rsp.cause_by == any_to_str(PrepareDocuments)
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
# write prd
rsp = await product_manager.run(rsp)
assert rsp.cause_by == any_to_str(WritePRD)
logger.info(rsp)
assert len(rsp.content) > 0