Merge pull request #1048 from stellaHSR/swebench_di

Add basic Git options for data interpreter and target file name extraction utilities
This commit is contained in:
Alexander Wu 2024-03-21 15:40:40 +08:00 committed by GitHub
commit d02ea95abd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 152 additions and 4 deletions

View file

@ -3,10 +3,11 @@
# @Desc :
import pandas as pd
from metagpt.const import METAGPT_ROOT
from metagpt.const import DATA_PATH, METAGPT_ROOT
SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"
TESTBED = DATA_PATH / "repos"
# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
# This collection represents a subset specifically related to scikit-learn content.

View file

@ -0,0 +1,38 @@
import re
def extract_scripts_from_codetext(codetext: str):
"""
Extracts Python script file names from a given text that contains multiple sections.
Each section starts with '[start of <script_name>.py]' and ends with '[end of <script_name>.py]'.
Parameters:
- codetext (str): A string that may contain multiple sections, each indicating the start of a Python script file.
Returns:
- list: A list of extracted Python script file names.
Example of codetext:
'''
[end of README.rst]
[start of sklearn/compose/_target.py]
... file content ...
[end of sklearn/compose/_target.py]
[start of another_module/example.py]
... file content ...
[end of another_module/example.py]
'''
"""
script_names = []
# Match all occurrences of '[start of <script_name>.py]'
matches = re.findall(r"\[start of ([^\]]+\.py)\]", codetext)
if matches:
for script_name in matches:
print("Extracted script name:", script_name)
script_names.append(script_name)
else:
print("No script names found in the text.")
return script_names

View file

@ -0,0 +1,93 @@
import os
import shutil
import subprocess
from pathlib import Path
from typing import Dict
import git
from git.exc import GitError
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
KEY_INSTANCE_ID = "instance_id"
RESET_FAILED = ">>>>> Reset Failed"
class ExecWrapper:
def __init__(self, subprocess_args: Dict = None):
self.subprocess_args = subprocess_args or {}
@handle_exception(exception_type=subprocess.CalledProcessError)
def __call__(self, cmd, raise_error=True, **kwargs):
combined_args = {**self.subprocess_args, **kwargs}
output = subprocess.run(cmd, **combined_args)
return output
class EnvManager:
def __init__(self, testbed):
shellenv = os.environ.copy()
self.testbed = testbed
self.exec = ExecWrapper(
subprocess_args={
"check": True,
"shell": False,
"capture_output": True,
"text": True,
"env": shellenv,
}
)
@handle_exception(exception_type=GitError)
def clone_repo(self, repo_name: str, path: str, token: str = None):
if token is None:
token = os.environ.get("GITHUB_TOKEN", "git")
if not token:
raise ValueError("GitHub token is required for cloning repositories.")
repo_url = f"https://{token}@github.com/swe-bench/{repo_name.replace('/', '__')}.git"
os.makedirs(path, exist_ok=True)
# Clone the repository
git.Repo.clone_from(repo_url, path)
logger.info(f"Repository '{repo_name}' cloned successfully.")
@handle_exception(exception_type=Exception) # Using a broad exception type for the example
def copy_repo(self, source_path: str, destination_path: str):
if not os.path.isdir(source_path):
raise ValueError("Source path does not exist or is not a directory.")
os.makedirs(destination_path, exist_ok=True)
# Copy the repository
try:
shutil.copytree(
source_path, destination_path, dirs_exist_ok=True
) # For Python 3.8+, dirs_exist_ok handles existing directories
except TypeError:
# Fallback for Python < 3.8, where dirs_exist_ok is not available
if os.listdir(destination_path): # If destination is not empty
raise ValueError("Destination directory is not empty and dirs_exist_ok is not supported.")
shutil.copytree(source_path, destination_path)
logger.info(f"Repository contents from '{source_path}' copied successfully to '{destination_path}'.")
@handle_exception(exception_type=Exception, default_return=False)
def reset_task_env(self, instance: Dict):
"""
Reset task environment + testbed and checkout base commit of given task instance
"""
gitignore_path = Path(".gitignore")
if gitignore_path.exists():
self.exec(["git", "ls-files", "--ignored", "--exclude-standard", "-o", "-z"], raise_error=False)
# fixme: need detect platform and change this cmd
# self.exec(["xargs", "-0", "-r", "rm", "-rf"], input=gitignore_path.read_text())
self.exec(["git", "restore", "."])
self.exec(["git", "reset", "HEAD", "."])
self.exec(["git", "clean", "-fdx"])
self.exec(["git", "-c", "advice.detachedHead=false", "checkout", instance["base_commit"]])
logger.info(f"[{instance['instance_id']}] Reset task environment to {instance['base_commit']}")
return True

View file

@ -10,7 +10,9 @@ from make_datasets.utils import extract_diff
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm
from data.inference.const import SCIKIT_LEARN_IDS
from data.inference.const import SCIKIT_LEARN_IDS, TESTBED
from data.inference.make_datasets.parse_diff import extract_scripts_from_codetext
from data.inference.make_datasets.repo_utils import EnvManager
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
@ -33,8 +35,12 @@ async def call_chat(inputs, interpreter):
requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code."
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
cleaned_user_message = user_message.split(
"I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file in the following format."
)[0]
try:
await interpreter.run([requirement, system_messages, user_message])
await interpreter.run([requirement, system_messages, cleaned_user_message])
return interpreter.get_last_cell_source()
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
@ -70,8 +76,18 @@ async def openai_inference(
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter(use_reflection=use_reflection)
instance_id = datum["instance_id"]
env_manager = EnvManager(testbed=TESTBED)
instance_id = datum["instance_id"]
script_names = extract_scripts_from_codetext(datum["text"])
logger.info(script_names)
repo = datum["repo"]
repo_prefix = repo.replace("/", "__")
repo_path = os.path.join(env_manager.testbed, repo_prefix)
if not os.path.exists(repo_path):
continue
os.chdir(repo_path)
env_manager.reset_task_env(instance=datum)
if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}