From 916185d82c36ab6136b584322f6c81b6dba209cc Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 19 Mar 2024 15:34:12 +0800 Subject: [PATCH] Fix get_last_cell_source method --- data/inference/run_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 1126fcbf3..66f229f85 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -29,12 +29,12 @@ async def call_chat(inputs, interpreter): inputs (str): The inputs to generate completions for. interpreter (DataInterpreter): The data interpreter to use for execution. """ - requirement = "Please rewrite the code to address the issues existing in the repository and generate the correct code. Then, use the `git diff` command to output the patch based on the correct code." + requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code." system_messages = inputs.split("\n", 1)[0] user_message = inputs.split("\n", 1)[1] try: await interpreter.run([requirement, system_messages, user_message]) - return interpreter.get_last_cell_source + return interpreter.get_last_cell_source() except Exception as e: logger.error(f"Error: {e}\nInputs: {inputs}") traceback.print_exc() @@ -80,10 +80,10 @@ async def openai_inference( di, ) logger.info(f"Final response: {response}") + save_history(di) output_dict["full_output"] = response output_dict["model_patch"] = extract_diff(response) print(json.dumps(output_dict), file=f, flush=True) - save_history(di) async def main(