mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 02:23:52 +02:00
Add Inference of SWE-bench for Data Interpreter
This commit is contained in:
parent
e783e5b208
commit
ef3be47f28
4 changed files with 180 additions and 1 deletions
28
data/inference/make_datasets/utils.py
Normal file
28
data/inference/make_datasets/utils.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import re
|
||||
|
||||
|
||||
def extract_diff(response):
|
||||
"""
|
||||
Extracts the diff from a response formatted in different ways
|
||||
"""
|
||||
if response is None:
|
||||
return None
|
||||
diff_matches = []
|
||||
other_matches = []
|
||||
pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
|
||||
for code, match in pattern.findall(response):
|
||||
if code in {"diff", "patch"}:
|
||||
diff_matches.append(match)
|
||||
else:
|
||||
other_matches.append(match)
|
||||
pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
|
||||
for code, match in pattern.findall(response):
|
||||
if code in {"diff", "patch"}:
|
||||
diff_matches.append(match)
|
||||
else:
|
||||
other_matches.append(match)
|
||||
if diff_matches:
|
||||
return diff_matches[0]
|
||||
if other_matches:
|
||||
return other_matches[0]
|
||||
return response.split("</s>")[0]
|
||||
147
data/inference/run_api.py
Normal file
147
data/inference/run_api.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
import json
|
||||
import os
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from make_datasets.utils import extract_diff
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.di.data_interpreter import DataInterpreter
|
||||
from metagpt.utils import count_string_tokens
|
||||
from metagpt.utils.recovery_util import save_history
|
||||
|
||||
# Replace with your own
|
||||
MAX_TOKEN = 128000
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5))
|
||||
async def call_chat(inputs, interpreter):
|
||||
"""
|
||||
Calls the openai API to generate completions for the given inputs.
|
||||
|
||||
Args:
|
||||
inputs (str): The inputs to generate completions for.
|
||||
interpreter (DataInterpreter): The data interpreter to use for execution.
|
||||
"""
|
||||
system_messages = inputs.split("\n", 1)[0]
|
||||
user_message = inputs.split("\n", 1)[1]
|
||||
try:
|
||||
await interpreter.run([system_messages, user_message])
|
||||
return interpreter.get_last_cell_source
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}\nInputs: {inputs}")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
async def openai_inference(
|
||||
test_dataset,
|
||||
model_name_or_path,
|
||||
output_file,
|
||||
existing_ids,
|
||||
):
|
||||
"""
|
||||
Runs inference on a dataset using the openai API.
|
||||
|
||||
Args:
|
||||
test_dataset (datasets.Dataset): The dataset to run inference on.
|
||||
model_name_or_path (str): The name or path of the model to use.
|
||||
output_file (str): The path to the output file.
|
||||
existing_ids (set): A set of ids that have already been processed.
|
||||
"""
|
||||
test_dataset = test_dataset.filter(
|
||||
lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN,
|
||||
desc="Filtering",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
basic_args = {
|
||||
"model_name_or_path": model_name_or_path,
|
||||
}
|
||||
print(f"Filtered to {len(test_dataset)} instances")
|
||||
with open(output_file, "a+") as f:
|
||||
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
|
||||
di = DataInterpreter()
|
||||
instance_id = datum["instance_id"]
|
||||
if instance_id in existing_ids:
|
||||
continue
|
||||
output_dict = {"instance_id": instance_id}
|
||||
output_dict.update(basic_args)
|
||||
output_dict["text"] = f"{datum['text']}\n\n"
|
||||
response = await call_chat(
|
||||
output_dict["text"],
|
||||
di,
|
||||
)
|
||||
logger.info(f"Final response: {response}")
|
||||
output_dict["full_output"] = response
|
||||
output_dict["model_patch"] = extract_diff(response)
|
||||
print(json.dumps(output_dict), file=f, flush=True)
|
||||
save_history(di)
|
||||
|
||||
|
||||
async def main(
|
||||
dataset_name_or_path,
|
||||
split="test",
|
||||
model_name_or_path=config.llm.model,
|
||||
output_dir="outputs",
|
||||
):
|
||||
"""
|
||||
Performs inference on SWE-bench dataset using the Data Interpreter.
|
||||
|
||||
Args:
|
||||
dataset_name_or_path: HuggingFace dataset name or local path
|
||||
split: Dataset split to use (default: test)
|
||||
model_name_or_path: Name of the model to use (default: config.llm.model)
|
||||
param output_dir: Path to the output directory (default: outputs)
|
||||
"""
|
||||
if config.llm.api_type.value == "azure" and config.llm.model == "gpt-4":
|
||||
# Actual model name is gpt-4-1106-preview for Azure
|
||||
model_nickname = "gpt-4-1106-preview"
|
||||
else:
|
||||
model_nickname = Path(model_name_or_path).name
|
||||
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
|
||||
output_file = Path(output_dir, output_file + ".jsonl")
|
||||
logger.info(f"Will write to {output_file}")
|
||||
existing_ids = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, "r") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
instance_id = data["instance_id"]
|
||||
existing_ids.add(instance_id)
|
||||
logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
|
||||
if Path(dataset_name_or_path).exists():
|
||||
dataset = load_from_disk(dataset_name_or_path)
|
||||
else:
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
if split not in dataset:
|
||||
raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
|
||||
dataset = dataset[split]
|
||||
lens = np.array(list(map(len, dataset["text"])))
|
||||
dataset = dataset.select(np.argsort(lens))
|
||||
if len(existing_ids) > 0:
|
||||
dataset = dataset.filter(
|
||||
lambda x: x["instance_id"] not in existing_ids,
|
||||
desc="Filtering out existing ids",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
inference_args = {
|
||||
"test_dataset": dataset,
|
||||
"model_name_or_path": model_name_or_path,
|
||||
"output_file": output_file,
|
||||
"existing_ids": existing_ids,
|
||||
}
|
||||
if model_name_or_path.startswith("gpt"):
|
||||
await openai_inference(**inference_args)
|
||||
else:
|
||||
raise ValueError(f"Invalid model name or path {model_name_or_path}")
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
Loading…
Add table
Add a link
Reference in a new issue