diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 1126fcbf3..66f229f85 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -29,12 +29,12 @@ async def call_chat(inputs, interpreter): inputs (str): The inputs to generate completions for. interpreter (DataInterpreter): The data interpreter to use for execution. """ - requirement = "Please rewrite the code to address the issues existing in the repository and generate the correct code. Then, use the `git diff` command to output the patch based on the correct code." + requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code." system_messages = inputs.split("\n", 1)[0] user_message = inputs.split("\n", 1)[1] try: await interpreter.run([requirement, system_messages, user_message]) - return interpreter.get_last_cell_source + return interpreter.get_last_cell_source() except Exception as e: logger.error(f"Error: {e}\nInputs: {inputs}") traceback.print_exc() @@ -80,10 +80,10 @@ async def openai_inference( di, ) logger.info(f"Final response: {response}") + save_history(di) output_dict["full_output"] = response output_dict["model_patch"] = extract_diff(response) print(json.dumps(output_dict), file=f, flush=True) - save_history(di) async def main(