mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 01:06:27 +02:00
263 lines
8.1 KiB
Python
263 lines
8.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@Time : 2023/5/8 22:12
|
|
@Author : alexanderwu
|
|
@File : schema.py
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Type, TypedDict
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from metagpt.logs import logger
|
|
|
|
|
|
class RawMessage(TypedDict):
|
|
content: str
|
|
role: str
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
"""list[<role>: <content>]"""
|
|
content: str
|
|
instruct_content: BaseModel = field(default=None)
|
|
role: str = field(default='user') # system / user / assistant
|
|
cause_by: Type["Action"] = field(default="")
|
|
sent_from: str = field(default="")
|
|
send_to: str = field(default="")
|
|
restricted_to: str = field(default="")
|
|
state: str = None # None, done, todo, doing, error
|
|
|
|
def __str__(self):
|
|
# prefix = '-'.join([self.role, str(self.cause_by)])
|
|
return f"{self.role}: {self.content}"
|
|
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"role": self.role,
|
|
"content": self.content
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class UserMessage(Message):
|
|
"""便于支持OpenAI的消息
|
|
Facilitate support for OpenAI messages
|
|
"""
|
|
def __init__(self, content: str):
|
|
super().__init__(content, 'user')
|
|
|
|
|
|
@dataclass
|
|
class SystemMessage(Message):
|
|
"""便于支持OpenAI的消息
|
|
Facilitate support for OpenAI messages
|
|
"""
|
|
def __init__(self, content: str):
|
|
super().__init__(content, 'system')
|
|
|
|
|
|
@dataclass
|
|
class AIMessage(Message):
|
|
"""便于支持OpenAI的消息
|
|
Facilitate support for OpenAI messages
|
|
"""
|
|
def __init__(self, content: str):
|
|
super().__init__(content, 'assistant')
|
|
|
|
|
|
class Task(BaseModel):
|
|
task_id: str = ""
|
|
dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task
|
|
instruction: str = ""
|
|
task_type: str = ""
|
|
code_steps: str = ""
|
|
code: str = ""
|
|
result: str = ""
|
|
is_success: bool = False
|
|
is_finished: bool = False
|
|
|
|
|
|
class Plan(BaseModel):
|
|
goal: str
|
|
context: str = ""
|
|
tasks: list[Task] = []
|
|
task_map: dict[str, Task] = {}
|
|
current_task_id = ""
|
|
|
|
def _topological_sort(self, tasks: list[Task]):
|
|
task_map = {task.task_id: task for task in tasks}
|
|
dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks}
|
|
sorted_tasks = []
|
|
visited = set()
|
|
|
|
def visit(task_id):
|
|
if task_id in visited:
|
|
return
|
|
visited.add(task_id)
|
|
for dependent_id in dependencies.get(task_id, []):
|
|
visit(dependent_id)
|
|
sorted_tasks.append(task_map[task_id])
|
|
|
|
for task in tasks:
|
|
visit(task.task_id)
|
|
|
|
return sorted_tasks
|
|
|
|
def add_tasks(self, tasks: list[Task]):
|
|
"""
|
|
Integrates new tasks into the existing plan, ensuring dependency order is maintained.
|
|
|
|
This method performs two primary functions based on the current state of the task list:
|
|
1. If there are no existing tasks, it topologically sorts the provided tasks to ensure
|
|
correct execution order based on dependencies, and sets these as the current tasks.
|
|
2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains
|
|
any common prefix of tasks (based on task_id and instruction) and appends the remainder
|
|
of the new tasks. The current task is updated to the first unfinished task in this merged list.
|
|
|
|
Args:
|
|
tasks (list[Task]): A list of tasks (may be unordered) to add to the plan.
|
|
|
|
Returns:
|
|
None: The method updates the internal state of the plan but does not return anything.
|
|
"""
|
|
if not tasks:
|
|
return
|
|
|
|
# Topologically sort the new tasks to ensure correct dependency order
|
|
new_tasks = self._topological_sort(tasks)
|
|
|
|
if not self.tasks:
|
|
# If there are no existing tasks, set the new tasks as the current tasks
|
|
self.tasks = new_tasks
|
|
|
|
else:
|
|
# Find the length of the common prefix between existing and new tasks
|
|
prefix_length = 0
|
|
for old_task, new_task in zip(self.tasks, new_tasks):
|
|
if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction:
|
|
break
|
|
prefix_length += 1
|
|
|
|
# Combine the common prefix with the remainder of the new tasks
|
|
final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:]
|
|
self.tasks = final_tasks
|
|
|
|
# Update current_task_id to the first unfinished task in the merged list
|
|
self._update_current_task()
|
|
|
|
# Update the task map for quick access to tasks by ID
|
|
self.task_map = {task.task_id: task for task in self.tasks}
|
|
|
|
def reset_task(self, task_id: str):
|
|
"""
|
|
Clear code and result of the task based on task_id, and set the task as unfinished.
|
|
|
|
Args:
|
|
task_id (str): The ID of the task to be reset.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if task_id in self.task_map:
|
|
task = self.task_map[task_id]
|
|
task.code = ""
|
|
task.result = ""
|
|
task.is_success = False
|
|
task.is_finished = False
|
|
|
|
def replace_task(self, new_task: Task):
|
|
"""
|
|
Replace an existing task with the new input task based on task_id, and reset all tasks depending on it.
|
|
|
|
Args:
|
|
new_task (Task): The new task that will replace an existing one.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
assert new_task.task_id in self.task_map
|
|
# Replace the task in the task map and the task list
|
|
self.task_map[new_task.task_id] = new_task
|
|
for i, task in enumerate(self.tasks):
|
|
if task.task_id == new_task.task_id:
|
|
self.tasks[i] = new_task
|
|
break
|
|
|
|
# Reset dependent tasks
|
|
for task in self.tasks:
|
|
if new_task.task_id in task.dependent_task_ids:
|
|
self.reset_task(task.task_id)
|
|
|
|
def append_task(self, new_task: Task):
|
|
"""
|
|
Append a new task to the end of existing task sequences
|
|
|
|
Args:
|
|
new_task (Task): The new task to be appended to the existing task sequence
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead"
|
|
|
|
assert all([self.has_task_id(dep_id) for dep_id in new_task.dependent_task_ids]), \
|
|
"New task has unknown dependencies"
|
|
|
|
# Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence
|
|
self.tasks.append(new_task)
|
|
self.task_map[new_task.task_id] = new_task
|
|
self._update_current_task()
|
|
|
|
def has_task_id(self, task_id: str) -> bool:
|
|
return task_id in self.task_map
|
|
|
|
def _update_current_task(self):
|
|
current_task_id = ""
|
|
for task in self.tasks:
|
|
if not task.is_finished:
|
|
current_task_id = task.task_id
|
|
break
|
|
self.current_task_id = current_task_id # all tasks finished
|
|
|
|
@property
|
|
def current_task(self) -> Task:
|
|
"""Find current task to execute
|
|
|
|
Returns:
|
|
Task: the current task to be executed
|
|
"""
|
|
return self.task_map.get(self.current_task_id, None)
|
|
|
|
def finish_current_task(self):
|
|
"""Finish current task, set Task.is_finished=True, set current task to next task
|
|
"""
|
|
if self.current_task_id:
|
|
self.current_task.is_finished = True
|
|
self._update_current_task() # set to next task
|
|
|
|
def get_finished_tasks(self) -> list[Task]:
|
|
"""return all finished tasks in correct linearized order
|
|
|
|
Returns:
|
|
list[Task]: list of finished tasks
|
|
"""
|
|
return [task for task in self.tasks if task.is_finished]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_content = 'test_message'
|
|
msgs = [
|
|
UserMessage(test_content),
|
|
SystemMessage(test_content),
|
|
AIMessage(test_content),
|
|
Message(test_content, role='QA')
|
|
]
|
|
logger.info(msgs)
|