diff --git a/data/inference/make_datasets/utils.py b/data/inference/make_datasets/utils.py new file mode 100644 index 000000000..284f8d976 --- /dev/null +++ b/data/inference/make_datasets/utils.py @@ -0,0 +1,28 @@ +import re + + +def extract_diff(response): + """ + Extracts the diff from a response formatted in different ways + """ + if response is None: + return None + diff_matches = [] + other_matches = [] + pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL) + for code, match in pattern.findall(response): + if code in {"diff", "patch"}: + diff_matches.append(match) + else: + other_matches.append(match) + pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL) + for code, match in pattern.findall(response): + if code in {"diff", "patch"}: + diff_matches.append(match) + else: + other_matches.append(match) + if diff_matches: + return diff_matches[0] + if other_matches: + return other_matches[0] + return response.split("")[0] diff --git a/data/inference/run_api.py b/data/inference/run_api.py new file mode 100644 index 000000000..4fa2f5894 --- /dev/null +++ b/data/inference/run_api.py @@ -0,0 +1,147 @@ +import json +import os +import traceback +from pathlib import Path + +import fire +import numpy as np +from datasets import load_dataset, load_from_disk +from make_datasets.utils import extract_diff +from tenacity import retry, stop_after_attempt, wait_random_exponential +from tqdm.auto import tqdm + +from metagpt.config2 import config +from metagpt.logs import logger +from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.utils import count_string_tokens +from metagpt.utils.recovery_util import save_history + +# Replace with your own +MAX_TOKEN = 128000 + + +@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(5)) +async def call_chat(inputs, interpreter): + """ + Calls the openai API to generate completions for the given inputs. + + Args: + inputs (str): The inputs to generate completions for. + interpreter (DataInterpreter): The data interpreter to use for execution. + """ + system_messages = inputs.split("\n", 1)[0] + user_message = inputs.split("\n", 1)[1] + try: + await interpreter.run([system_messages, user_message]) + return interpreter.get_last_cell_source + except Exception as e: + logger.error(f"Error: {e}\nInputs: {inputs}") + traceback.print_exc() + raise e + + +async def openai_inference( + test_dataset, + model_name_or_path, + output_file, + existing_ids, +): + """ + Runs inference on a dataset using the openai API. + + Args: + test_dataset (datasets.Dataset): The dataset to run inference on. + model_name_or_path (str): The name or path of the model to use. + output_file (str): The path to the output file. + existing_ids (set): A set of ids that have already been processed. + """ + test_dataset = test_dataset.filter( + lambda x: count_string_tokens(x["text"], model_name_or_path) <= MAX_TOKEN, + desc="Filtering", + load_from_cache_file=False, + ) + basic_args = { + "model_name_or_path": model_name_or_path, + } + print(f"Filtered to {len(test_dataset)} instances") + with open(output_file, "a+") as f: + for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"): + di = DataInterpreter() + instance_id = datum["instance_id"] + if instance_id in existing_ids: + continue + output_dict = {"instance_id": instance_id} + output_dict.update(basic_args) + output_dict["text"] = f"{datum['text']}\n\n" + response = await call_chat( + output_dict["text"], + di, + ) + logger.info(f"Final response: {response}") + output_dict["full_output"] = response + output_dict["model_patch"] = extract_diff(response) + print(json.dumps(output_dict), file=f, flush=True) + save_history(di) + + +async def main( + dataset_name_or_path, + split="test", + model_name_or_path=config.llm.model, + output_dir="outputs", +): + """ + Performs inference on SWE-bench dataset using the Data Interpreter. + + Args: + dataset_name_or_path: HuggingFace dataset name or local path + split: Dataset split to use (default: test) + model_name_or_path: Name of the model to use (default: config.llm.model) + param output_dir: Path to the output directory (default: outputs) + """ + if config.llm.api_type.value == "azure" and config.llm.model == "gpt-4": + # Actual model name is gpt-4-1106-preview for Azure + model_nickname = "gpt-4-1106-preview" + else: + model_nickname = Path(model_name_or_path).name + output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}" + output_file = Path(output_dir, output_file + ".jsonl") + logger.info(f"Will write to {output_file}") + existing_ids = set() + if os.path.exists(output_file): + with open(output_file, "r") as f: + for line in f: + data = json.loads(line) + instance_id = data["instance_id"] + existing_ids.add(instance_id) + logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}") + if Path(dataset_name_or_path).exists(): + dataset = load_from_disk(dataset_name_or_path) + else: + dataset = load_dataset(dataset_name_or_path) + if split not in dataset: + raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}") + dataset = dataset[split] + lens = np.array(list(map(len, dataset["text"]))) + dataset = dataset.select(np.argsort(lens)) + if len(existing_ids) > 0: + dataset = dataset.filter( + lambda x: x["instance_id"] not in existing_ids, + desc="Filtering out existing ids", + load_from_cache_file=False, + ) + inference_args = { + "test_dataset": dataset, + "model_name_or_path": model_name_or_path, + "output_file": output_file, + "existing_ids": existing_ids, + } + if model_name_or_path.startswith("gpt"): + await openai_inference(**inference_args) + else: + raise ValueError(f"Invalid model name or path {model_name_or_path}") + logger.info("Done!") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index a8534b710..c30d998e9 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -182,3 +182,6 @@ class DataInterpreter(Role): print(result) data_info = DATA_INFO.format(info=result) self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData)) + + def get_last_cell_source(self): + return self.execute_code.nb.cells[-1].source diff --git a/requirements.txt b/requirements.txt index 83565278b..54a500892 100644 --- a/requirements.txt +++ b/requirements.txt @@ -77,4 +77,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation \ No newline at end of file +jieba==0.42.1 # for tool recommendation +datasets==2.18.0 \ No newline at end of file