mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
commit
19ccec7080
16 changed files with 653 additions and 253 deletions
|
|
@ -28,22 +28,24 @@ SCIKIT_LEARN_IDS = [
|
|||
"scikit-learn__scikit-learn-10459",
|
||||
]
|
||||
|
||||
MATPLOTLIB_IDS = [
|
||||
"matplotlib__matplotlib-24362",
|
||||
"matplotlib__matplotlib-20584",
|
||||
"matplotlib__matplotlib-23188",
|
||||
"matplotlib__matplotlib-24403",
|
||||
# 'matplotlib__matplotlib-21443',
|
||||
# 'matplotlib__matplotlib-23047'
|
||||
]
|
||||
|
||||
def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
|
||||
|
||||
def read_subset_instance(path=SUBSET_DATASET, tag="scikit-learn"):
|
||||
try:
|
||||
df = pd.read_excel(path)
|
||||
# Filter for instances containing the tag in either column
|
||||
pass_filter = df["instance_id_pass"].str.contains(tag, na=False)
|
||||
fail_filter = df["instance_id_fail"].str.contains(tag, na=False)
|
||||
|
||||
# Combine the filters using | (OR operator) for efficiency
|
||||
combined_filter = pass_filter | fail_filter
|
||||
|
||||
# Apply combined filter and select the specific columns
|
||||
filtered_df = df[combined_filter][["instance_id_pass", "instance_id_fail"]]
|
||||
|
||||
# Flatten the DataFrame into a list and remove NaN values
|
||||
subset_instance = filtered_df.stack().dropna().tolist()
|
||||
pass_filters = df["instance_id_pass"].tolist()
|
||||
fail_filters = df["instance_id_fail"].tolist()
|
||||
pass_filters = [s for s in pass_filters if tag in s]
|
||||
fail_filters = [s for s in fail_filters if tag in s]
|
||||
subset_instance = pass_filters + fail_filters
|
||||
|
||||
return subset_instance
|
||||
except FileNotFoundError:
|
||||
|
|
@ -52,3 +54,7 @@ def read_sub_set_instance(path=SUBSET_DATASET, tag="scikit-learn"):
|
|||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(read_subset_instance(tag="matplotlib__matplotlib"))
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
import re
|
||||
|
||||
|
||||
def extract_scripts_from_codetext(codetext: str):
|
||||
"""
|
||||
Extracts Python script file names from a given text that contains multiple sections.
|
||||
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
|
||||
|
||||
Parameters:
|
||||
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
|
||||
|
||||
Returns:
|
||||
- list: A list of extracted Python script file names.
|
||||
|
||||
Example of codetext:
|
||||
'''
|
||||
[end of README.rst]
|
||||
[start of sklearn/compose/_target.py]
|
||||
... file content ...
|
||||
[end of sklearn/compose/_target.py]
|
||||
[start of another_module/example.py]
|
||||
... file content ...
|
||||
[end of another_module/example.py]
|
||||
'''
|
||||
"""
|
||||
script_names = []
|
||||
|
||||
# Match all occurrences of '[start of <script_name>.py]'
|
||||
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
|
||||
|
||||
if matches:
|
||||
for script_name in matches:
|
||||
print("Extracted script name:", script_name)
|
||||
script_names.append(script_name)
|
||||
else:
|
||||
print("No script names found in the text.")
|
||||
|
||||
return script_names
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
import re
|
||||
|
||||
|
||||
def extract_diff(response):
|
||||
"""
|
||||
Extracts the diff from a response formatted in different ways
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
diff_matches = []
|
||||
other_matches = []
|
||||
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
|
||||
for code, match in pattern.findall(response):
|
||||
if code in {"diff", "patch"}:
|
||||
diff_matches.append(match)
|
||||
else:
|
||||
other_matches.append(match)
|
||||
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
|
||||
for code, match in pattern.findall(response):
|
||||
if code in {"diff", "patch"}:
|
||||
diff_matches.append(match)
|
||||
else:
|
||||
other_matches.append(match)
|
||||
if diff_matches:
|
||||
return diff_matches[0]
|
||||
if other_matches:
|
||||
return other_matches[0]
|
||||
return response.split("</s>")[0]
|
||||
|
|
@ -1,173 +0,0 @@
|
|||
import json
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from make_datasets.utils import extract_diff
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.inference.const import SCIKIT_LEARN_IDS, TESTBED
|
||||
from data.inference.make_datasets.parse_diff import extract_scripts_from_codetext
|
||||
from data.inference.make_datasets.repo_utils import EnvManager
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.utils import count_string_tokens
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
|
||||
# Replace with your own
|
||||
MAX_TOKEN = 128000
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
|
||||
async def call_chat(inputs, interpreter):
|
||||
"""
|
||||
Calls the openai API to generate completions for the given inputs.
|
||||
|
||||
Args:
|
||||
inputs (str): The inputs to generate completions for.
|
||||
interpreter (DataInterpreter): The data interpreter to use for execution.
|
||||
"""
|
||||
requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code."
|
||||
system_messages = inputs.split("\n", 1)[0]
|
||||
user_message = inputs.split("\n", 1)[1]
|
||||
cleaned_user_message = user_message.split(
|
||||
"I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file in the following format."
|
||||
)[0]
|
||||
|
||||
try:
|
||||
await interpreter.run([requirement, system_messages, cleaned_user_message])
|
||||
return interpreter.get_last_cell_source()
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}\nInputs: {inputs}")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
async def openai_inference(
|
||||
test_dataset,
|
||||
model_name_or_path,
|
||||
output_file,
|
||||
existing_ids,
|
||||
use_reflection,
|
||||
):
|
||||
"""
|
||||
Runs inference on a dataset using the openai API.
|
||||
|
||||
Args:
|
||||
test_dataset (datasets.Dataset): The dataset to run inference on.
|
||||
model_name_or_path (str): The name or path of the model to use.
|
||||
output_file (str): The path to the output file.
|
||||
existing_ids (set): A set of ids that have already been processed.
|
||||
"""
|
||||
test_dataset = test_dataset.filter(
|
||||
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
|
||||
desc="Filtering",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
basic_args = {
|
||||
"model_name_or_path": model_name_or_path,
|
||||
}
|
||||
print(f"Filtered to {len(test_dataset)} instances")
|
||||
with open(output_file, "a+") as f:
|
||||
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
|
||||
di = DataInterpreter(use_reflection=use_reflection)
|
||||
env_manager = EnvManager(testbed=TESTBED)
|
||||
|
||||
instance_id = datum["instance_id"]
|
||||
script_names = extract_scripts_from_codetext(datum["text"])
|
||||
logger.info(script_names)
|
||||
repo = datum["repo"]
|
||||
repo_prefix = repo.replace("/", "__")
|
||||
repo_path = os.path.join(env_manager.testbed, repo_prefix)
|
||||
if not os.path.exists(repo_path):
|
||||
continue
|
||||
os.chdir(repo_path)
|
||||
env_manager.reset_task_env(instance=datum)
|
||||
if instance_id in existing_ids:
|
||||
continue
|
||||
output_dict = {"instance_id": instance_id}
|
||||
output_dict.update(basic_args)
|
||||
output_dict["text"] = f"{datum['text']}\n\n"
|
||||
response = await call_chat(
|
||||
output_dict["text"],
|
||||
di,
|
||||
)
|
||||
logger.info(f"Final response: {response}")
|
||||
save_history(di)
|
||||
output_dict["full_output"] = response
|
||||
output_dict["model_patch"] = extract_diff(response)
|
||||
print(json.dumps(output_dict), file=f, flush=True)
|
||||
|
||||
|
||||
async def main(
|
||||
dataset_name_or_path,
|
||||
split="test",
|
||||
model_name_or_path=config.llm.model,
|
||||
output_dir="outputs",
|
||||
use_reflection=True,
|
||||
):
|
||||
"""
|
||||
Performs inference on SWE-bench dataset using the Data Interpreter.
|
||||
|
||||
Args:
|
||||
dataset_name_or_path: HuggingFace dataset name or local path
|
||||
split: Dataset split to use (default: test)
|
||||
model_name_or_path: Name of the model to use (default: config.llm.model)
|
||||
param output_dir: Path to the output directory (default: outputs)
|
||||
"""
|
||||
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
|
||||
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
|
||||
output_file = Path(output_dir, output_file + ".jsonl")
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Will write to {output_file}")
|
||||
existing_ids = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
instance_id = data["instance_id"]
|
||||
existing_ids.add(instance_id)
|
||||
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
|
||||
if Path(dataset_name_or_path).exists():
|
||||
dataset = load_from_disk(dataset_name_or_path)
|
||||
else:
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
if split not in dataset:
|
||||
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
|
||||
dataset = dataset[split]
|
||||
lens = np.array(list(map(len, dataset["text"])))
|
||||
dataset = dataset.select(np.argsort(lens))
|
||||
|
||||
if len(existing_ids) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] not in existing_ids,
|
||||
desc="Filtering out existing ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if len(SCIKIT_LEARN_IDS) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] in SCIKIT_LEARN_IDS,
|
||||
desc="Filtering out subset_instance_ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
inference_args = {
|
||||
"test_dataset": dataset,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
"output_file": output_file,
|
||||
"existing_ids": existing_ids,
|
||||
"use_reflection": use_reflection,
|
||||
}
|
||||
if model_name_or_path.startswith("gpt"):
|
||||
await openai_inference(**inference_args)
|
||||
else:
|
||||
raise ValueError(f"Invalid model name or path {model_name_or_path}")
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
3
swe_bench/__init__.py
Normal file
3
swe_bench/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
99
swe_bench/gitagent.py
Normal file
99
swe_bench/gitagent.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from typing import Literal, Union
|
||||
|
||||
from metagpt.actions.di.ask_review import ReviewConst
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class GitAgent(DataInterpreter):
|
||||
name: str = "Jacky"
|
||||
profile: str = "Solve git issues proficiently"
|
||||
auto_run: bool = True
|
||||
use_plan: bool = True
|
||||
use_reflection: bool = False
|
||||
react_mode: Literal["plan_and_act", "react"] = "react"
|
||||
script_names: Union[str, list[str]] = []
|
||||
instance_id: str = ""
|
||||
|
||||
async def critique(self, result, review_format):
|
||||
review_result = (
|
||||
"Finally, return a boolean value (True or False) to indicate the result of the review. "
|
||||
"Note: If the result is good enough, return True; otherwise, return False."
|
||||
)
|
||||
status = await self.llm.aask(
|
||||
[
|
||||
Message(content=review_format, role="user"),
|
||||
Message(content=result, role="assistant"),
|
||||
Message(content=review_result, role="user"),
|
||||
]
|
||||
)
|
||||
logger.info(status)
|
||||
|
||||
return status
|
||||
|
||||
async def review_patch(self, code):
|
||||
review_format = (
|
||||
"Please ensure that the code {code} and original script {original_script} can fix the issue {memory} in patch format. "
|
||||
"If it is not in patch format, please convert it to patch format."
|
||||
)
|
||||
|
||||
results = []
|
||||
for script in self.script_names:
|
||||
with open(script, "r", encoding="utf-8") as fp:
|
||||
original_script = fp.read()
|
||||
|
||||
memory = self.get_memories()[0].content
|
||||
review_prompt = review_format.format(code=code, original_script=original_script, memory=memory)
|
||||
# todo: extract issue and remove image urls
|
||||
result = await self.llm.aask(review_prompt)
|
||||
|
||||
results.append(result)
|
||||
# fixme: merge results to a single patch file
|
||||
result = "\n".join(results)
|
||||
|
||||
return result, review_prompt
|
||||
|
||||
async def _write_and_exec_code(self, max_retry: int = 3):
|
||||
counter = 0
|
||||
success = False
|
||||
|
||||
# plan info
|
||||
plan_status = self.planner.get_plan_status() if self.use_plan else ""
|
||||
|
||||
# tool info
|
||||
if self.tools:
|
||||
context = (
|
||||
self.working_memory.get()[-1].content if self.working_memory.get() else ""
|
||||
) # thoughts from _think stage in 'react' mode
|
||||
plan = self.planner.plan if self.use_plan else None
|
||||
tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
|
||||
else:
|
||||
tool_info = ""
|
||||
|
||||
while not success and counter < max_retry:
|
||||
### write code ###
|
||||
code, cause_by = await self._write_code(counter, plan_status, tool_info)
|
||||
|
||||
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
|
||||
|
||||
result, format_prompt = await self.review_patch(code)
|
||||
|
||||
success = await self.critique(result, format_prompt)
|
||||
await self.execute_code.run(code)
|
||||
### execute code ###
|
||||
# todo: execute: git apply
|
||||
|
||||
### process execution result ###
|
||||
counter += 1
|
||||
|
||||
if not success and counter >= max_retry:
|
||||
logger.info("coding failed!")
|
||||
review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
|
||||
if ReviewConst.CHANGE_WORDS[0] in review:
|
||||
counter = 0 # redo the task again with help of human suggestions
|
||||
|
||||
return code, result, success
|
||||
3
swe_bench/inference/__init__.py
Normal file
3
swe_bench/inference/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
|
@ -8,7 +8,8 @@ original_argv = sys.argv.copy()
|
|||
|
||||
try:
|
||||
# 设置你想要传递给脚本的命令行参数
|
||||
sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"]
|
||||
dataset_path = "SWE-bench_oracle" # "SWE-bench_bm25_27K" # "SWE-bench_13k"
|
||||
sys.argv = ["run_api.py", "--dataset_name_or_path", f"princeton-nlp/{dataset_path}", "--output_dir", "./outputs"]
|
||||
# 执行脚本
|
||||
runpy.run_path(path_name="run_api.py", run_name="__main__")
|
||||
finally:
|
||||
74
swe_bench/inference/run_agent.py
Normal file
74
swe_bench/inference/run_agent.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
from swe_bench.gitagent import GitAgent
|
||||
from swe_bench.make_datasets.make_dataset import reset_task_env
|
||||
from swe_bench.utils.utils import extract_scripts_from_codetext
|
||||
|
||||
PATCH_FORMAT = """
|
||||
```diff
|
||||
--- original_file.py
|
||||
+++ modified_file.py
|
||||
@@ -line_number,context_lines +line_number,context_lines @@
|
||||
- original line of code to be replaced or removed
|
||||
+ new line of code to be added or to replace the original
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def _prepare(inputs):
|
||||
requirement = "Please rewrite the code to address the issues. "
|
||||
system_messages = inputs.split("\n", 1)[0]
|
||||
user_message = inputs.split("\n", 1)[1]
|
||||
cleaned_user_message = re.sub("<patch>.*?</patch>", "", user_message, flags=re.DOTALL)
|
||||
|
||||
issues = re.findall("<issue>(.*?)</issue>", user_message, flags=re.DOTALL)
|
||||
|
||||
return requirement, system_messages, cleaned_user_message, issues
|
||||
|
||||
|
||||
def construct_prompt(inputs, script_names):
|
||||
prompt = (
|
||||
f"You only need to modify the code file listed here {script_names}."
|
||||
f"Notice: "
|
||||
f"1. Analysis the issue, especially for the ValueError, and identify influence code lines.\n"
|
||||
f"2. Only change a few lines, and make sure I can use git diff and git apply to resolve the issue .\n"
|
||||
f"3. I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply.\n"
|
||||
f"4. use the format as : {PATCH_FORMAT}"
|
||||
)
|
||||
|
||||
requirement, system_messages, cleaned_user_message, issues = _prepare(inputs)
|
||||
return requirement, system_messages, cleaned_user_message, issues, prompt
|
||||
|
||||
|
||||
@handle_exception(exception_type=Exception)
|
||||
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
|
||||
async def run_agent(inputs, agent, **kwargs):
|
||||
script_names = kwargs.get("script_names", [])
|
||||
requirement, system_messages, cleaned_user_message, issues, prompt = construct_prompt(inputs, script_names)
|
||||
system_messages = system_messages.replace(" ", "")
|
||||
cleaned_user_message = cleaned_user_message.replace(" ", "")
|
||||
await agent.run([requirement, system_messages, cleaned_user_message, prompt])
|
||||
return agent.get_last_cell_source()
|
||||
|
||||
|
||||
async def run_instance(instance, use_reflection=True):
|
||||
ga = GitAgent(use_reflection=use_reflection)
|
||||
script_names = extract_scripts_from_codetext(instance["text"])
|
||||
ga.script_names = script_names
|
||||
|
||||
patch, repo, repo_path = reset_task_env(instance)
|
||||
if repo_path is None:
|
||||
return
|
||||
|
||||
response = await run_agent(f"{instance['text']}\n\n", agent=ga, script_names=script_names)
|
||||
logger.info(f"Final response: {response}")
|
||||
save_history(ga)
|
||||
return response
|
||||
112
swe_bench/inference/run_api.py
Normal file
112
swe_bench/inference/run_api.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
from data.load_dataset import load_oracle_dataset
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils import count_string_tokens
|
||||
from swe_bench.inference.run_agent import run_instance
|
||||
from swe_bench.utils.utils import check_existing_ids, extract_diff
|
||||
|
||||
# Replace with your own
|
||||
MAX_TOKEN = 128000
|
||||
|
||||
|
||||
async def openai_inference(
|
||||
test_dataset,
|
||||
model_name_or_path,
|
||||
output_file,
|
||||
existing_ids,
|
||||
use_reflection,
|
||||
):
|
||||
"""
|
||||
Runs inference on a dataset using the openai API.
|
||||
|
||||
Args:
|
||||
test_dataset (datasets.Dataset): The dataset to run inference on.
|
||||
model_name_or_path (str): The name or path of the model to use.
|
||||
output_file (str): The path to the output file.
|
||||
existing_ids (set): A set of ids that have already been processed.
|
||||
"""
|
||||
test_dataset = test_dataset.filter(
|
||||
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
|
||||
desc="Filtering",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
basic_args = {
|
||||
"model_name_or_path": model_name_or_path,
|
||||
}
|
||||
logger.info(f"Filtered to {len(test_dataset)} instances")
|
||||
data = []
|
||||
with open(output_file, "a+") as f:
|
||||
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
|
||||
instance_id = datum["instance_id"]
|
||||
|
||||
if instance_id in existing_ids:
|
||||
continue
|
||||
version = datum["version"]
|
||||
repo = datum["repo"]
|
||||
repo_prefix = repo.replace("/", "__")
|
||||
output_dict = {"instance_id": instance_id}
|
||||
output_dict.update(basic_args)
|
||||
output_dict["text"] = f"{datum['text']}\n\n"
|
||||
logger.info(f"{repo_prefix}_{version}")
|
||||
data.append(f"{repo_prefix}_{version}")
|
||||
|
||||
response = await run_instance(instance=datum)
|
||||
if response is None:
|
||||
continue
|
||||
logger.info(f"Final response: {response}")
|
||||
|
||||
output_dict["full_output"] = response
|
||||
output_dict["model_patch"] = extract_diff(response)
|
||||
print(json.dumps(output_dict), file=f, flush=True)
|
||||
|
||||
|
||||
async def main(
|
||||
dataset_name_or_path,
|
||||
split="test",
|
||||
model_name_or_path=config.llm.model,
|
||||
output_dir="outputs",
|
||||
use_reflection=True,
|
||||
):
|
||||
"""
|
||||
Performs inference on SWE-bench dataset using the Data Interpreter.
|
||||
|
||||
Args:
|
||||
dataset_name_or_path: HuggingFace dataset name or local path
|
||||
split: Dataset split to use (default: test)
|
||||
model_name_or_path: Name of the model to use (default: config.llm.model)
|
||||
param output_dir: Path to the output directory (default: outputs)
|
||||
"""
|
||||
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
|
||||
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
|
||||
output_file = Path(output_dir, output_file + ".jsonl")
|
||||
print(output_file.absolute())
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Will write to {output_file}")
|
||||
|
||||
# check existing results
|
||||
existing_ids = check_existing_ids(output_file)
|
||||
# load dataset
|
||||
dataset = load_oracle_dataset(dataset_name_or_path)
|
||||
|
||||
inference_args = {
|
||||
"test_dataset": dataset,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
"output_file": output_file,
|
||||
"existing_ids": existing_ids,
|
||||
"use_reflection": use_reflection,
|
||||
}
|
||||
if model_name_or_path.startswith("gpt"):
|
||||
await openai_inference(**inference_args)
|
||||
else:
|
||||
raise ValueError(f"Invalid model name or path {model_name_or_path}")
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
3
swe_bench/make_datasets/__init__.py
Normal file
3
swe_bench/make_datasets/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
61
swe_bench/make_datasets/make_dataset.py
Normal file
61
swe_bench/make_datasets/make_dataset.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from data.inference.const import TESTBED
|
||||
from metagpt.logs import logger
|
||||
from swe_bench.make_datasets.make_instance import prompt_style_2_edits_only
|
||||
from swe_bench.utils.parse_diff import filter_changed_line
|
||||
from swe_bench.utils.repo_utils import EnvManager
|
||||
|
||||
|
||||
def reset_task_env(instance: dict = {}):
|
||||
# reset the env via git reset and git checkout
|
||||
env_manager = EnvManager(testbed=TESTBED)
|
||||
|
||||
patch = instance["patch"]
|
||||
repo = instance["repo"]
|
||||
instance["version"]
|
||||
repo_prefix = repo.replace("/", "__")
|
||||
repo_path = os.path.join(env_manager.testbed, repo_prefix)
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
return patch, repo, None
|
||||
os.chdir(repo_path)
|
||||
if not env_manager.reset_task_env(instance=instance):
|
||||
return patch, repo, None
|
||||
|
||||
return patch, repo, repo_path
|
||||
|
||||
|
||||
def reset_and_copy(instance: dict = {}):
|
||||
patch, repo, repo_path = reset_task_env(instance)
|
||||
if repo_path is None:
|
||||
return
|
||||
env_manager = EnvManager(testbed=TESTBED)
|
||||
repo_prefix = repo.replace("/", "__")
|
||||
version = instance["version"]
|
||||
destination_path = os.path.join(repo_path, f"{repo_prefix}__{version}")
|
||||
env_manager.copy_repo(source_path=repo_path, destination_path=destination_path)
|
||||
|
||||
|
||||
def make_oracle_collapsed_instance(instance):
|
||||
# for each instance, reset task env
|
||||
patch, repo, repo_path = reset_task_env(instance)
|
||||
if repo_path is None:
|
||||
return
|
||||
file_changes = filter_changed_line(patch)
|
||||
prompt = prompt_style_2_edits_only(instance, Path(repo_path), list(file_changes.keys()))
|
||||
logger.info(prompt)
|
||||
# todo: save output
|
||||
return {}
|
||||
|
||||
|
||||
def make_oracle_collapsed_dataset(dataset):
|
||||
for datum in tqdm(dataset, desc="Inference "):
|
||||
make_oracle_collapsed_instance(instance=datum)
|
||||
# todo: save output
|
||||
193
swe_bench/make_datasets/make_instance.py
Normal file
193
swe_bench/make_datasets/make_instance.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
from pathlib import Path
|
||||
|
||||
import unidiff
|
||||
|
||||
PATCH_EXAMPLE = """--- a/file.py
|
||||
+++ b/file.py
|
||||
@@ -1,27 +1,35 @@
|
||||
def euclidean(a, b):
|
||||
- while b:
|
||||
- a, b = b, a % b
|
||||
- return a
|
||||
+ if b == 0:
|
||||
+ return a
|
||||
+ return euclidean(b, a % b)
|
||||
|
||||
|
||||
def bresenham(x0, y0, x1, y1):
|
||||
points = []
|
||||
dx = abs(x1 - x0)
|
||||
dy = abs(y1 - y0)
|
||||
- sx = 1 if x0 < x1 else -1
|
||||
- sy = 1 if y0 < y1 else -1
|
||||
- err = dx - dy
|
||||
+ x, y = x0, y0
|
||||
+ sx = -1 if x0 > x1 else 1
|
||||
+ sy = -1 if y0 > y1 else 1
|
||||
|
||||
- while True:
|
||||
- points.append((x0, y0))
|
||||
- if x0 == x1 and y0 == y1:
|
||||
- break
|
||||
- e2 = 2 * err
|
||||
- if e2 > -dy:
|
||||
+ if dx > dy:
|
||||
+ err = dx / 2.0
|
||||
+ while x != x1:
|
||||
+ points.append((x, y))
|
||||
err -= dy
|
||||
- x0 += sx
|
||||
- if e2 < dx:
|
||||
- err += dx
|
||||
- y0 += sy
|
||||
+ if err < 0:
|
||||
+ y += sy
|
||||
+ err += dx
|
||||
+ x += sx
|
||||
+ else:
|
||||
+ err = dy / 2.0
|
||||
+ while y != y1:
|
||||
+ points.append((x, y))
|
||||
+ err -= dx
|
||||
+ if err < 0:
|
||||
+ x += sx
|
||||
+ err += dy
|
||||
+ y += sy
|
||||
|
||||
+ points.append((x, y))
|
||||
return points"""
|
||||
|
||||
FULL_GENERATION_EXAMPLE = """[start of /src/this_file.py]
|
||||
import os
|
||||
|
||||
def euclidean(a, b):
|
||||
if b == 0:
|
||||
return a
|
||||
return euclidean(b, a % b)
|
||||
[end of /src/this_file.py]
|
||||
[start of /src/another_file.py]
|
||||
def bresenham(x0, y0, x1, y1):
|
||||
points = []
|
||||
dx = abs(x1 - x0)
|
||||
dy = abs(y1 - y0)
|
||||
x, y = x0, y0
|
||||
sx = -1 if x0 > x1 else 1
|
||||
sy = -1 if y0 > y1 else 1
|
||||
if dx > dy:
|
||||
err = dx / 2.0
|
||||
while x != x1:
|
||||
points.append((x, y))
|
||||
err -= dy
|
||||
if err < 0:
|
||||
y += sy
|
||||
err += dx
|
||||
x += sx
|
||||
else:
|
||||
err = dy / 2.0
|
||||
while y != y1:
|
||||
points.append((x
|
||||
err -= dx
|
||||
if err < 0:
|
||||
x += sx
|
||||
err += dy
|
||||
y += sy
|
||||
points.append((x, y))
|
||||
return points
|
||||
[end of /src/another_file.py]"""
|
||||
|
||||
|
||||
def add_lines_list(content):
|
||||
content_with_lines = list()
|
||||
for ix, line in enumerate(content.split("\n"), start=1):
|
||||
content_with_lines.append(f"{ix} {line}")
|
||||
return content_with_lines
|
||||
|
||||
|
||||
def add_lines(content):
|
||||
return "\n".join(add_lines_list(content))
|
||||
|
||||
|
||||
def make_code_text(files_dict, add_line_numbers=True):
|
||||
all_text = ""
|
||||
for filename, contents in sorted(files_dict.items()):
|
||||
all_text += f"[start of {filename}]\n"
|
||||
if add_line_numbers:
|
||||
all_text += add_lines(contents)
|
||||
else:
|
||||
all_text += contents
|
||||
all_text += f"\n[end of {filename}]\n"
|
||||
return all_text.strip("\n")
|
||||
|
||||
|
||||
def make_code_text_edits_only(files_dict, patch, root_path, add_line_numbers=True):
|
||||
files = dict()
|
||||
patch = unidiff.PatchSet(patch)
|
||||
for patched_file in patch:
|
||||
source_file = root_path / patched_file.source_file.split("a/", 1)[-1]
|
||||
files[source_file] = list()
|
||||
for hunk in patched_file:
|
||||
start = hunk.source_start - 15
|
||||
end = start + hunk.source_length + 15
|
||||
files[source_file].append((start, end))
|
||||
all_text = ""
|
||||
for filename, content in files_dict.items():
|
||||
# filename = str(filename)
|
||||
all_text += f"[start of {filename}]\n"
|
||||
content_with_lines = add_lines_list(content)
|
||||
for start, end in files[filename]:
|
||||
if start > 0:
|
||||
all_text += "...\n"
|
||||
all_text += "\n".join(content_with_lines[start:end])
|
||||
all_text += "\n"
|
||||
if end < len(content_with_lines):
|
||||
all_text += "...\n"
|
||||
all_text = all_text.strip("\n")
|
||||
all_text += f"\n[end of {filename}]\n"
|
||||
return all_text.strip("\n")
|
||||
|
||||
|
||||
def prompt_style_2_edits_only(instance, root_path, filenames):
|
||||
premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
|
||||
|
||||
readmes = get_readme_files(root_path)
|
||||
instance["readmes"] = ingest_files([root_path / readme for readme in readmes])
|
||||
|
||||
readmes_text = make_code_text(instance["readmes"])
|
||||
instance["file_contents"] = ingest_files([root_path / filename for filename in filenames])
|
||||
code_text = make_code_text_edits_only(instance["file_contents"], instance["patch"], root_path)
|
||||
instructions = (
|
||||
"I need you to solve this issue by generating a single patch file that I can apply "
|
||||
+ "directly to this repository using git apply. Please respond with a single patch "
|
||||
+ "file in the following format."
|
||||
)
|
||||
problem_statement = instance["problem_statement"]
|
||||
final_text = [
|
||||
premise,
|
||||
"<issue>",
|
||||
problem_statement,
|
||||
"</issue>",
|
||||
"<code>",
|
||||
readmes_text,
|
||||
code_text,
|
||||
"</code>",
|
||||
instructions,
|
||||
"<patch>",
|
||||
PATCH_EXAMPLE,
|
||||
"</patch>",
|
||||
]
|
||||
final_text = "\n".join(final_text)
|
||||
return final_text
|
||||
|
||||
|
||||
def ingest_files(file_paths):
|
||||
files_dict = dict()
|
||||
for file_path in file_paths:
|
||||
files_dict[file_path] = Path.read_text(file_path, encoding="utf-8")
|
||||
return files_dict
|
||||
|
||||
|
||||
def get_readme_files(repo_path):
|
||||
path = Path(repo_path)
|
||||
# 检查文件名是否以 "readme" 开头,不区分大小写
|
||||
files = [file.name for file in path.iterdir() if file.is_file() and file.name.lower().startswith("readme")]
|
||||
return files
|
||||
3
swe_bench/utils/__init__.py
Normal file
3
swe_bench/utils/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
81
swe_bench/utils/utils.py
Normal file
81
swe_bench/utils/utils.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def check_existing_ids(output_file):
|
||||
existing_ids = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
instance_id = data["instance_id"]
|
||||
existing_ids.add(instance_id)
|
||||
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
|
||||
return existing_ids
|
||||
|
||||
|
||||
def extract_diff(response):
|
||||
"""
|
||||
Extracts the diff from a response formatted in different ways
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
diff_matches = []
|
||||
other_matches = []
|
||||
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
|
||||
for code, match in pattern.findall(response):
|
||||
if code in {"diff", "patch"}:
|
||||
diff_matches.append(match)
|
||||
else:
|
||||
other_matches.append(match)
|
||||
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
|
||||
for code, match in pattern.findall(response):
|
||||
if code in {"diff", "patch"}:
|
||||
diff_matches.append(match)
|
||||
else:
|
||||
other_matches.append(match)
|
||||
if diff_matches:
|
||||
return diff_matches[0]
|
||||
if other_matches:
|
||||
return other_matches[0]
|
||||
return response.split("</s>")[0]
|
||||
|
||||
|
||||
def extract_scripts_from_codetext(codetext: str):
|
||||
"""
|
||||
Extracts Python script file names from a given text that contains multiple sections.
|
||||
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
|
||||
|
||||
Parameters:
|
||||
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
|
||||
|
||||
Returns:
|
||||
- list: A list of extracted Python script file names.
|
||||
|
||||
Example of codetext:
|
||||
'''
|
||||
[end of README.rst]
|
||||
[start of sklearn/compose/_target.py]
|
||||
... file content ...
|
||||
[end of sklearn/compose/_target.py]
|
||||
[start of another_module/example.py]
|
||||
... file content ...
|
||||
[end of another_module/example.py]
|
||||
'''
|
||||
"""
|
||||
script_names = []
|
||||
|
||||
# Match all occurrences of '[start of <script_name>.py]'
|
||||
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
|
||||
|
||||
if matches:
|
||||
for script_name in matches:
|
||||
print("Extracted script name:", script_name)
|
||||
script_names.append(script_name)
|
||||
else:
|
||||
print("No script names found in the text.")
|
||||
|
||||
return script_names
|
||||
Loading…
Add table
Add a link
Reference in a new issue