Merge pull request #1074 from stellaHSR/swebench_di

Update agent flow
This commit is contained in:
Alexander Wu 2024-03-22 14:40:06 +08:00 committed by GitHub
commit 19ccec7080
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 653 additions and 253 deletions

View file

@ -28,22 +28,24 @@ SCIKIT_LEARN_IDS = [
"scikit-learn__scikit-learn-10459",
]
MATPLOTLIB_IDS = [
"matplotlib__matplotlib-24362",
"matplotlib__matplotlib-20584",
"matplotlib__matplotlib-23188",
"matplotlib__matplotlib-24403",
# 'matplotlib__matplotlib-21443',
# 'matplotlib__matplotlib-23047'
]
def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
def read_subset_instance(path=SUBSET_DATASET, tag="scikit-learn"):
try:
df = pd.read_excel(path)
# Filter for instances containing the tag in either column
pass_filter = df["instance_id_pass"].str.contains(tag, na=False)
fail_filter = df["instance_id_fail"].str.contains(tag, na=False)
# Combine the filters using | (OR operator) for efficiency
combined_filter = pass_filter | fail_filter
# Apply combined filter and select the specific columns
filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]]
# Flatten the DataFrame into a list and remove NaN values
subset_instance = filtered_df.stack().dropna().tolist()
pass_filters = df["instance_id_pass"].tolist()
fail_filters = df["instance_id_fail"].tolist()
pass_filters = [s for s in pass_filters if tag in s]
fail_filters = [s for s in fail_filters if tag in s]
subset_instance = pass_filters + fail_filters
return subset_instance
except FileNotFoundError:
@ -52,3 +54,7 @@ def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
except Exception as e:
print(f"An error occurred: {e}")
return []
if __name__ == "__main__":
print(read_subset_instance(tag="matplotlib__matplotlib"))

View file

@ -1,38 +0,0 @@
import re
def extract_scripts_from_codetext(codetext: str):
"""
Extracts Python script file names from a given text that contains multiple sections.
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
Parameters:
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
Returns:
- list: A list of extracted Python script file names.
Example of codetext:
'''
[end of README.rst]
[start of sklearn/compose/_target.py]
... file content ...
[end of sklearn/compose/_target.py]
[start of another_module/example.py]
... file content ...
[end of another_module/example.py]
'''
"""
script_names = []
# Match all occurrences of '[start of <script_name>.py]'
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
if matches:
for script_name in matches:
print("Extracted script name:", script_name)
script_names.append(script_name)
else:
print("No script names found in the text.")
return script_names

View file

@ -1,28 +0,0 @@
import re
def extract_diff(response):
"""
Extracts the diff from a response formatted in different ways
"""
if response is None:
return None
diff_matches = []
other_matches = []
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
if diff_matches:
return diff_matches[0]
if other_matches:
return other_matches[0]
return response.split("</s>")[0]

View file

@ -1,173 +0,0 @@
import json
import os
import traceback
from pathlib import Path
import fire
import numpy as np
from datasets import load_dataset, load_from_disk
from make_datasets.utils import extract_diff
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm
from data.inference.const import SCIKIT_LEARN_IDS, TESTBED
from data.inference.make_datasets.parse_diff import extract_scripts_from_codetext
from data.inference.make_datasets.repo_utils import EnvManager
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils import count_string_tokens
from metagpt.utils.recovery_util import save_history
# Replace with your own
MAX_TOKEN = 128000
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
async def call_chat(inputs, interpreter):
"""
Calls the openai API to generate completions for the given inputs.
Args:
inputs (str): The inputs to generate completions for.
interpreter (DataInterpreter): The data interpreter to use for execution.
"""
requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code."
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
cleaned_user_message = user_message.split(
"I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file in the following format."
)[0]
try:
await interpreter.run([requirement, system_messages, cleaned_user_message])
return interpreter.get_last_cell_source()
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
traceback.print_exc()
raise e
async def openai_inference(
test_dataset,
model_name_or_path,
output_file,
existing_ids,
use_reflection,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
model_name_or_path (str): The name or path of the model to use.
output_file (str): The path to the output file.
existing_ids (set): A set of ids that have already been processed.
"""
test_dataset = test_dataset.filter(
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
desc="Filtering",
load_from_cache_file=False,
)
basic_args = {
"model_name_or_path": model_name_or_path,
}
print(f"Filtered to {len(test_dataset)} instances")
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter(use_reflection=use_reflection)
env_manager = EnvManager(testbed=TESTBED)
instance_id = datum["instance_id"]
script_names = extract_scripts_from_codetext(datum["text"])
logger.info(script_names)
repo = datum["repo"]
repo_prefix = repo.replace("/", "__")
repo_path = os.path.join(env_manager.testbed, repo_prefix)
if not os.path.exists(repo_path):
continue
os.chdir(repo_path)
env_manager.reset_task_env(instance=datum)
if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
response = await call_chat(
output_dict["text"],
di,
)
logger.info(f"Final response: {response}")
save_history(di)
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)
async def main(
dataset_name_or_path,
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
use_reflection=True,
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
Args:
dataset_name_or_path: HuggingFace dataset name or local path
split: Dataset split to use (default: test)
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Will write to {output_file}")
existing_ids = set()
if os.path.exists(output_file):
with open(output_file, "r") as f:
for line in f:
data = json.loads(line)
instance_id = data["instance_id"]
existing_ids.add(instance_id)
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
if Path(dataset_name_or_path).exists():
dataset = load_from_disk(dataset_name_or_path)
else:
dataset = load_dataset(dataset_name_or_path)
if split not in dataset:
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))
if len(existing_ids) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
if len(SCIKIT_LEARN_IDS) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
desc="Filtering out subset_instance_ids",
load_from_cache_file=False,
)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
"use_reflection": use_reflection,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info("Done!")
if __name__ == "__main__":
fire.Fire(main)

3
swe_bench/__init__.py Normal file
View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

99
swe_bench/gitagent.py Normal file
View file

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import Literal, Union
from metagpt.actions.di.ask_review import ReviewConst
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message
class GitAgent(DataInterpreter):
name: str = "Jacky"
profile: str = "Solve git issues proficiently"
auto_run: bool = True
use_plan: bool = True
use_reflection: bool = False
react_mode: Literal["plan_and_act", "react"] = "react"
script_names: Union[str, list[str]] = []
instance_id: str = ""
async def critique(self, result, review_format):
review_result = (
"Finally, return a boolean value (True or False) to indicate the result of the review. "
"Note: If the result is good enough, return True; otherwise, return False."
)
status = await self.llm.aask(
[
Message(content=review_format, role="user"),
Message(content=result, role="assistant"),
Message(content=review_result, role="user"),
]
)
logger.info(status)
return status
async def review_patch(self, code):
review_format = (
"Please ensure that the code {code} and original script {original_script} can fix the issue {memory} in patch format. "
"If it is not in patch format, please convert it to patch format."
)
results = []
for script in self.script_names:
with open(script, "r", encoding="utf-8") as fp:
original_script = fp.read()
memory = self.get_memories()[0].content
review_prompt = review_format.format(code=code, original_script=original_script, memory=memory)
# todo: extract issue and remove image urls
result = await self.llm.aask(review_prompt)
results.append(result)
# fixme: merge results to a single patch file
result = "\n".join(results)
return result, review_prompt
async def _write_and_exec_code(self, max_retry: int = 3):
counter = 0
success = False
# plan info
plan_status = self.planner.get_plan_status() if self.use_plan else ""
# tool info
if self.tools:
context = (
self.working_memory.get()[-1].content if self.working_memory.get() else ""
) # thoughts from _think stage in 'react' mode
plan = self.planner.plan if self.use_plan else None
tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
else:
tool_info = ""
while not success and counter < max_retry:
### write code ###
code, cause_by = await self._write_code(counter, plan_status, tool_info)
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
result, format_prompt = await self.review_patch(code)
success = await self.critique(result, format_prompt)
await self.execute_code.run(code)
### execute code ###
# todo: execute: git apply
### process execution result ###
counter += 1
if not success and counter >= max_retry:
logger.info("coding failed!")
review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
if ReviewConst.CHANGE_WORDS[0] in review:
counter = 0 # redo the task again with help of human suggestions
return code, result, success

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -8,7 +8,8 @@ original_argv = sys.argv.copy()
try:
# 设置你想要传递给脚本的命令行参数
sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"]
dataset_path = "SWE-bench_oracle" # "SWE-bench_bm25_27K" # "SWE-bench_13k"
sys.argv = ["run_api.py", "--dataset_name_or_path", f"princeton-nlp/{dataset_path}", "--output_dir", "./outputs"]
# 执行脚本
runpy.run_path(path_name="run_api.py", run_name="__main__")
finally:

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.recovery_util import save_history
from swe_bench.gitagent import GitAgent
from swe_bench.make_datasets.make_dataset import reset_task_env
from swe_bench.utils.utils import extract_scripts_from_codetext
PATCH_FORMAT = """
```diff
--- original_file.py
+++ modified_file.py
@@ -line_number,context_lines +line_number,context_lines @@
- original line of code to be replaced or removed
+ new line of code to be added or to replace the original
```
"""
def _prepare(inputs):
requirement = "Please rewrite the code to address the issues. "
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
cleaned_user_message = re.sub("<patch>.*?</patch>", "", user_message, flags=re.DOTALL)
issues = re.findall("<issue>(.*?)</issue>", user_message, flags=re.DOTALL)
return requirement, system_messages, cleaned_user_message, issues
def construct_prompt(inputs, script_names):
prompt = (
f"You only need to modify the code file listed here {script_names}."
f"Notice: "
f"1. Analysis the issue, especially for the ValueError, and identify influence code lines.\n"
f"2. Only change a few lines, and make sure I can use git diff and git apply to resolve the issue .\n"
f"3. I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.\n"
f"4. use the format as : {PATCH_FORMAT}"
)
requirement, system_messages, cleaned_user_message, issues = _prepare(inputs)
return requirement, system_messages, cleaned_user_message, issues, prompt
@handle_exception(exception_type=Exception)
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
async def run_agent(inputs, agent, **kwargs):
script_names = kwargs.get("script_names", [])
requirement, system_messages, cleaned_user_message, issues, prompt = construct_prompt(inputs, script_names)
system_messages = system_messages.replace(" ", "")
cleaned_user_message = cleaned_user_message.replace(" ", "")
await agent.run([requirement, system_messages, cleaned_user_message, prompt])
return agent.get_last_cell_source()
async def run_instance(instance, use_reflection=True):
ga = GitAgent(use_reflection=use_reflection)
script_names = extract_scripts_from_codetext(instance["text"])
ga.script_names = script_names
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
response = await run_agent(f"{instance['text']}\n\n", agent=ga, script_names=script_names)
logger.info(f"Final response: {response}")
save_history(ga)
return response

View file

@ -0,0 +1,112 @@
import json
from pathlib import Path
import fire
from data.load_dataset import load_oracle_dataset
from tqdm.auto import tqdm
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils import count_string_tokens
from swe_bench.inference.run_agent import run_instance
from swe_bench.utils.utils import check_existing_ids, extract_diff
# Replace with your own
MAX_TOKEN = 128000
async def openai_inference(
test_dataset,
model_name_or_path,
output_file,
existing_ids,
use_reflection,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
model_name_or_path (str): The name or path of the model to use.
output_file (str): The path to the output file.
existing_ids (set): A set of ids that have already been processed.
"""
test_dataset = test_dataset.filter(
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
desc="Filtering",
load_from_cache_file=False,
)
basic_args = {
"model_name_or_path": model_name_or_path,
}
logger.info(f"Filtered to {len(test_dataset)} instances")
data = []
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
version = datum["version"]
repo = datum["repo"]
repo_prefix = repo.replace("/", "__")
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
logger.info(f"{repo_prefix}_{version}")
data.append(f"{repo_prefix}_{version}")
response = await run_instance(instance=datum)
if response is None:
continue
logger.info(f"Final response: {response}")
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)
async def main(
dataset_name_or_path,
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
use_reflection=True,
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
Args:
dataset_name_or_path: HuggingFace dataset name or local path
split: Dataset split to use (default: test)
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
print(output_file.absolute())
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Will write to {output_file}")
# check existing results
existing_ids = check_existing_ids(output_file)
# load dataset
dataset = load_oracle_dataset(dataset_name_or_path)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
"use_reflection": use_reflection,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info("Done!")
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import os
from pathlib import Path
from tqdm.auto import tqdm
from data.inference.const import TESTBED
from metagpt.logs import logger
from swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
from swe_bench.utils.parse_diff import filter_changed_line
from swe_bench.utils.repo_utils import EnvManager
def reset_task_env(instance: dict = {}):
# reset the env via git reset and git checkout
env_manager = EnvManager(testbed=TESTBED)
patch = instance["patch"]
repo = instance["repo"]
instance["version"]
repo_prefix = repo.replace("/", "__")
repo_path = os.path.join(env_manager.testbed, repo_prefix)
if not os.path.exists(repo_path):
return patch, repo, None
os.chdir(repo_path)
if not env_manager.reset_task_env(instance=instance):
return patch, repo, None
return patch, repo, repo_path
def reset_and_copy(instance: dict = {}):
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
env_manager = EnvManager(testbed=TESTBED)
repo_prefix = repo.replace("/", "__")
version = instance["version"]
destination_path = os.path.join(repo_path, f"{repo_prefix}__{version}")
env_manager.copy_repo(source_path=repo_path, destination_path=destination_path)
def make_oracle_collapsed_instance(instance):
# for each instance, reset task env
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
file_changes = filter_changed_line(patch)
prompt = prompt_style_2_edits_only(instance, Path(repo_path), list(file_changes.keys()))
logger.info(prompt)
# todo: save output
return {}
def make_oracle_collapsed_dataset(dataset):
for datum in tqdm(dataset, desc="Inference "):
make_oracle_collapsed_instance(instance=datum)
# todo: save output

View file

@ -0,0 +1,193 @@
from pathlib import Path
import unidiff
PATCH_EXAMPLE = """--- a/file.py
+++ b/file.py
@@ -1,27 +1,35 @@
def euclidean(a, b):
- while b:
- a, b = b, a % b
- return a
+ if b == 0:
+ return a
+ return euclidean(b, a % b)
def bresenham(x0, y0, x1, y1):
points = []
dx = abs(x1 - x0)
dy = abs(y1 - y0)
- sx = 1 if x0 < x1 else -1
- sy = 1 if y0 < y1 else -1
- err = dx - dy
+ x, y = x0, y0
+ sx = -1 if x0 > x1 else 1
+ sy = -1 if y0 > y1 else 1
- while True:
- points.append((x0, y0))
- if x0 == x1 and y0 == y1:
- break
- e2 = 2 * err
- if e2 > -dy:
+ if dx > dy:
+ err = dx / 2.0
+ while x != x1:
+ points.append((x, y))
err -= dy
- x0 += sx
- if e2 < dx:
- err += dx
- y0 += sy
+ if err < 0:
+ y += sy
+ err += dx
+ x += sx
+ else:
+ err = dy / 2.0
+ while y != y1:
+ points.append((x, y))
+ err -= dx
+ if err < 0:
+ x += sx
+ err += dy
+ y += sy
+ points.append((x, y))
return points"""
FULL_GENERATION_EXAMPLE = """[start of /src/this_file.py]
import os
def euclidean(a, b):
if b == 0:
return a
return euclidean(b, a % b)
[end of /src/this_file.py]
[start of /src/another_file.py]
def bresenham(x0, y0, x1, y1):
points = []
dx = abs(x1 - x0)
dy = abs(y1 - y0)
x, y = x0, y0
sx = -1 if x0 > x1 else 1
sy = -1 if y0 > y1 else 1
if dx > dy:
err = dx / 2.0
while x != x1:
points.append((x, y))
err -= dy
if err < 0:
y += sy
err += dx
x += sx
else:
err = dy / 2.0
while y != y1:
points.append((x
err -= dx
if err < 0:
x += sx
err += dy
y += sy
points.append((x, y))
return points
[end of /src/another_file.py]"""
def add_lines_list(content):
content_with_lines = list()
for ix, line in enumerate(content.split("\n"), start=1):
content_with_lines.append(f"{ix} {line}")
return content_with_lines
def add_lines(content):
return "\n".join(add_lines_list(content))
def make_code_text(files_dict, add_line_numbers=True):
all_text = ""
for filename, contents in sorted(files_dict.items()):
all_text += f"[start of {filename}]\n"
if add_line_numbers:
all_text += add_lines(contents)
else:
all_text += contents
all_text += f"\n[end of {filename}]\n"
return all_text.strip("\n")
def make_code_text_edits_only(files_dict, patch, root_path, add_line_numbers=True):
files = dict()
patch = unidiff.PatchSet(patch)
for patched_file in patch:
source_file = root_path / patched_file.source_file.split("a/", 1)[-1]
files[source_file] = list()
for hunk in patched_file:
start = hunk.source_start - 15
end = start + hunk.source_length + 15
files[source_file].append((start, end))
all_text = ""
for filename, content in files_dict.items():
# filename = str(filename)
all_text += f"[start of {filename}]\n"
content_with_lines = add_lines_list(content)
for start, end in files[filename]:
if start > 0:
all_text += "...\n"
all_text += "\n".join(content_with_lines[start:end])
all_text += "\n"
if end < len(content_with_lines):
all_text += "...\n"
all_text = all_text.strip("\n")
all_text += f"\n[end of {filename}]\n"
return all_text.strip("\n")
def prompt_style_2_edits_only(instance, root_path, filenames):
premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
readmes = get_readme_files(root_path)
instance["readmes"] = ingest_files([root_path / readme for readme in readmes])
readmes_text = make_code_text(instance["readmes"])
instance["file_contents"] = ingest_files([root_path / filename for filename in filenames])
code_text = make_code_text_edits_only(instance["file_contents"], instance["patch"], root_path)
instructions = (
"I need you to solve this issue by generating a single patch file that I can apply "
+ "directly to this repository using git apply. Please respond with a single patch "
+ "file in the following format."
)
problem_statement = instance["problem_statement"]
final_text = [
premise,
"<issue>",
problem_statement,
"</issue>",
"<code>",
readmes_text,
code_text,
"</code>",
instructions,
"<patch>",
PATCH_EXAMPLE,
"</patch>",
]
final_text = "\n".join(final_text)
return final_text
def ingest_files(file_paths):
files_dict = dict()
for file_path in file_paths:
files_dict[file_path] = Path.read_text(file_path, encoding="utf-8")
return files_dict
def get_readme_files(repo_path):
path = Path(repo_path)
# 检查文件名是否以 "readme" 开头,不区分大小写
files = [file.name for file in path.iterdir() if file.is_file() and file.name.lower().startswith("readme")]
return files

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

81
swe_bench/utils/utils.py Normal file
View file

@ -0,0 +1,81 @@
import json
import os
import re
from metagpt.logs import logger
def check_existing_ids(output_file):
existing_ids = set()
if os.path.exists(output_file):
with open(output_file, "r") as f:
for line in f:
data = json.loads(line)
instance_id = data["instance_id"]
existing_ids.add(instance_id)
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
return existing_ids
def extract_diff(response):
"""
Extracts the diff from a response formatted in different ways
"""
if response is None:
return None
diff_matches = []
other_matches = []
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
if diff_matches:
return diff_matches[0]
if other_matches:
return other_matches[0]
return response.split("</s>")[0]
def extract_scripts_from_codetext(codetext: str):
"""
Extracts Python script file names from a given text that contains multiple sections.
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
Parameters:
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
Returns:
- list: A list of extracted Python script file names.
Example of codetext:
'''
[end of README.rst]
[start of sklearn/compose/_target.py]
... file content ...
[end of sklearn/compose/_target.py]
[start of another_module/example.py]
... file content ...
[end of another_module/example.py]
'''
"""
script_names = []
# Match all occurrences of '[start of <script_name>.py]'
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
if matches:
for script_name in matches:
print("Extracted script name:", script_name)
script_names.append(script_name)
else:
print("No script names found in the text.")
return script_names