add: SWE Agent

This commit is contained in:
seeker 2024-06-28 21:05:46 +08:00
parent c0c2b5b218
commit 9b11ac5c34
17 changed files with 1575 additions and 21 deletions

View file

@ -4,14 +4,18 @@ import inspect
import json
import re
import traceback
from typing import Callable, Literal, Tuple
from typing import Callable, Dict, List, Literal, Tuple, Union
from pydantic import model_validator
from metagpt.actions import Action
from metagpt.actions.di.run_command import RunCommand
from metagpt.logs import logger
from metagpt.prompts.di.role_zero import CMD_PROMPT, ROLE_INSTRUCTION, JSON_REPAIR_PROMPT
from metagpt.prompts.di.role_zero import (
CMD_PROMPT,
JSON_REPAIR_PROMPT,
ROLE_INSTRUCTION,
)
from metagpt.roles import Role
from metagpt.schema import AIMessage, Message, UserMessage
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
@ -21,8 +25,8 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
from metagpt.utils.report import ThoughtReporter
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType
@register_tool(include_functions=["ask_human", "reply_to_human"])
@ -163,25 +167,15 @@ class RoleZero(Role):
if self.use_fixed_sop:
return await super()._act()
try:
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError as e:
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except Exception as e:
tb = traceback.format_exc()
print(tb)
error_msg = UserMessage(content=str(e))
self.rc.memory.add(error_msg)
commands, ok = await self._get_commands()
if not ok:
error_msg = commands
return error_msg
# 为了对LLM不按格式生成进行容错
if isinstance(commands, dict):
commands = commands["commands"] if "commands" in commands else [commands]
logger.info(f"Commands: \n{commands}")
outputs = await self._run_commands(commands)
logger.info(f"Commands outputs: \n{outputs}")
self.rc.memory.add(UserMessage(content=outputs))
return AIMessage(
content=f"Complete run with outputs: {outputs}",
sent_from=self.name,
@ -208,6 +202,36 @@ class RoleZero(Role):
actions_taken += 1
return rsp # return output from the last action
async def _get_commands(self) -> Tuple[Union[UserMessage, List[Dict]], bool]:
"""Retrieves commands from the Large Language Model (LLM).
This function attempts to retrieve a list of commands from the LLM by
processing the response (`self.command_rsp`). It handles potential errors
during parsing and LLM response formats.
Returns:
A tuple containing:
- A `UserMessage` object or dict representing the commands.
- A boolean flag indicating success (True) or failure (False).
"""
try:
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError:
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except Exception as e:
tb = traceback.format_exc()
logger.debug(tb)
error_msg = UserMessage(content=str(e))
self.rc.memory.add(error_msg)
return error_msg, False
# 为了对LLM不按格式生成进行容错
if isinstance(commands, dict):
commands = commands["commands"] if "commands" in commands else [commands]
return commands, True
async def _run_commands(self, commands) -> str:
outputs = []
for cmd in commands:

74
metagpt/roles/di/swe.py Normal file
View file

@ -0,0 +1,74 @@
import json
import os
from pydantic import Field
from metagpt.logs import logger
from metagpt.prompts.di.swe import (
MINIMAL_EXAMPLE,
NEXT_STEP_TEMPLATE,
SWE_AGENT_SYSTEM_TEMPLATE,
)
from metagpt.roles.di.role_zero import RoleZero
from metagpt.tools.libs.terminal import Bash
from metagpt.tools.swe_agent_commands.swe_agent_utils import extract_patch
class SWE(RoleZero):
name: str = "SweAgent"
profile: str = "Software Engineer"
goal: str = "Resolve GitHub issue"
_bash_window_size: int = 100
_system_msg: str = SWE_AGENT_SYSTEM_TEMPLATE
system_msg: list[str] = [_system_msg.format(WINDOW=_bash_window_size)]
_instruction: str = NEXT_STEP_TEMPLATE
tools: list[str] = ["Bash"]
terminal: Bash = Field(default_factory=Bash, exclude=True)
output_diff: str = ""
max_react_loop: int = 30
async def _think(self) -> bool:
self._set_system_msg()
self._format_instruction()
res = await super()._think()
await self._handle_action()
return res
def _set_system_msg(self):
if os.getenv("WINDOW"):
self._bash_window_size = int(os.getenv("WINDOW"))
self.system_msg = [self._system_msg.format(WINDOW=self._bash_window_size)]
def _format_instruction(self):
state_output = self.terminal.run("state")
bash_state = json.loads(state_output)
self.instruction = self._instruction.format(
WINDOW=self._bash_window_size, examples=MINIMAL_EXAMPLE, **bash_state
).strip()
return self.instruction
async def _handle_action(self):
commands, ok = await self._get_commands()
if not ok:
return
for cmd in commands:
if "submit" not in cmd.get("args", {}).get("cmd", ""):
return
try:
# Generate patch by git diff
diff_output = self.terminal.run("git diff")
clear_diff = extract_patch(diff_output)
logger.info(f"Diff output: \n{clear_diff}")
if clear_diff:
self.output_diff = clear_diff
except Exception as e:
logger.error(f"Error during submission: {e}")
def _update_tool_execution(self):
self.tool_execution_map.update({"Bash.run": self.terminal.run})
def _retrieve_experience(self) -> str:
return MINIMAL_EXAMPLE