task utils etc.

This commit is contained in:
yzlin 2024-02-01 20:07:44 +08:00
parent afb702c3f3
commit 2897981e63
4 changed files with 16 additions and 18 deletions

View file

@ -72,12 +72,8 @@ CODE_REFLECTION = {
}
def message_to_str(message: Message) -> str:
return f"{message.role}: {message.content}"
def messages_to_str(messages: List[Message]) -> str:
return "\n".join([message_to_str(message) for message in messages])
return "\n".join([str(message) for message in messages])
class DebugCode(BaseWriteAnalysisCode):

View file

@ -111,7 +111,7 @@ class Planner(BaseModel):
return "", confirmed
async def confirm_task(self, task: Task, task_result: TaskResult, review: str):
self.plan.update_task_result(task=task, task_result=task_result)
task.update_task_result(task_result=task_result)
self.plan.finish_current_task()
self.working_memory.clear()

View file

@ -341,6 +341,18 @@ class Task(BaseModel):
is_success: bool = False
is_finished: bool = False
def reset(self):
self.code = ""
self.result = ""
self.is_success = False
self.is_finished = False
def update_task_result(self, task_result: TaskResult):
self.code_steps = task_result.code_steps
self.code = task_result.code
self.result = task_result.result
self.is_success = task_result.is_success
class TaskResult(BaseModel):
"""Result of taking a task, with result and is_success required to be filled"""
@ -434,10 +446,7 @@ class Plan(BaseModel):
"""
if task_id in self.task_map:
task = self.task_map[task_id]
task.code = ""
task.result = ""
task.is_success = False
task.is_finished = False
task.reset()
def replace_task(self, new_task: Task):
"""
@ -483,12 +492,6 @@ class Plan(BaseModel):
self.task_map[new_task.task_id] = new_task
self._update_current_task()
def update_task_result(self, task: Task, task_result: TaskResult):
task.code_steps = task_result.code_steps
task.code = task_result.code
task.result = task_result.result
task.is_success = task_result.is_success
def has_task_id(self, task_id: str) -> bool:
return task_id in self.task_map

View file

@ -29,8 +29,7 @@ def save_code_file(name: str, code_context: str, file_format: str = "py") -> Non
# Choose to save as a Python file or a JSON file based on the file format
file_path = DATA_PATH / "output" / f"{name}/code.{file_format}"
if file_format == "py":
with open(file_path, "w", encoding="utf-8") as fp:
fp.write(code_context + "\n\n")
file_path.write_text(code_context + "\n\n", encoding="utf-8")
elif file_format == "json":
# Parse the code content as JSON and save
data = {"code": code_context}