diff --git a/data/inference/const.py b/data/inference/const.py index 42e63ec0e..62d96fe69 100644 --- a/data/inference/const.py +++ b/data/inference/const.py @@ -3,10 +3,11 @@ # @Desc : import pandas as pd -from metagpt.const import METAGPT_ROOT +from metagpt.const import DATA_PATH, METAGPT_ROOT SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv" SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv" +TESTBED = DATA_PATH / "repos" # SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET. # This collection represents a subset specifically related to scikit-learn content. diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 7882f13e7..00dbdd0e0 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -10,7 +10,9 @@ from make_datasets.utils import extract_diff from tenacity import retry, stop_after_attempt, wait_random_exponential from tqdm.auto import tqdm -from data.inference.const import SCIKIT_LEARN_IDS +from data.inference.const import SCIKIT_LEARN_IDS, TESTBED +from data.inference.make_datasets.parse_diff import extract_scripts_from_codetext +from data.inference.make_datasets.repo_utils import EnvManager from metagpt.config2 import config from metagpt.logs import logger from metagpt.roles.di.data_interpreter import DataInterpreter @@ -33,8 +35,12 @@ async def call_chat(inputs, interpreter): requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code." system_messages = inputs.split("\n", 1)[0] user_message = inputs.split("\n", 1)[1] + cleaned_user_message = user_message.split( + "I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file in the following format." + )[0] + try: - await interpreter.run([requirement, system_messages, user_message]) + await interpreter.run([requirement, system_messages, cleaned_user_message]) return interpreter.get_last_cell_source() except Exception as e: logger.error(f"Error: {e}\nInputs: {inputs}") @@ -70,8 +76,18 @@ async def openai_inference( with open(output_file, "a+") as f: for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"): di = DataInterpreter(use_reflection=use_reflection) - instance_id = datum["instance_id"] + env_manager = EnvManager(testbed=TESTBED) + instance_id = datum["instance_id"] + script_names = extract_scripts_from_codetext(datum["text"]) + logger.info(script_names) + repo = datum["repo"] + repo_prefix = repo.replace("/", "__") + repo_path = os.path.join(env_manager.testbed, repo_prefix) + if not os.path.exists(repo_path): + continue + os.chdir(repo_path) + env_manager.reset_task_env(instance=datum) if instance_id in existing_ids: continue output_dict = {"instance_id": instance_id}