update code

change dir, add new role
This commit is contained in:
stellahsr 2024-03-22 10:22:49 +08:00
parent 3fac156d66
commit 7bf4505d90
11 changed files with 338 additions and 158 deletions

View file

@ -28,22 +28,38 @@ SCIKIT_LEARN_IDS = [
"scikit-learn__scikit-learn-10459",
]
MATPLOTLIB_IDS = [
"matplotlib__matplotlib-24362",
"matplotlib__matplotlib-20584",
"matplotlib__matplotlib-23188",
"matplotlib__matplotlib-24403",
# 'matplotlib__matplotlib-21443',
# 'matplotlib__matplotlib-23047'
]
def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
try:
df = pd.read_excel(path)
pass_filters = df["instance_id_pass"].tolist()
fail_filters = df["instance_id_fail"].tolist()
pass_filters = [s for s in pass_filters if tag in s]
fail_filters = [s for s in fail_filters if tag in s]
print(pass_filters)
print(fail_filters)
# Filter for instances containing the tag in either column
pass_filter = df["instance_id_pass"].str.contains(tag, na=False)
fail_filter = df["instance_id_fail"].str.contains(tag, na=False)
# pass_filter = df["instance_id_pass"].str.contains(tag, na=False)
# fail_filter = df["instance_id_fail"].str.contains(tag, na=False)
# Combine the filters using | (OR operator) for efficiency
combined_filter = pass_filter | fail_filter
# combined_filter = pass_filters | fail_filters
# print(df[combined_filter])
# Apply combined filter and select the specific columns
filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]]
# filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]]
# Flatten the DataFrame into a list and remove NaN values
subset_instance = filtered_df.stack().dropna().tolist()
subset_instance = pass_filters + fail_filters
return subset_instance
except FileNotFoundError:
@ -52,3 +68,7 @@ def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
except Exception as e:
print(f"An error occurred: {e}")
return []
if __name__ == "__main__":
print(read_sub_set_instance(tag="matplotlib__matplotlib"))

View file

@ -1,38 +0,0 @@
import re
def extract_scripts_from_codetext(codetext: str):
"""
Extracts Python script file names from a given text that contains multiple sections.
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
Parameters:
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
Returns:
- list: A list of extracted Python script file names.
Example of codetext:
'''
[end of README.rst]
[start of sklearn/compose/_target.py]
... file content ...
[end of sklearn/compose/_target.py]
[start of another_module/example.py]
... file content ...
[end of another_module/example.py]
'''
"""
script_names = []
# Match all occurrences of '[start of <script_name>.py]'
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
if matches:
for script_name in matches:
print("Extracted script name:", script_name)
script_names.append(script_name)
else:
print("No script names found in the text.")
return script_names

View file

@ -1,87 +0,0 @@
import os
import subprocess
from pathlib import Path
from traceback import format_exc
from typing import Dict
import git
from git.exc import GitError
from metagpt.logs import logger
KEY_INSTANCE_ID = "instance_id"
RESET_FAILED = ">>>>> Reset Failed"
class ExecWrapper:
def __init__(self, subprocess_args: Dict = None):
self.subprocess_args = subprocess_args or {}
def __call__(self, cmd, raise_error=True, **kwargs):
try:
combined_args = {**self.subprocess_args, **kwargs}
output = subprocess.run(cmd, **combined_args)
return output
except subprocess.CalledProcessError as e:
if raise_error:
error_message = (
f"Error: {e}\nError stdout: {e.stdout}\nError stderr: {e.stderr}\nError traceback: {format_exc()}"
)
logger.error(error_message)
raise e
class EnvManager:
def __init__(self, testbed):
shellenv = os.environ.copy()
self.testbed = testbed
self.exec = ExecWrapper(
subprocess_args={
"check": True,
"shell": False,
"capture_output": True,
"text": True,
"env": shellenv,
}
)
def clone_repo(self, repo_name: str, path: str, token: str = None):
if token is None:
token = os.environ.get("GITHUB_TOKEN", "git")
if not token:
raise ValueError("GitHub token is required for cloning repositories.")
repo_url = f"https://{token}@github.com/swe-bench/{repo_name.replace('/', '__')}.git"
try:
# Ensure the destination directory exists
os.makedirs(path, exist_ok=True)
# Clone the repository
git.Repo.clone_from(repo_url, path)
print(f"Repository '{repo_name}' cloned successfully.")
except GitError as e:
print(f"Failed to clone repository '{repo_name}': {e}")
def reset_task_env(self, instance: Dict):
"""
Reset task environment + testbed and checkout base commit of given task instance
"""
try:
gitignore_path = Path(".gitignore")
if gitignore_path.exists():
self.exec(["git", "ls-files", "--ignored", "--exclude-standard", "-o", "-z"], raise_error=False)
# fixme: need detect platform and change this cmd
# self.exec(["xargs", "-0", "-r", "rm", "-rf"], input=gitignore_path.read_text())
self.exec(["git", "restore", "."])
self.exec(["git", "reset", "HEAD", "."])
self.exec(["git", "clean", "-fdx"])
self.exec(["git", "-c", "advice.detachedHead=false", "checkout", instance["base_commit"]])
logger.info(f"[{instance['instance_id']}] Reset task environment to {instance['base_commit']}")
return True
except Exception as e:
err_msg = f"{RESET_FAILED}; Failed to reset task environment to {instance['base_commit']}: {e}"
logger.error(err_msg)
return False

View file

@ -1,28 +0,0 @@
import re
def extract_diff(response):
"""
Extracts the diff from a response formatted in different ways
"""
if response is None:
return None
diff_matches = []
other_matches = []
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
if diff_matches:
return diff_matches[0]
if other_matches:
return other_matches[0]
return response.split("</s>")[0]

3
swe_bench/__init__.py Normal file
View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

99
swe_bench/gitagent.py Normal file
View file

@ -0,0 +1,99 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import Literal, Union
from metagpt.actions.di.ask_review import ReviewConst
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message
class GitAgent(DataInterpreter):
name: str = "Jacky"
profile: str = "Solve git issues proficiently"
auto_run: bool = True
use_plan: bool = True
use_reflection: bool = False
react_mode: Literal["plan_and_act", "react"] = "react"
script_names: Union[str, list[str]] = []
instance_id: str = ""
async def critique(self, result, review_format):
review_result = (
"Finally, return a boolean value (True or False) to indicate the result of the review. "
"Note: If the result is good enough, return True; otherwise, return False."
)
status = await self.llm.aask(
[
Message(content=review_format, role="user"),
Message(content=result, role="assistant"),
Message(content=review_result, role="user"),
]
)
logger.info(status)
return status
async def review_patch(self, code):
review_format = (
"Please ensure that the code {code} and original script {original_script} can fix the issue {memory} in patch format. "
"If it is not in patch format, please convert it to patch format."
)
results = []
for script in self.script_names:
with open(script, "r", encoding="utf-8") as fp:
original_script = fp.read()
memory = self.get_memories()[0].content
review_prompt = review_format.format(code=code, original_script=original_script, memory=memory)
# todo: extract issue and remove image urls
result = await self.llm.aask(review_prompt)
results.append(result)
# fixme: merge results to a single patch file
result = "\n".join(results)
return result, review_prompt
async def _write_and_exec_code(self, max_retry: int = 3):
counter = 0
success = False
# plan info
plan_status = self.planner.get_plan_status() if self.use_plan else ""
# tool info
if self.tools:
context = (
self.working_memory.get()[-1].content if self.working_memory.get() else ""
) # thoughts from _think stage in 'react' mode
plan = self.planner.plan if self.use_plan else None
tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
else:
tool_info = ""
while not success and counter < max_retry:
### write code ###
code, cause_by = await self._write_code(counter, plan_status, tool_info)
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
result, format_prompt = await self.review_patch(code)
success = await self.critique(result, format_prompt)
await self.execute_code.run(code)
### execute code ###
# todo: execute: git apply
### process execution result ###
counter += 1
if not success and counter >= max_retry:
logger.info("coding failed!")
review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
if ReviewConst.CHANGE_WORDS[0] in review:
counter = 0 # redo the task again with help of human suggestions
return code, result, success

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import runpy
import sys
original_argv = sys.argv.copy()
try:
# 设置你想要传递给脚本的命令行参数
dataset_path = "SWE-bench_oracle" # "SWE-bench_bm25_27K" # "SWE-bench_13k"
sys.argv = ["run_api.py", "--dataset_name_or_path", f"princeton-nlp/{dataset_path}", "--output_dir", "./outputs"]
# 执行脚本
runpy.run_path(path_name="run_api.py", run_name="__main__")
finally:
# 恢复原始的sys.argv以避免对后续代码的潜在影响
sys.argv = original_argv

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.recovery_util import save_history
from swe_bench.gitagent import GitAgent
from swe_bench.make_datasets.make_dataset import reset_task_env
from swe_bench.utils.utils import extract_scripts_from_codetext
PATCH_FORMAT = """
```diff
--- original_file.py
+++ modified_file.py
@@ -line_number,context_lines +line_number,context_lines @@
- original line of code to be replaced or removed
+ new line of code to be added or to replace the original
```
"""
def _prepare(inputs):
requirement = "Please rewrite the code to address the issues. "
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
cleaned_user_message = re.sub("<patch>.*?</patch>", "", user_message, flags=re.DOTALL)
issues = re.findall("<issue>(.*?)</issue>", user_message, flags=re.DOTALL)
return requirement, system_messages, cleaned_user_message, issues
def construct_prompt(inputs, script_names):
prompt = (
f"You only need to modify the code file listed here {script_names}."
f"Notice: "
f"1. Analysis the issue, especially for the ValueError, and identify influence code lines.\n"
f"2. Only change a few lines, and make sure I can use git diff and git apply to resolve the issue .\n"
f"3. I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.\n"
f"4. use the format as : {PATCH_FORMAT}"
)
requirement, system_messages, cleaned_user_message, issues = _prepare(inputs)
return requirement, system_messages, cleaned_user_message, issues, prompt
@handle_exception(exception_type=Exception)
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
async def run_agent(inputs, agent, **kwargs):
script_names = kwargs.get("script_names", [])
requirement, system_messages, cleaned_user_message, issues, prompt = construct_prompt(inputs, script_names)
system_messages = system_messages.replace(" ", "")
cleaned_user_message = cleaned_user_message.replace(" ", "")
await agent.run([requirement, system_messages, cleaned_user_message, prompt])
return agent.get_last_cell_source()
async def run_instance(instance, use_reflection=True):
ga = GitAgent(use_reflection=use_reflection)
script_names = extract_scripts_from_codetext(instance["text"])
ga.script_names = script_names
patch, repo, repo_path = reset_task_env(instance)
if repo_path is None:
return
response = await run_agent(f"{instance['text']}\n\n", agent=ga, script_names=script_names)
logger.info(f"Final response: {response}")
save_history(ga)
return response

View file

@ -0,0 +1,114 @@
import json
from pathlib import Path
import fire
from tqdm.auto import tqdm
from data.load_dataset import load_oracle_dataset
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils import count_string_tokens
from swe_bench.inference.run_agent import run_instance
from swe_bench.utils.utils import check_existing_ids, extract_diff
# Replace with your own
MAX_TOKEN = 128000
async def openai_inference(
test_dataset,
model_name_or_path,
output_file,
existing_ids,
use_reflection,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
model_name_or_path (str): The name or path of the model to use.
output_file (str): The path to the output file.
existing_ids (set): A set of ids that have already been processed.
"""
test_dataset = test_dataset.filter(
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
desc="Filtering",
load_from_cache_file=False,
)
basic_args = {
"model_name_or_path": model_name_or_path,
}
logger.info(f"Filtered to {len(test_dataset)} instances")
data = []
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
version = datum["version"]
repo = datum["repo"]
repo_prefix = repo.replace("/", "__")
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
logger.info(f"{repo_prefix}_{version}")
data.append(f"{repo_prefix}_{version}")
# import pdb;pdb.set_trace()
response = await run_instance(instance=datum)
if response is None:
continue
logger.info(f"Final response: {response}")
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)
# print(data)
async def main(
dataset_name_or_path,
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
use_reflection=True,
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
Args:
dataset_name_or_path: HuggingFace dataset name or local path
split: Dataset split to use (default: test)
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
print(output_file.absolute())
output_file.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Will write to {output_file}")
# check existing results
existing_ids = check_existing_ids(output_file)
# load dataset
dataset = load_oracle_dataset(dataset_name_or_path)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
"use_reflection": use_reflection,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info("Done!")
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,3 @@
# -*- coding: utf-8 -*-
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :