Fix get_last_cell_source method

This commit is contained in:
Your Name 2024-03-19 15:34:12 +08:00
parent cab00d8530
commit 916185d82c

View file

@ -29,12 +29,12 @@ async def call_chat(inputs, interpreter):
inputs (str): The inputs to generate completions for.
interpreter (DataInterpreter): The data interpreter to use for execution.
"""
requirement = "Please rewrite the code to address the issues existing in the repository and generate the correct code. Then, use the `git diff` command to output the patch based on the correct code."
requirement = "Please rewrite the code and generate test case to address the issues existing in the repository. If the test code passes, it is considered that the execution code has passed and use the `git diff` command to output the patch based on the correct code."
system_messages = inputs.split("\n", 1)[0]
user_message = inputs.split("\n", 1)[1]
try:
await interpreter.run([requirement, system_messages, user_message])
return interpreter.get_last_cell_source
return interpreter.get_last_cell_source()
except Exception as e:
logger.error(f"Error: {e}\nInputs: {inputs}")
traceback.print_exc()
@ -80,10 +80,10 @@ async def openai_inference(
di,
)
logger.info(f"Final response: {response}")
save_history(di)
output_dict["full_output"] = response
output_dict["model_patch"] = extract_diff(response)
print(json.dumps(output_dict), file=f, flush=True)
save_history(di)
async def main(