diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 18fb48c40..1126fcbf3 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -29,10 +29,11 @@ async def call_chat(inputs, interpreter): inputs (str): The inputs to generate completions for. interpreter (DataInterpreter): The data interpreter to use for execution. """ + requirement = "Please rewrite the code to address the issues existing in the repository and generate the correct code. Then, use the `git diff` command to output the patch based on the correct code." system_messages = inputs.split("\n", 1)[0] user_message = inputs.split("\n", 1)[1] try: - await interpreter.run([system_messages, user_message]) + await interpreter.run([requirement, system_messages, user_message]) return interpreter.get_last_cell_source except Exception as e: logger.error(f"Error: {e}\nInputs: {inputs}") @@ -45,6 +46,7 @@ async def openai_inference( model_name_or_path, output_file, existing_ids, + use_reflection, ): """ Runs inference on a dataset using the openai API. @@ -66,7 +68,7 @@ async def openai_inference( print(f"Filtered to {len(test_dataset)} instances") with open(output_file, "a+") as f: for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"): - di = DataInterpreter() + di = DataInterpreter(use_reflection=use_reflection) instance_id = datum["instance_id"] if instance_id in existing_ids: continue @@ -89,6 +91,7 @@ async def main( split="test", model_name_or_path=config.llm.model, output_dir="outputs", + use_reflection=True, ): """ Performs inference on SWE-bench dataset using the Data Interpreter. @@ -99,11 +102,7 @@ async def main( model_name_or_path: Name of the model to use (default: config.llm.model) param output_dir: Path to the output directory (default: outputs) """ - if config.llm.api_type.value == "azure" and config.llm.model == "gpt-4": - # Actual model name is gpt-4-1106-preview for Azure - model_nickname = "gpt-4-1106-preview" - else: - model_nickname = Path(model_name_or_path).name + model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}" output_file = Path(output_dir, output_file + ".jsonl") output_file.parent.mkdir(parents=True, exist_ok=True) @@ -136,6 +135,7 @@ async def main( "model_name_or_path": model_name_or_path, "output_file": output_file, "existing_ids": existing_ids, + "use_reflection": use_reflection, } if model_name_or_path.startswith("gpt"): await openai_inference(**inference_args)