1. add testbed path

2. update repo parse and git ops
This commit is contained in:
stellahsr 2024-03-19 23:19:58 +08:00
parent 9a7279bf91
commit 631a2642e8
2 changed files with 21 additions and 4 deletions

View file

@ -3,10 +3,11 @@
# @Desc :
import pandas as pd
from metagpt.const import METAGPT_ROOT
from metagpt.const import DATA_PATH, METAGPT_ROOT
SUBSET_DATASET = METAGPT_ROOT / "sub_swebench_dataset" / "sub_swebench.csv"
SUBSET_DATASET_SKLERARN = METAGPT_ROOT / "sub_swebench_dataset" / "scikit-learn-68.csv"
TESTBED = DATA_PATH / "repos"
# SCIKIT_LEARN_IDS: A list of instance identifiers from 'sub_swebench.csv' within SUBSET_DATASET.
# This collection represents a subset specifically related to scikit-learn content.

View file

@ -10,7 +10,9 @@ from make_datasets.utils import extract_diff
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm
from data.inference.const import SCIKIT_LEARN_IDS
from data.inference.const import SCIKIT_LEARN_IDS, TESTBED
from data.inference.make_datasets.parse_diff import extract_scripts_from_codetext
from data.inference.make_datasets.repo_utils import EnvManager
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
@ -33,8 +35,12 @@ async def call_chat(inputs, interpreter):
requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code."
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
cleaned_user_message = user_message.split(
"I need you to solve this issue by generating a single patch file that I can apply directly to this repository using git apply. Please respond with a single patch file in the following format."
)[0]
try:
await interpreter.run([requirement, system_messages, user_message])
await interpreter.run([requirement, system_messages, cleaned_user_message])
return interpreter.get_last_cell_source()
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
@ -70,8 +76,18 @@ async def openai_inference(
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter(use_reflection=use_reflection)
instance_id = datum["instance_id"]
env_manager = EnvManager(testbed=TESTBED)
instance_id = datum["instance_id"]
script_names = extract_scripts_from_codetext(datum["text"])
logger.info(script_names)
repo = datum["repo"]
repo_prefix = repo.replace("/", "__")
repo_path = os.path.join(env_manager.testbed, repo_prefix)
if not os.path.exists(repo_path):
continue
os.chdir(repo_path)
env_manager.reset_task_env(instance=datum)
if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}