From 6505087c788eaa4ea53f9680f83472aea83c4fdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Sun, 28 Apr 2024 22:01:40 +0800 Subject: [PATCH] refactor: optimizing act_by_order mode --- metagpt/roles/product_manager.py | 15 +++----------- metagpt/roles/role.py | 23 ++++++++------------- tests/metagpt/roles/test_product_manager.py | 5 ----- 3 files changed, 12 insertions(+), 31 deletions(-) diff --git a/metagpt/roles/product_manager.py b/metagpt/roles/product_manager.py index fbe139a99..9db9f7d9e 100644 --- a/metagpt/roles/product_manager.py +++ b/metagpt/roles/product_manager.py @@ -9,7 +9,7 @@ from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.prepare_documents import PrepareDocuments -from metagpt.roles.role import Role +from metagpt.roles.role import Role, RoleReactMode from metagpt.utils.common import any_to_name @@ -35,17 +35,8 @@ class ProductManager(Role): self.set_actions([PrepareDocuments, WritePRD]) self._watch([UserRequirement, PrepareDocuments]) - self.todo_action = any_to_name(PrepareDocuments) - - async def _think(self) -> bool: - """Decide what to do""" - if self.git_repo and not self.config.git_reinit: - self._set_state(1) - else: - self._set_state(0) - self.config.git_reinit = False - self.todo_action = any_to_name(WritePRD) - return bool(self.rc.todo) + self.rc.react_mode = RoleReactMode.BY_ORDER + self.todo_action = any_to_name(WritePRD) async def _observe(self, ignore_memory=False) -> int: return await super()._observe(ignore_memory=True) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index e0f8a7ea6..35d8423e5 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -365,6 +365,12 @@ class Role(SerializationMixin, ContextMixin, BaseModel): self.recovered = False # avoid max_react_loop out of work return True + if self.rc.react_mode == RoleReactMode.BY_ORDER: + if self.rc.max_react_loop != len(self.actions): + self.rc.max_react_loop = len(self.actions) + self._set_state(self.rc.state + 1) + return self.rc.state >= 0 and self.rc.state < len(self.actions) + prompt = self._get_prefix() prompt += STATE_TEMPLATE.format( history=self.rc.history, @@ -455,8 +461,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act while actions_taken < self.rc.max_react_loop: # think - await self._think() - if self.rc.todo is None: + todo = await self._think() + if not todo: break # act logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}") @@ -464,15 +470,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): actions_taken += 1 return rsp # return output from the last action - async def _act_by_order(self) -> Message: - """switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ...""" - start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state - rsp = Message(content="No actions taken yet") # return default message if actions=[] - for i in range(start_idx, len(self.states)): - self._set_state(i) - rsp = await self._act() - return rsp # return output from the last action - async def _plan_and_act(self) -> Message: """first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically.""" @@ -513,10 +510,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): async def react(self) -> Message: """Entry to one of three strategies by which Role reacts to the observed Message""" - if self.rc.react_mode == RoleReactMode.REACT: + if self.rc.react_mode == RoleReactMode.REACT or self.rc.react_mode == RoleReactMode.BY_ORDER: rsp = await self._react() - elif self.rc.react_mode == RoleReactMode.BY_ORDER: - rsp = await self._act_by_order() elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT: rsp = await self._plan_and_act() else: diff --git a/tests/metagpt/roles/test_product_manager.py b/tests/metagpt/roles/test_product_manager.py index 59b5aa81a..143eef2f2 100644 --- a/tests/metagpt/roles/test_product_manager.py +++ b/tests/metagpt/roles/test_product_manager.py @@ -10,7 +10,6 @@ import json import pytest from metagpt.actions import WritePRD -from metagpt.actions.prepare_documents import PrepareDocuments from metagpt.const import REQUIREMENT_FILENAME from metagpt.context import Context from metagpt.logs import logger @@ -30,11 +29,7 @@ async def test_product_manager(new_filename): rsp = await product_manager.run(MockMessages.req) assert context.git_repo assert context.repo - assert rsp.cause_by == any_to_str(PrepareDocuments) assert REQUIREMENT_FILENAME in context.repo.docs.changed_files - - # write prd - rsp = await product_manager.run(rsp) assert rsp.cause_by == any_to_str(WritePRD) logger.info(rsp) assert len(rsp.content) > 0