Merge branch 'reply_to_human_when_finish_current_task' into 'mgx_ops'

优化:当智能体结束执行时,添加回复检测,要求智能体至少报告一次结果。调整swe-agent测试代码

See merge request pub/MetaGPT!326
This commit is contained in:
林义章 2024-08-20 12:59:19 +00:00
commit 31dbd55d76
4 changed files with 71 additions and 28 deletions

View file

@ -26,7 +26,9 @@ from metagpt.prompts.di.role_zero import (
QUICK_THINK_PROMPT,
QUICK_THINK_SYSTEM_PROMPT,
REGENERATE_PROMPT,
REPORT_TO_HUMAN_PROMPT,
ROLE_INSTRUCTION,
SUMMARY_PROMPT,
SYSTEM_PROMPT,
THOUGHT_GUIDANCE,
)
@ -85,6 +87,7 @@ class RoleZero(Role):
memory_k: int = 20 # number of memories (messages) to use as historical context
use_fixed_sop: bool = False
requirements_constraints: str = "" # the constraints in user requirements
use_summary: bool = True # whether to summarize at the end
@model_validator(mode="after")
def set_plan_and_tool(self) -> "RoleZero":
@ -234,9 +237,10 @@ class RoleZero(Role):
if self.use_fixed_sop:
return await super()._act()
commands, ok = await self._parse_commands()
commands, ok = await self._parse_commands(self.command_rsp)
if not ok:
error_msg = commands
self.rc.memory.add(UserMessage(content=error_msg))
return error_msg
logger.info(f"Commands: \n{commands}")
outputs = await self._run_commands(commands)
@ -339,7 +343,7 @@ class RoleZero(Role):
command_rsp = await self.llm.aask(regenerate_req)
return command_rsp
async def _parse_commands(self) -> Tuple[List[Dict], bool]:
async def _parse_commands(self, command_rsp) -> Tuple[List[Dict], bool]:
"""Retrieves commands from the Large Language Model (LLM).
This function attempts to retrieve a list of commands from the LLM by
@ -351,20 +355,20 @@ class RoleZero(Role):
- A boolean flag indicating success (True) or failure (False).
"""
try:
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
commands = CodeParser.parse_code(block=None, lang="json", text=command_rsp)
if commands.endswith("]") and not commands.startswith("["):
commands = "[" + commands
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON for: {self.command_rsp}. Trying to repair...")
logger.warning(f"Failed to parse JSON for: {command_rsp}. Trying to repair...")
commands = await self.llm.aask(
msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp, json_decode_error=str(e))
msg=JSON_REPAIR_PROMPT.format(json_data=command_rsp, json_decode_error=str(e))
)
try:
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except json.JSONDecodeError:
# repair escape error of code and math
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
commands = CodeParser.parse_code(block=None, lang="json", text=command_rsp)
new_command = repair_escape_error(commands)
commands = json.loads(
repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON)
@ -372,8 +376,7 @@ class RoleZero(Role):
except Exception as e:
tb = traceback.format_exc()
print(tb)
error_msg = UserMessage(content=str(e))
self.rc.memory.add(error_msg)
error_msg = str(e)
return error_msg, False
# 为了对LLM不按格式生成进行容错
@ -426,8 +429,7 @@ class RoleZero(Role):
command_output = "Current task is finished. If all tasks are finished, use 'end' to stop."
elif cmd["command_name"] == "end":
self._set_state(-1)
command_output = ""
command_output = await self._end()
# output from bash.run may be empty, add decorations to the output to ensure visibility.
elif cmd["command_name"] == "Bash.run":
@ -486,3 +488,23 @@ class RoleZero(Role):
if not isinstance(self.rc.env, MGXEnv):
return "Not in MGXEnv, command will not be executed."
return await self.rc.env.reply_to_human(content, sent_from=self)
async def _end(self):
self._set_state(-1)
memory = self.rc.memory.get(self.memory_k)
# Ensure reply to the human before the "end" command is executed.
if not any(["reply_to_human" in memory.content for memory in self.get_memories(k=5)]):
reply_to_human_prompt = REPORT_TO_HUMAN_PROMPT.format(
requirements_constraints=self.requirements_constraints,
)
reply_content = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(reply_to_human_prompt)]))
await self.reply_to_human(content=reply_content)
self.rc.memory.add(AIMessage(content=reply_content, cause_by=RunCommand))
outputs = ""
# Summary of the Completed Task and Deliverables
if self.use_summary:
summary_prompt = SUMMARY_PROMPT.format(
requirements_constraints=self.requirements_constraints,
)
outputs = await self.llm.aask(self.llm.format_msg(memory + [UserMessage(summary_prompt)]))
return outputs

View file

@ -9,6 +9,7 @@ from metagpt.prompts.di.swe_agent import (
NEXT_STEP_TEMPLATE,
)
from metagpt.roles.di.role_zero import RoleZero
from metagpt.schema import Message
from metagpt.tools.libs.git import git_create_pull
from metagpt.tools.libs.terminal import Bash
@ -32,8 +33,6 @@ class SWEAgent(RoleZero):
async def _think(self) -> bool:
await self._format_instruction()
res = await super()._think()
if self.run_eval:
await self._parse_commands_for_eval()
return res
def _update_tool_execution(self):
@ -55,6 +54,12 @@ class SWEAgent(RoleZero):
bash_state = json.loads(state_output)
self.cmd_prompt_current_state = CURRENT_BASH_STATE.format(**bash_state).strip()
async def _act(self) -> Message:
message = await super()._act()
if self.run_eval:
self._parse_commands_for_eval()
return message
async def _parse_commands_for_eval(self):
"""
Handles actions based on parsed commands.
@ -64,24 +69,19 @@ class SWEAgent(RoleZero):
This function is specifically added for SWE bench evaluation.
"""
# only import when evaluation is needed
from metagpt.tools.swe_agent_commands.swe_agent_utils import extract_patch
# If todo switches to None, it indicates that this is the final round of reactions, and the Swe-Agent will stop. Use git diff to store any changes made.
if not self.rc.todo:
from metagpt.tools.swe_agent_commands.swe_agent_utils import extract_patch
commands, ok = await self._parse_commands()
if not ok:
return
for cmd in commands:
if "end" != cmd.get("command_name", ""):
return
try:
diff_output = await self.terminal.run("git diff --cached")
clear_diff = extract_patch(diff_output)
logger.info(f"Diff output: \n{clear_diff}")
if clear_diff:
self.output_diff = clear_diff
try:
diff_output = await self.terminal.run("git diff --cached")
clear_diff = extract_patch(diff_output)
logger.info(f"Diff output: \n{clear_diff}")
if clear_diff:
self.output_diff = clear_diff
except Exception as e:
logger.error(f"Error during submission: {e}")
except Exception as e:
logger.error(f"Error during submission: {e}")
def _retrieve_experience(self) -> str:
return MINIMAL_EXAMPLE

View file

@ -30,6 +30,8 @@ class TeamLeader(RoleZero):
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = SimpleExpRetriever()
use_summary: bool = False
def _update_tool_execution(self):
self.tool_execution_map.update(
{