Add Inference of SWE-bench for Data Interpreter

This commit is contained in:
mannaandpoem 2024-03-18 14:30:53 +08:00
parent e783e5b208
commit ef3be47f28
4 changed files with 180 additions and 1 deletions

View file

@ -0,0 +1,28 @@
import re
def extract_diff(response):
"""
Extracts the diff from a response formatted in different ways
"""
if response is None:
return None
diff_matches = []
other_matches = []
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
for code, match in pattern.findall(response):
if code in {"diff", "patch"}:
diff_matches.append(match)
else:
other_matches.append(match)
if diff_matches:
return diff_matches[0]
if other_matches:
return other_matches[0]
return response.split("</s>")[0]

147
data/inference/run_api.py Normal file
View file

@ -0,0 +1,147 @@
import json
import os
import traceback
from pathlib import Path
import fire
import numpy as np
from datasets import load_dataset, load_from_disk
from make_datasets.utils import extract_diff
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.utils import count_string_tokens
from metagpt.utils.recovery_util import save_history
# Replace with your own
MAX_TOKEN = 128000
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
async def call_chat(inputs, interpreter):
"""
Calls the openai API to generate completions for the given inputs.
Args:
inputs (str): The inputs to generate completions for.
interpreter (DataInterpreter): The data interpreter to use for execution.
"""
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
try:
await interpreter.run([system_messages, user_message])
return interpreter.get_last_cell_source
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
traceback.print_exc()
raise e
async def openai_inference(
test_dataset,
model_name_or_path,
output_file,
existing_ids,
):
"""
Runs inference on a dataset using the openai API.
Args:
test_dataset (datasets.Dataset): The dataset to run inference on.
model_name_or_path (str): The name or path of the model to use.
output_file (str): The path to the output file.
existing_ids (set): A set of ids that have already been processed.
"""
test_dataset = test_dataset.filter(
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
desc="Filtering",
load_from_cache_file=False,
)
basic_args = {
"model_name_or_path": model_name_or_path,
}
print(f"Filtered to {len(test_dataset)} instances")
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter()
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
output_dict = {"instance_id": instance_id}
output_dict.update(basic_args)
output_dict["text"] = f"{datum['text']}\n\n"
response = await call_chat(
output_dict["text"],
di,
)
logger.info(f"Final response: {response}")
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)
save_history(di)
async def main(
dataset_name_or_path,
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
Args:
dataset_name_or_path: HuggingFace dataset name or local path
split: Dataset split to use (default: test)
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
if config.llm.api_type.value == "azure" and config.llm.model == "gpt-4":
# Actual model name is gpt-4-1106-preview for Azure
model_nickname = "gpt-4-1106-preview"
else:
model_nickname = Path(model_name_or_path).name
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
logger.info(f"Will write to {output_file}")
existing_ids = set()
if os.path.exists(output_file):
with open(output_file, "r") as f:
for line in f:
data = json.loads(line)
instance_id = data["instance_id"]
existing_ids.add(instance_id)
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
if Path(dataset_name_or_path).exists():
dataset = load_from_disk(dataset_name_or_path)
else:
dataset = load_dataset(dataset_name_or_path)
if split not in dataset:
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
dataset = dataset[split]
lens = np.array(list(map(len, dataset["text"])))
dataset = dataset.select(np.argsort(lens))
if len(existing_ids) > 0:
dataset = dataset.filter(
lambda x: x["instance_id"] not in existing_ids,
desc="Filtering out existing ids",
load_from_cache_file=False,
)
inference_args = {
"test_dataset": dataset,
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)
else:
raise ValueError(f"Invalid model name or path {model_name_or_path}")
logger.info("Done!")
if __name__ == "__main__":
fire.Fire(main)