1. Add a prompt tell DataInterpreter how to solve issue

2. Add an arg of use_reflection for DataInterpreter
This commit is contained in:
mannaandpoem 2024-03-18 18:39:56 +08:00
parent 1c8caee7a8
commit cab00d8530

View file

@ -29,10 +29,11 @@ async def call_chat(inputs, interpreter):
inputs (str): The inputs to generate completions for.
interpreter (DataInterpreter): The data interpreter to use for execution.
"""
requirement = "Please rewrite the code to address the issues existing in the repository and generate the correct code. Then, use the `git diff` command to output the patch based on the correct code."
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
try:
await interpreter.run([system_messages, user_message])
await interpreter.run([requirement, system_messages, user_message])
return interpreter.get_last_cell_source
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
@ -45,6 +46,7 @@ async def openai_inference(
model_name_or_path,
output_file,
existing_ids,
use_reflection,
):
"""
Runs inference on a dataset using the openai API.
@ -66,7 +68,7 @@ async def openai_inference(
print(f"Filtered to {len(test_dataset)} instances")
with open(output_file, "a+") as f:
for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
di = DataInterpreter()
di = DataInterpreter(use_reflection=use_reflection)
instance_id = datum["instance_id"]
if instance_id in existing_ids:
continue
@ -89,6 +91,7 @@ async def main(
split="test",
model_name_or_path=config.llm.model,
output_dir="outputs",
use_reflection=True,
):
"""
Performs inference on SWE-bench dataset using the Data Interpreter.
@ -99,11 +102,7 @@ async def main(
model_name_or_path: Name of the model to use (default: config.llm.model)
param output_dir: Path to the output directory (default: outputs)
"""
if config.llm.api_type.value == "azure" and config.llm.model == "gpt-4":
# Actual model name is gpt-4-1106-preview for Azure
model_nickname = "gpt-4-1106-preview"
else:
model_nickname = Path(model_name_or_path).name
model_nickname = Path(model_name_or_path).name if isinstance(model_name_or_path, Path) else model_name_or_path
output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
output_file = Path(output_dir, output_file + ".jsonl")
output_file.parent.mkdir(parents=True, exist_ok=True)
@ -136,6 +135,7 @@ async def main(
"model_name_or_path": model_name_or_path,
"output_file": output_file,
"existing_ids": existing_ids,
"use_reflection": use_reflection,
}
if model_name_or_path.startswith("gpt"):
await openai_inference(**inference_args)