update path

This commit is contained in:
stellahsr 2024-03-26 15:28:37 +08:00
parent e88b0fdf16
commit 50f4953ea7
17 changed files with 17 additions and 13 deletions

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from pathlib import Path
import numpy as np
from datasets import load_dataset, load_from_disk
from benchmark.swe_bench.inference.const import SCIKIT_LEARN_IDS
def load_oracle_dataset(dataset_name_or_path: str = "", split: str = "test", existing_ids: list = []):
if Path(dataset_name_or_path).exists():
dataset = load_from_disk(dataset_name_or_path)
else:
dataset = load_dataset(dataset_name_or_path)
if split not in dataset:
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))
if len(existing_ids) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
if len(SCIKIT_LEARN_IDS) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
desc="Filtering out subset_instance_ids",
load_from_cache_file=False,
)
return dataset

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import Literal, Union
from metagpt.actions.di.ask_review import ReviewConst
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message
class GitAgent(DataInterpreter):
name: str = "Jacky"
profile: str = "Solve git issues proficiently"
auto_run: bool = True
use_plan: bool = True
use_reflection: bool = False
react_mode: Literal["plan_and_act", "react"] = "react"
script_names: Union[str, list[str]] = []
instance_id: str = ""
async def critique(self, result, review_format):
review_result = (
"Finally, return a boolean value (True or False) to indicate the result of the review. "
"Note: If the result is good enough, return True; otherwise, return False."
)
status = await self.llm.aask(
[
Message(content=review_format, role="user"),
Message(content=result, role="assistant"),
Message(content=review_result, role="user"),
]
)
logger.info(status)
return status
async def review_patch(self, code):
review_format = (
"Please ensure that the code {code} and original script {original_script} can fix the issue {memory} in patch format. "
"If it is not in patch format, please convert it to patch format."
)
results = []
for script in self.script_names:
with open(script, "r", encoding="utf-8") as fp:
original_script = fp.read()
memory = self.get_memories()[0].content
review_prompt = review_format.format(code=code, original_script=original_script, memory=memory)
# todo: extract issue and remove image urls
result = await self.llm.aask(review_prompt)
results.append(result)
# fixme: merge results to a single patch file
result = "\n".join(results)
return result, review_prompt
async def _write_and_exec_code(self, max_retry: int = 3):
counter = 0
success = False
# plan info
plan_status = self.planner.get_plan_status() if self.use_plan else ""
# tool info
if self.tools:
context = (
self.working_memory.get()[-1].content if self.working_memory.get() else ""
) # thoughts from _think stage in 'react' mode
plan = self.planner.plan if self.use_plan else None
tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
else:
tool_info = ""
while not success and counter < max_retry:
### write code ###
code, cause_by = await self._write_code(counter, plan_status, tool_info)
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
result, format_prompt = await self.review_patch(code)
success = await self.critique(result, format_prompt)
await self.execute_code.run(code)
### execute code ###
# todo: execute: git apply
### process execution result ###
counter += 1
if not success and counter >= max_retry:
logger.info("coding failed!")
review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
if ReviewConst.CHANGE_WORDS[0] in review:
counter = 0 # redo the task again with help of human suggestions
return code, result, success

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import pandas as pd
from metagpt.const import METAGPT_ROOT
SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"
TESTBED = METAGPT_ROOT / "benchmark" / "swe-bench" / "data" / "repos"
# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
# This collection represents a subset specifically related to scikit-learn content.
SCIKIT_LEARN_IDS = [
"scikit-learn__scikit-learn-11578",
"scikit-learn__scikit-learn-10297",
"scikit-learn__scikit-learn-25747",
"scikit-learn__scikit-learn-15512",
"scikit-learn__scikit-learn-15119",
"scikit-learn__scikit-learn-10870",
"scikit-learn__scikit-learn-15100",
"scikit-learn__scikit-learn-14496",
"scikit-learn__scikit-learn-14890",
"scikit-learn__scikit-learn-10428",
"scikit-learn__scikit-learn-25744",
"scikit-learn__scikit-learn-11542",
"scikit-learn__scikit-learn-10198",
"scikit-learn__scikit-learn-10459",
]
MATPLOTLIB_IDS = [
"matplotlib__matplotlib-24362",
"matplotlib__matplotlib-20584",
"matplotlib__matplotlib-23188",
"matplotlib__matplotlib-24403",
# 'matplotlib__matplotlib-21443',
# 'matplotlib__matplotlib-23047'
]
def read_subset_instance(path=SUBSET_DATASET, tag="scikit-learn"):
try:
df = pd.read_excel(path)
pass_filters = df["instance_id_pass"].tolist()
fail_filters = df["instance_id_fail"].tolist()
pass_filters = [s for s in pass_filters if tag in s]
fail_filters = [s for s in fail_filters if tag in s]
subset_instance = pass_filters + fail_filters
return subset_instance
except FileNotFoundError:
print(f"File not found: {path}")
return []
except Exception as e:
print(f"An error occurred: {e}")
return []
if __name__ == "__main__":
print(read_subset_instance(tag="matplotlib__matplotlib"))

View file

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import runpy
import sys
original_argv = sys.argv.copy()
try:
# 设置你想要传递给脚本的命令行参数
dataset_path = "SWE-bench_oracle" # "SWE-bench_bm25_27K" # "SWE-bench_13k"
sys.argv = ["run_api.py", "--dataset_name_or_path", f"princeton-nlp/{dataset_path}", "--output_dir", "./outputs"]
# 执行脚本
runpy.run_path(path_name="run_api.py", run_name="__main__")
finally:
# 恢复原始的sys.argv以避免对后续代码的潜在影响
sys.argv = original_argv

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from tenacity import retry, stop_after_attempt, wait_random_exponential
from benchmark.swe_bench.gitagent import GitAgent
from benchmark.swe_bench.make_datasets.make_dataset import reset_task_env
from benchmark.swe_bench.utils.utils import extract_scripts_from_codetext
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.recovery_util import save_history
PATCH_FORMAT = """
```diff
--- original_file.py
+++ modified_file.py
@@ -line_number,context_lines +line_number,context_lines @@
- original line of code to be replaced or removed
+ new line of code to be added or to replace the original
```
"""
def _prepare(inputs):
requirement = "Please rewrite the code to address the issues. "
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
cleaned_user_message = re.sub("<patch>.*?</patch>", "", user_message, flags=re.DOTALL)
issues = re.findall("<issue>(.*?)</issue>", user_message, flags=re.DOTALL)
return requirement, system_messages, cleaned_user_message, issues
def construct_prompt(inputs, script_names):
prompt = (
f"You only need to modify the code file listed here {script_names}."
f"Notice: "
f"1. Analysis the issue, especially for the ValueError, and identify influence code lines.\n"
f"2. Only change a few lines, and make sure I can use git diff and git apply to resolve the issue .\n"
f"3. I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.\n"
f"4. use the format as : {PATCH_FORMAT}"
)
requirement, system_messages, cleaned_user_message, issues = _prepare(inputs)
return requirement, system_messages, cleaned_user_message, issues, prompt
@handle_exception(exception_type=Exception)
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
async def run_agent(inputs, agent, **kwargs):
script_names = kwargs.get("script_names", [])
requirement, system_messages, cleaned_user_message, issues, prompt = construct_prompt(inputs, script_names)
system_messages = system_messages.replace(" ", "")
cleaned_user_message = cleaned_user_message.replace(" ", "")
await agent.run([requirement, system_messages, cleaned_user_message, prompt])
return agent.get_last_cell_source()
async def run_instance(instance, use_reflection=True):
ga = GitAgent(use_reflection=use_reflection)
script_names = extract_scripts_from_codetext(instance["text"])
ga.script_names = script_names
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
response = await run_agent(f"{instance['text']}\n\n", agent=ga, script_names=script_names)
logger.info(f"Final response: {response}")
save_history(ga)
return response

View file

@ -0,0 +1,112 @@
import json
from pathlib import Path
import fire
from tqdm.auto import tqdm
from benchmark.swe_bench.data.load_dataset import load_oracle_dataset
from benchmark.swe_bench.inference.run_agent import run_instance
from benchmark.swe_bench.utils.utils import check_existing_ids, extract_diff
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils import count_string_tokens
# Replace with your own
MAX_TOKEN = 128000
async def openai_inference(
test_dataset,
model_name_or_path,
output_file,
existing_ids,
use_reflection,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
model_name_or_path (str): The name or path of the model to use.
output_file (str): The path to the output file.
existing_ids (set): A set of ids that have already been processed.
"""
test_dataset = test_dataset.filter(
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
desc="Filtering",
load_from_cache_file=False,
)
basic_args = {
"model_name_or_path": model_name_or_path,
}
logger.info(f"Filtered to {len(test_dataset)} instances")
data = []
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
version = datum["version"]
repo = datum["repo"]
repo_prefix = repo.replace("/", "__")
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
logger.info(f"{repo_prefix}_{version}")
data.append(f"{repo_prefix}_{version}")
response = await run_instance(instance=datum, use_reflection=use_reflection)
if response is None:
continue
logger.info(f"Final response: {response}")
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)
async def main(
dataset_name_or_path,
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
use_reflection=True,
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
Args:
dataset_name_or_path: HuggingFace dataset name or local path
split: Dataset split to use (default: test)
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
print(output_file.absolute())
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Will write to {output_file}")
# check existing results
existing_ids = check_existing_ids(output_file)
# load dataset
dataset = load_oracle_dataset(dataset_name_or_path)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
"use_reflection": use_reflection,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info("Done!")
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import os
from pathlib import Path
from tqdm.auto import tqdm
from benchmark.swe_bench.inference.const import TESTBED
from benchmark.swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
from benchmark.swe_bench.utils.parse_diff import filter_changed_line
from benchmark.swe_bench.utils.repo_utils import EnvManager
from metagpt.logs import logger
def reset_task_env(instance: dict = {}):
# reset the env via git reset and git checkout
env_manager = EnvManager(testbed=TESTBED)
patch = instance["patch"]
repo = instance["repo"]
instance["version"]
repo_prefix = repo.replace("/", "__")
repo_path = os.path.join(env_manager.testbed, repo_prefix)
if not os.path.exists(repo_path):
return patch, repo, None
os.chdir(repo_path)
if not env_manager.reset_task_env(instance=instance):
return patch, repo, None
return patch, repo, repo_path
def reset_and_copy(instance: dict = {}):
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
env_manager = EnvManager(testbed=TESTBED)
repo_prefix = repo.replace("/", "__")
version = instance["version"]
destination_path = os.path.join(repo_path, f"{repo_prefix}__{version}")
env_manager.copy_repo(source_path=repo_path, destination_path=destination_path)
def make_oracle_collapsed_instance(instance):
# for each instance, reset task env
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
file_changes = filter_changed_line(patch)
prompt = prompt_style_2_edits_only(instance, Path(repo_path), list(file_changes.keys()))
logger.info(prompt)
# todo: save output
return {}
def make_oracle_collapsed_dataset(dataset):
for datum in tqdm(dataset, desc="Inference "):
make_oracle_collapsed_instance(instance=datum)
# todo: save output

View file

@ -0,0 +1,193 @@
from pathlib import Path
import unidiff
PATCH_EXAMPLE = """--- a/file.py
+++ b/file.py
@@ -1,27 +1,35 @@
def euclidean(a, b):
- while b:
- a, b = b, a % b
- return a
+ if b == 0:
+ return a
+ return euclidean(b, a % b)
def bresenham(x0, y0, x1, y1):
points = []
dx = abs(x1 - x0)
dy = abs(y1 - y0)
- sx = 1 if x0 < x1 else -1
- sy = 1 if y0 < y1 else -1
- err = dx - dy
+ x, y = x0, y0
+ sx = -1 if x0 > x1 else 1
+ sy = -1 if y0 > y1 else 1
- while True:
- points.append((x0, y0))
- if x0 == x1 and y0 == y1:
- break
- e2 = 2 * err
- if e2 > -dy:
+ if dx > dy:
+ err = dx / 2.0
+ while x != x1:
+ points.append((x, y))
err -= dy
- x0 += sx
- if e2 < dx:
- err += dx
- y0 += sy
+ if err < 0:
+ y += sy
+ err += dx
+ x += sx
+ else:
+ err = dy / 2.0
+ while y != y1:
+ points.append((x, y))
+ err -= dx
+ if err < 0:
+ x += sx
+ err += dy
+ y += sy
+ points.append((x, y))
return points"""
FULL_GENERATION_EXAMPLE = """[start of /src/this_file.py]
import os
def euclidean(a, b):
if b == 0:
return a
return euclidean(b, a % b)
[end of /src/this_file.py]
[start of /src/another_file.py]
def bresenham(x0, y0, x1, y1):
points = []
dx = abs(x1 - x0)
dy = abs(y1 - y0)
x, y = x0, y0
sx = -1 if x0 > x1 else 1
sy = -1 if y0 > y1 else 1
if dx > dy:
err = dx / 2.0
while x != x1:
points.append((x, y))
err -= dy
if err < 0:
y += sy
err += dx
x += sx
else:
err = dy / 2.0
while y != y1:
points.append((x
err -= dx
if err < 0:
x += sx
err += dy
y += sy
points.append((x, y))
return points
[end of /src/another_file.py]"""
def add_lines_list(content):
content_with_lines = list()
for ix, line in enumerate(content.split("\n"), start=1):
content_with_lines.append(f"{ix} {line}")
return content_with_lines
def add_lines(content):
return "\n".join(add_lines_list(content))
def make_code_text(files_dict, add_line_numbers=True):
all_text = ""
for filename, contents in sorted(files_dict.items()):
all_text += f"[start of {filename}]\n"
if add_line_numbers:
all_text += add_lines(contents)
else:
all_text += contents
all_text += f"\n[end of {filename}]\n"
return all_text.strip("\n")
def make_code_text_edits_only(files_dict, patch, root_path, add_line_numbers=True):
files = dict()
patch = unidiff.PatchSet(patch)
for patched_file in patch:
source_file = root_path / patched_file.source_file.split("a/", 1)[-1]
files[source_file] = list()
for hunk in patched_file:
start = hunk.source_start - 15
end = start + hunk.source_length + 15
files[source_file].append((start, end))
all_text = ""
for filename, content in files_dict.items():
# filename = str(filename)
all_text += f"[start of {filename}]\n"
content_with_lines = add_lines_list(content)
for start, end in files[filename]:
if start > 0:
all_text += "...\n"
all_text += "\n".join(content_with_lines[start:end])
all_text += "\n"
if end < len(content_with_lines):
all_text += "...\n"
all_text = all_text.strip("\n")
all_text += f"\n[end of {filename}]\n"
return all_text.strip("\n")
def prompt_style_2_edits_only(instance, root_path, filenames):
premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
readmes = get_readme_files(root_path)
instance["readmes"] = ingest_files([root_path / readme for readme in readmes])
readmes_text = make_code_text(instance["readmes"])
instance["file_contents"] = ingest_files([root_path / filename for filename in filenames])
code_text = make_code_text_edits_only(instance["file_contents"], instance["patch"], root_path)
instructions = (
"I need you to solve this issue by generating a single patch file that I can apply "
+ "directly to this repository using git apply. Please respond with a single patch "
+ "file in the following format."
)
problem_statement = instance["problem_statement"]
final_text = [
premise,
"<issue>",
problem_statement,
"</issue>",
"<code>",
readmes_text,
code_text,
"</code>",
instructions,
"<patch>",
PATCH_EXAMPLE,
"</patch>",
]
final_text = "\n".join(final_text)
return final_text
def ingest_files(file_paths):
files_dict = dict()
for file_path in file_paths:
files_dict[file_path] = Path.read_text(file_path, encoding="utf-8")
return files_dict
def get_readme_files(repo_path):
path = Path(repo_path)
# 检查文件名是否以 "readme" 开头,不区分大小写
files = [file.name for file in path.iterdir() if file.is_file() and file.name.lower().startswith("readme")]
return files

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,12 @@
from enum import Enum, auto
class FileChangeMode(Enum):
create = auto()
delete = auto()
change = auto()
class LineChangeType(Enum):
addition = auto()
deletion = auto()

View file

@ -0,0 +1,116 @@
import re
from typing import Dict, List
from swe_bench.utils.enums import FileChangeMode, LineChangeType
from metagpt.logs import logger
def extract_changes_from_patch(diff: str) -> List[Dict[str, any]]:
"""Parses the patch text through the standard syntax of git diff, outputs the information of added and deleted lines.
Extracts detailed information about file changes based on the output content of git diff.
Args:
diff: A string containing the output of git diff.
Returns:
A list of dictionaries containing information about each file change.
"""
changes = []
current_file = None
file_pattern = re.compile(r"^diff --git a/(.+) b/(.+)$")
line_change_pattern = re.compile(r"^@@ -(\d+),\d+ \+(\d+),\d+ @@.*$")
new_file_flag_line = "--- /dev/null"
deleted_file_flag_line = "+++ /dev/null"
for line in diff.splitlines():
file_section_start = file_pattern.match(line)
if file_section_start:
if current_file:
changes.append(current_file)
file_a, file_b = file_section_start.groups()
current_file = start_new_file_section(file_a, file_b)
current_file["mode"] = FileChangeMode.change
elif current_file:
# 匹配到新文件模式,标记当前文件为新增
if line == new_file_flag_line:
current_file["mode"] = FileChangeMode.create
# 匹配到删除文件模式,标记当前文件为删除
elif line == deleted_file_flag_line:
current_file["mode"] = FileChangeMode.delete
update_file_changes(current_file, line, line_change_pattern)
if current_file:
changes.append(current_file)
return changes
def start_new_file_section(file_before_change: str, file_after_change: str) -> Dict[str, any]:
"""Function to initialize a new file section
When encountering a new file change, this function is called to initialize a dictionary recording the file change information.
Args:
file_before_change: The file name before the change
file_after_change: The file name after the change, or "/dev/null" if the file was deleted.
Returns:
A dictionary representing the file change.
"""
return {
"file_before_change": file_before_change,
"file_after_change": file_after_change,
"changes": [],
}
def update_file_changes(current_file: Dict[str, any], line: str, line_change_pattern: re.Pattern):
"""Updates the current file change information
Updates the current file's change record based on a line from the diff.
Args:
current_file: The current file information being processed
line: The current line from the diff
line_change_pattern: The regex pattern used to identify line changes
"""
line_change_match = line_change_pattern.match(line)
if line_change_match:
current_file["base_line"], current_file["changed_line"] = map(int, line_change_match.groups())
elif line.startswith("+"):
current_file["changes"].append(
{"type": LineChangeType.addition, "line": current_file.get("changed_line", 1), "content": line[1:]}
)
current_file["changed_line"] = current_file.get("changed_line", 0) + 1
elif line.startswith("-"):
current_file["changes"].append(
{"type": LineChangeType.deletion, "line": current_file.get("base_line", 1), "content": line[1:]}
)
current_file["base_line"] = current_file.get("base_line", 0) + 1
def filter_changed_line(patch):
"""Filters changed lines
Filters the part of the change record of the current file that needs to be used.
Args:
patch: The git diff text
"""
parsed_changes = extract_changes_from_patch(patch)
res = {}
for change in parsed_changes:
file_name = change["file_before_change"]
res[file_name] = []
# 新增的文件略过
if change["mode"] is FileChangeMode.create:
continue
for c in change["changes"]:
if c["type"] is LineChangeType.addition:
continue
logger.debug(f" {c['type']} at line {c['line']}: {c['content']}")
res[file_name].append(c)
return res

View file

@ -0,0 +1,93 @@
import os
import shutil
import subprocess
from pathlib import Path
from typing import Dict
import git
from git.exc import GitError
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
KEY_INSTANCE_ID = "instance_id"
RESET_FAILED = ">>>>> Reset Failed"
class ExecWrapper:
def __init__(self, subprocess_args: Dict = None):
self.subprocess_args = subprocess_args or {}
@handle_exception(exception_type=subprocess.CalledProcessError)
def __call__(self, cmd, raise_error=True, **kwargs):
combined_args = {**self.subprocess_args, **kwargs}
output = subprocess.run(cmd, **combined_args)
return output
class EnvManager:
def __init__(self, testbed):
shellenv = os.environ.copy()
self.testbed = testbed
self.exec = ExecWrapper(
subprocess_args={
"check": True,
"shell": False,
"capture_output": True,
"text": True,
"env": shellenv,
}
)
@handle_exception(exception_type=GitError)
def clone_repo(self, repo_name: str, path: str, token: str = None):
if token is None:
token = os.environ.get("GITHUB_TOKEN", "git")
if not token:
raise ValueError("GitHub token is required for cloning repositories.")
repo_url = f"https://{token}@github.com/swe-bench/{repo_name.replace('/', '__')}.git"
os.makedirs(path, exist_ok=True)
# Clone the repository
git.Repo.clone_from(repo_url, path)
logger.info(f"Repository '{repo_name}' cloned successfully.")
@handle_exception(exception_type=Exception) # Using a broad exception type for the example
def copy_repo(self, source_path: str, destination_path: str):
if not os.path.isdir(source_path):
raise ValueError("Source path does not exist or is not a directory.")
os.makedirs(destination_path, exist_ok=True)
# Copy the repository
try:
shutil.copytree(
source_path, destination_path, dirs_exist_ok=True
) # For Python 3.8+, dirs_exist_ok handles existing directories
except TypeError:
# Fallback for Python < 3.8, where dirs_exist_ok is not available
if os.listdir(destination_path): # If destination is not empty
raise ValueError("Destination directory is not empty and dirs_exist_ok is not supported.")
shutil.copytree(source_path, destination_path)
logger.info(f"Repository contents from '{source_path}' copied successfully to '{destination_path}'.")
@handle_exception(exception_type=Exception, default_return=False)
def reset_task_env(self, instance: Dict):
"""
Reset task environment + testbed and checkout base commit of given task instance
"""
gitignore_path = Path(".gitignore")
if gitignore_path.exists():
self.exec(["git", "ls-files", "--ignored", "--exclude-standard", "-o", "-z"], raise_error=False)
# fixme: need detect platform and change this cmd
# self.exec(["xargs", "-0", "-r", "rm", "-rf"], input=gitignore_path.read_text())
self.exec(["git", "restore", "."])
self.exec(["git", "reset", "HEAD", "."])
self.exec(["git", "clean", "-fdx"])
self.exec(["git", "-c", "advice.detachedHead=false", "checkout", instance["base_commit"]])
logger.info(f"[{instance['instance_id']}] Reset task environment to {instance['base_commit']}")
return True

View file

@ -0,0 +1,81 @@
import json
import os
import re
from metagpt.logs import logger
def check_existing_ids(output_file):
existing_ids = set()
if os.path.exists(output_file):
with open(output_file, "r") as f:
for line in f:
data = json.loads(line)
instance_id = data["instance_id"]
existing_ids.add(instance_id)
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
return existing_ids
def extract_diff(response):
"""
Extracts the diff from a response formatted in different ways
"""
if response is None:
return None
diff_matches = []
other_matches = []
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
if diff_matches:
return diff_matches[0]
if other_matches:
return other_matches[0]
return response.split("</s>")[0]
def extract_scripts_from_codetext(codetext: str):
"""
Extracts Python script file names from a given text that contains multiple sections.
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
Parameters:
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
Returns:
- list: A list of extracted Python script file names.
Example of codetext:
'''
[end of README.rst]
[start of sklearn/compose/_target.py]
... file content ...
[end of sklearn/compose/_target.py]
[start of another_module/example.py]
... file content ...
[end of another_module/example.py]
'''
"""
script_names = []
# Match all occurrences of '[start of <script_name>.py]'
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
if matches:
for script_name in matches:
print("Extracted script name:", script_name)
script_names.append(script_name)
else:
print("No script names found in the text.")
return script_names