feat: add action.set_context(self.context) to Role._init_action

This commit is contained in:
莘权 马 2024-04-29 19:52:52 +08:00
parent 3fbc85c1b3
commit 2ed47e3eb2
3 changed files with 43 additions and 37 deletions

View file

@ -15,7 +15,7 @@
of SummarizeCode.
"""
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions import DebugError, RunCode, UserRequirement, WriteTest
from metagpt.actions.prepare_documents import PrepareDocuments
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.const import MESSAGE_ROUTE_TO_NONE
@ -49,26 +49,16 @@ class QaEngineer(Role):
# will overwrite _think() in future updates
self.set_actions(
[
PrepareDocuments(
send_to=any_to_str(self),
key_descriptions={
"project_path": 'the project path if exists in "Original Requirement"',
"reqa_file": 'the path of the source code file explicitly requested for unit test if exists in "Original Requirement"',
},
context=self.context,
),
WriteTest,
]
)
self._watch([PrepareDocuments, SummarizeCode, WriteTest, RunCode, DebugError])
self._watch([UserRequirement, PrepareDocuments, SummarizeCode, WriteTest, RunCode, DebugError])
self.test_round = 0
async def _write_test(self, message: Message) -> None:
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
changed_files = set(src_file_repo.changed_files.keys())
# Unit tests only.
if self.config.reqa_file and self.config.reqa_file not in changed_files:
changed_files.add(self.config.reqa_file)
reqa_file = self.context.kwargs.reqa_file or self.config.reqa_file
changed_files = {reqa_file} if reqa_file else set(src_file_repo.changed_files.keys())
for filename in changed_files:
# write tests
if not filename or "test" in filename:
@ -157,7 +147,8 @@ class QaEngineer(Role):
)
async def _act(self) -> Message:
await init_python_folder(self.project_repo.tests.workdir)
if self.project_path:
await init_python_folder(self.project_repo.tests.workdir)
if self.test_round > self.test_round_allowed:
result_msg = AIMessage(
content=f"Exceeding {self.test_round_allowed} rounds of tests, skip (writing code counts as a round, too)",
@ -182,6 +173,8 @@ class QaEngineer(Role):
elif msg.cause_by in run_filters:
# I ran my test code, time to fix bugs, if any
await self._debug_error(msg)
elif msg.cause_by == any_to_str(UserRequirement):
return await self._parse_user_requirement(msg)
self.test_round += 1
return AIMessage(
content=f"Round {self.test_round} of tests done",
@ -189,3 +182,17 @@ class QaEngineer(Role):
sent_from=self.profile,
send_to=MESSAGE_ROUTE_TO_NONE,
)
async def _parse_user_requirement(self, msg: Message) -> AIMessage:
action = PrepareDocuments(
send_to=any_to_str(self),
key_descriptions={
"project_path": 'the project path if exists in "Original Requirement"',
"reqa_file": 'the file name to rewrite unit test if exists in "Original Requirement"',
},
context=self.context,
)
rsp = await action.run([msg])
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
return rsp

View file

@ -255,10 +255,9 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
return self
def _init_action(self, action: Action):
if not action.private_config:
action.set_llm(self.llm, override=True)
else:
action.set_llm(self.llm, override=False)
action.set_context(self.context)
override = not action.private_config
action.set_llm(self.llm, override=override)
action.set_prefix(self._get_prefix())
def set_action(self, action: Action):

View file

@ -88,27 +88,27 @@ class MockEnv(Environment):
@pytest.mark.parametrize(
("content", "send_to"),
[
# ("snake game", any_to_str(ProductManager)),
# (
# "Rewrite the PRD file of the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game', add 'moving enemy' to the original requirement",
# any_to_str(ProductManager),
# ),
# (
# "Add 'random moving enemy, and dispears after 10 seconds' design to the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'",
# any_to_str(Architect),
# ),
# (
# 'Rewrite the tasks file of the project at "/Users/iorishinier/github/MetaGPT/workspace/snake_game"',
# any_to_str(ProjectManager),
# ),
("snake game", any_to_str(ProductManager)),
(
"Rewrite the PRD file of the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game', add 'moving enemy' to the original requirement",
any_to_str(ProductManager),
),
(
"Add 'random moving enemy, and dispears after 10 seconds' design to the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'",
any_to_str(Architect),
),
(
'Rewrite the tasks file of the project at "/Users/iorishinier/github/MetaGPT/workspace/snake_game"',
any_to_str(ProjectManager),
),
(
"Rewrite 'main.py' of the project at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'",
any_to_str(Engineer),
),
# (
# "Rewrite the unit test of 'test_main.py' at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'",
# any_to_str(QaEngineer),
# ),
(
"Rewrite the unit test of 'main.py' at '/Users/iorishinier/github/MetaGPT/workspace/snake_game'",
any_to_str(QaEngineer),
),
],
)
async def test_env(content, send_to):
@ -120,7 +120,7 @@ async def test_env(content, send_to):
Architect(context=context),
ProjectManager(context=context),
Engineer(n_borg=5, use_code_review=True, context=context),
QaEngineer(context=context),
QaEngineer(context=context, test_round_allowed=2),
]
)
msg = UserMessage(content=content, send_to=send_to)