#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2023/5/8 22:12 @Author : alexanderwu @File : schema.py """ from __future__ import annotations from dataclasses import dataclass, field from typing import Type, TypedDict from pydantic import BaseModel from metagpt.logs import logger class RawMessage(TypedDict): content: str role: str @dataclass class Message: """list[: ]""" content: str instruct_content: BaseModel = field(default=None) role: str = field(default='user') # system / user / assistant cause_by: Type["Action"] = field(default="") sent_from: str = field(default="") send_to: str = field(default="") restricted_to: str = field(default="") state: str = None # None, done, todo, doing, error def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" def __repr__(self): return self.__str__() def to_dict(self) -> dict: return { "role": self.role, "content": self.content } @dataclass class UserMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content, 'user') @dataclass class SystemMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content, 'system') @dataclass class AIMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content, 'assistant') class Task(BaseModel): task_id: str = "" dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task instruction: str = "" task_type: str = "" code_steps: str = "" code: str = "" result: str = "" is_success: bool = False is_finished: bool = False class Plan(BaseModel): goal: str context: str = "" tasks: list[Task] = [] task_map: dict[str, Task] = {} current_task_id = "" def _topological_sort(self, tasks: list[Task]): task_map = {task.task_id: task for task in tasks} dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks} sorted_tasks = [] visited = set() def visit(task_id): if task_id in visited: return visited.add(task_id) for dependent_id in dependencies.get(task_id, []): visit(dependent_id) sorted_tasks.append(task_map[task_id]) for task in tasks: visit(task.task_id) return sorted_tasks def add_tasks(self, tasks: list[Task]): """ Integrates new tasks into the existing plan, ensuring dependency order is maintained. This method performs two primary functions based on the current state of the task list: 1. If there are no existing tasks, it topologically sorts the provided tasks to ensure correct execution order based on dependencies, and sets these as the current tasks. 2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains any common prefix of tasks (based on task_id and instruction) and appends the remainder of the new tasks. The current task is updated to the first unfinished task in this merged list. Args: tasks (list[Task]): A list of tasks (may be unordered) to add to the plan. Returns: None: The method updates the internal state of the plan but does not return anything. """ if not tasks: return # Topologically sort the new tasks to ensure correct dependency order new_tasks = self._topological_sort(tasks) if not self.tasks: # If there are no existing tasks, set the new tasks as the current tasks self.tasks = new_tasks else: # Find the length of the common prefix between existing and new tasks prefix_length = 0 for old_task, new_task in zip(self.tasks, new_tasks): if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction: break prefix_length += 1 # Combine the common prefix with the remainder of the new tasks final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:] self.tasks = final_tasks # Update current_task_id to the first unfinished task in the merged list self._update_current_task() # Update the task map for quick access to tasks by ID self.task_map = {task.task_id: task for task in self.tasks} def reset_task(self, task_id: str): """ Clear code and result of the task based on task_id, and set the task as unfinished. Args: task_id (str): The ID of the task to be reset. Returns: None """ if task_id in self.task_map: task = self.task_map[task_id] task.code = "" task.result = "" task.is_success = False task.is_finished = False def replace_task(self, new_task: Task): """ Replace an existing task with the new input task based on task_id, and reset all tasks depending on it. Args: new_task (Task): The new task that will replace an existing one. Returns: None """ assert new_task.task_id in self.task_map # Replace the task in the task map and the task list self.task_map[new_task.task_id] = new_task for i, task in enumerate(self.tasks): if task.task_id == new_task.task_id: self.tasks[i] = new_task break # Reset dependent tasks for task in self.tasks: if new_task.task_id in task.dependent_task_ids: self.reset_task(task.task_id) def append_task(self, new_task: Task): """ Append a new task to the end of existing task sequences Args: new_task (Task): The new task to be appended to the existing task sequence Returns: None """ assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead" assert all([self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids]), \ "New task has unknown dependencies" # Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence self.tasks.append(new_task) self.task_map[new_task.task_id] = new_task self._update_current_task() def has_task_id(self, task_id: str) -> bool: return task_id in self.task_map def _update_current_task(self): current_task_id = "" for task in self.tasks: if not task.is_finished: current_task_id = task.task_id break self.current_task_id = current_task_id # all tasks finished @property def current_task(self) -> Task: """Find current task to execute Returns: Task: the current task to be executed """ return self.task_map.get(self.current_task_id, None) def finish_current_task(self): """Finish current task, set Task.is_finished=True, set current task to next task """ if self.current_task_id: self.current_task.is_finished = True self._update_current_task() # set to next task def get_finished_tasks(self) -> list[Task]: """return all finished tasks in correct linearized order Returns: list[Task]: list of finished tasks """ return [task for task in self.tasks if task.is_finished] if __name__ == '__main__': test_content = 'test_message' msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), Message(test_content, role='QA') ] logger.info(msgs)