2023-06-30 17:10:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023 / 5 / 8 22 : 12
@Author : alexanderwu
@File : schema . py
2023-11-03 11:53:47 +08:00
@Modified By : mashenquan , 2023 - 10 - 31. According to Chapter 2.2 .1 of RFC 116 :
Replanned the distribution of responsibilities and functional positioning of ` Message ` class attributes .
2023-11-27 16:15:55 +08:00
@Modified By : mashenquan , 2023 / 11 / 22.
1. Add ` Document ` and ` Documents ` for ` FileRepository ` in Section 2.2 .3 .4 of RFC 135.
2. Encapsulate the common key - values set to pydantic structures to standardize and unify parameter passing
between actions .
2023-11-29 10:14:04 +08:00
3. Add ` id ` to ` Message ` according to Section 2.2 .3 .1 .1 of RFC 135.
2023-06-30 17:10:48 +08:00
"""
2023-07-22 11:28:22 +08:00
2023-06-30 17:10:48 +08:00
from __future__ import annotations
2023-07-22 11:28:22 +08:00
2023-11-01 20:08:58 +08:00
import asyncio
2023-10-31 15:23:37 +08:00
import json
2023-11-22 17:08:00 +08:00
import os . path
2023-11-29 10:14:04 +08:00
import uuid
2023-12-19 23:53:04 +08:00
from abc import ABC
2023-11-01 20:08:58 +08:00
from asyncio import Queue , QueueEmpty , wait_for
2024-03-31 10:53:08 +08:00
from enum import Enum
2023-10-31 15:23:37 +08:00
from json import JSONDecodeError
2023-11-28 18:16:50 +08:00
from pathlib import Path
2024-01-15 16:37:42 +08:00
from typing import Any , Dict , Iterable , List , Optional , Type , TypeVar , Union
2023-12-27 14:00:54 +08:00
from pydantic import (
BaseModel ,
ConfigDict ,
Field ,
PrivateAttr ,
2024-05-18 14:36:22 +08:00
create_model ,
2023-12-27 14:00:54 +08:00
field_serializer ,
field_validator ,
2024-01-08 22:15:56 +08:00
model_serializer ,
model_validator ,
2023-12-27 14:00:54 +08:00
)
2023-06-30 17:10:48 +08:00
2023-11-01 20:08:58 +08:00
from metagpt . const import (
2024-04-30 12:06:56 +08:00
AGENT ,
2023-11-01 20:08:58 +08:00
MESSAGE_ROUTE_CAUSE_BY ,
MESSAGE_ROUTE_FROM ,
MESSAGE_ROUTE_TO ,
2023-11-06 22:38:43 +08:00
MESSAGE_ROUTE_TO_ALL ,
2023-11-28 18:16:50 +08:00
SYSTEM_DESIGN_FILE_REPO ,
TASK_FILE_REPO ,
2023-11-01 20:08:58 +08:00
)
2024-04-25 20:14:18 +08:00
from metagpt . logs import logger
2024-01-26 19:39:06 +08:00
from metagpt . repo_parser import DotClassInfo
2024-06-03 17:57:24 +08:00
from metagpt . tools . tool_registry import register_tool
2024-05-18 14:36:22 +08:00
from metagpt . utils . common import (
CodeParser ,
any_to_str ,
any_to_str_set ,
aread ,
import_class ,
)
2023-12-19 16:16:52 +08:00
from metagpt . utils . exceptions import handle_exception
2024-04-26 21:38:07 +08:00
from metagpt . utils . report import TaskReporter
2023-12-21 10:48:46 +08:00
from metagpt . utils . serialize import (
actionoutout_schema_to_mapping ,
actionoutput_mapping_to_str ,
actionoutput_str_to_mapping ,
)
2023-06-30 17:10:48 +08:00
2024-01-08 22:15:56 +08:00
class SerializationMixin ( BaseModel , extra = " forbid " ) :
2024-01-02 15:26:23 +08:00
"""
PolyMorphic subclasses Serialization / Deserialization Mixin
- First of all , we need to know that pydantic is not designed for polymorphism .
- If Engineer is subclass of Role , it would be serialized as Role . If we want to serialize it as Engineer , we need
to add ` class name ` to Engineer . So we need Engineer inherit SerializationMixin .
More details :
- https : / / docs . pydantic . dev / latest / concepts / serialization /
- https : / / github . com / pydantic / pydantic / discussions / 7008 discuss about avoid ` __get_pydantic_core_schema__ `
"""
2023-12-28 16:07:39 +08:00
__is_polymorphic_base = False
__subclasses_map__ = { }
2024-01-08 22:15:56 +08:00
@model_serializer ( mode = " wrap " )
def __serialize_with_class_type__ ( self , default_serializer ) - > Any :
# default serializer, then append the `__module_class_name` field and return
ret = default_serializer ( self )
ret [ " __module_class_name " ] = f " { self . __class__ . __module__ } . { self . __class__ . __qualname__ } "
2023-12-28 16:07:39 +08:00
return ret
2024-01-08 22:15:56 +08:00
@model_validator ( mode = " wrap " )
2023-12-28 16:07:39 +08:00
@classmethod
2024-01-08 22:15:56 +08:00
def __convert_to_real_type__ ( cls , value : Any , handler ) :
if isinstance ( value , dict ) is False :
return handler ( value )
# it is a dict so make sure to remove the __module_class_name
# because we don't allow extra keywords but want to ensure
# e.g Cat.model_validate(cat.model_dump()) works
class_full_name = value . pop ( " __module_class_name " , None )
# if it's not the polymorphic base we construct via default handler
if not cls . __is_polymorphic_base :
if class_full_name is None :
return handler ( value )
elif str ( cls ) == f " <class ' { class_full_name } ' > " :
return handler ( value )
else :
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
pass
# otherwise we lookup the correct polymorphic type and construct that
# instead
if class_full_name is None :
raise ValueError ( " Missing __module_class_name field " )
class_type = cls . __subclasses_map__ . get ( class_full_name , None )
2023-12-28 16:07:39 +08:00
if class_type is None :
2024-01-08 22:15:56 +08:00
# TODO could try dynamic import
raise TypeError ( " Trying to instantiate {class_full_name} , which has not yet been defined! " )
2023-12-28 16:07:39 +08:00
return class_type ( * * value )
def __init_subclass__ ( cls , is_polymorphic_base : bool = False , * * kwargs ) :
cls . __is_polymorphic_base = is_polymorphic_base
cls . __subclasses_map__ [ f " { cls . __module__ } . { cls . __qualname__ } " ] = cls
super ( ) . __init_subclass__ ( * * kwargs )
2023-12-25 22:39:03 +08:00
class SimpleMessage ( BaseModel ) :
2023-06-30 17:10:48 +08:00
content : str
role : str
2023-11-22 17:08:00 +08:00
class Document ( BaseModel ) :
"""
Represents a document .
"""
2023-12-04 23:04:07 +08:00
root_path : str = " "
filename : str = " "
content : str = " "
2023-11-22 17:08:00 +08:00
def get_meta ( self ) - > Document :
""" Get metadata of the document.
: return : A new Document instance with the same root path and filename .
"""
return Document ( root_path = self . root_path , filename = self . filename )
@property
def root_relative_path ( self ) :
""" Get relative path from root of git repository.
: return : relative path from root of git repository .
"""
return os . path . join ( self . root_path , self . filename )
2023-12-15 00:37:10 +08:00
def __str__ ( self ) :
return self . content
def __repr__ ( self ) :
return self . content
2024-05-18 14:36:22 +08:00
@classmethod
async def load (
cls , filename : Union [ str , Path ] , project_path : Optional [ Union [ str , Path ] ] = None
) - > Optional [ " Document " ] :
"""
Load a document from a file .
Args :
filename ( Union [ str , Path ] ) : The path to the file to load .
project_path ( Optional [ Union [ str , Path ] ] , optional ) : The path to the project . Defaults to None .
Returns :
Optional [ Document ] : The loaded document , or None if the file does not exist .
"""
if not filename or not Path ( filename ) . exists ( ) :
return None
content = await aread ( filename = filename )
doc = cls ( content = content , filename = str ( filename ) )
if project_path and Path ( filename ) . is_relative_to ( project_path ) :
doc . root_path = Path ( filename ) . relative_to ( project_path ) . parent
doc . filename = Path ( filename ) . name
return doc
2023-11-22 17:08:00 +08:00
class Documents ( BaseModel ) :
""" A class representing a collection of documents.
Attributes :
docs ( Dict [ str , Document ] ) : A dictionary mapping document names to Document instances .
"""
docs : Dict [ str , Document ] = Field ( default_factory = dict )
2024-01-15 16:37:42 +08:00
@classmethod
def from_iterable ( cls , documents : Iterable [ Document ] ) - > Documents :
""" Create a Documents instance from a list of Document instances.
: param documents : A list of Document instances .
: return : A Documents instance .
"""
docs = { doc . filename : doc for doc in documents }
return Documents ( docs = docs )
2024-01-15 16:41:51 +08:00
def to_action_output ( self ) - > " ActionOutput " :
2024-01-15 16:37:42 +08:00
""" Convert to action output string.
: return : A string representing action output .
"""
2024-01-15 16:41:51 +08:00
from metagpt . actions . action_output import ActionOutput
2024-01-15 16:37:42 +08:00
return ActionOutput ( content = self . model_dump_json ( ) , instruct_content = self )
2023-11-22 17:08:00 +08:00
2024-04-28 22:49:43 +08:00
class Resource ( BaseModel ) :
""" Used by `Message`.`parse_resources` """
resource_type : str # the type of resource
value : str # a string type of resource content
description : str # explanation
2023-10-31 15:23:37 +08:00
class Message ( BaseModel ) :
2023-06-30 17:10:48 +08:00
""" list[<role>: <content>] """
2023-10-31 15:23:37 +08:00
2023-12-27 14:00:54 +08:00
id : str = Field ( default = " " , validate_default = True ) # According to Section 2.2.3.1.1 of RFC 135
2024-05-10 11:53:21 +08:00
content : str # natural language for user or agent
2023-12-27 14:00:54 +08:00
instruct_content : Optional [ BaseModel ] = Field ( default = None , validate_default = True )
2023-11-08 20:27:18 +08:00
role : str = " user " # system / user / assistant
2023-12-27 14:00:54 +08:00
cause_by : str = Field ( default = " " , validate_default = True )
sent_from : str = Field ( default = " " , validate_default = True )
2023-12-29 04:27:44 +08:00
send_to : set [ str ] = Field ( default = { MESSAGE_ROUTE_TO_ALL } , validate_default = True )
2024-04-30 11:58:35 +08:00
metadata : Dict [ str , str ] = Field ( default_factory = dict ) # metadata for `content` and `instruct_content`
2023-12-27 14:00:54 +08:00
@field_validator ( " id " , mode = " before " )
@classmethod
def check_id ( cls , id : str ) - > str :
return id if id else uuid . uuid4 ( ) . hex
@field_validator ( " instruct_content " , mode = " before " )
@classmethod
def check_instruct_content ( cls , ic : Any ) - > BaseModel :
2024-01-09 15:40:42 +08:00
if ic and isinstance ( ic , dict ) and " class " in ic :
if " mapping " in ic :
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping ( ic [ " mapping " ] )
actionnode_class = import_class ( " ActionNode " , " metagpt.actions.action_node " ) # avoid circular import
ic_obj = actionnode_class . create_model_class ( class_name = ic [ " class " ] , mapping = mapping )
elif " module " in ic :
2024-01-09 16:07:33 +08:00
# subclasses of BaseModel
2024-01-09 15:40:42 +08:00
ic_obj = import_class ( ic [ " class " ] , ic [ " module " ] )
else :
raise KeyError ( " missing required key to init Message.instruct_content from dict " )
2023-12-27 14:00:54 +08:00
ic = ic_obj ( * * ic [ " value " ] )
return ic
@field_validator ( " cause_by " , mode = " before " )
@classmethod
def check_cause_by ( cls , cause_by : Any ) - > str :
return any_to_str ( cause_by if cause_by else import_class ( " UserRequirement " , " metagpt.actions.add_requirement " ) )
@field_validator ( " sent_from " , mode = " before " )
@classmethod
def check_sent_from ( cls , sent_from : Any ) - > str :
return any_to_str ( sent_from if sent_from else " " )
@field_validator ( " send_to " , mode = " before " )
@classmethod
def check_send_to ( cls , send_to : Any ) - > set :
return any_to_str_set ( send_to if send_to else { MESSAGE_ROUTE_TO_ALL } )
2024-03-07 19:05:46 +08:00
@field_serializer ( " send_to " , mode = " plain " )
def ser_send_to ( self , send_to : set ) - > list :
return list ( send_to )
2023-12-27 14:00:54 +08:00
@field_serializer ( " instruct_content " , mode = " plain " )
2024-01-15 20:10:39 +08:00
def ser_instruct_content ( self , ic : BaseModel ) - > Union [ dict , None ] :
2023-12-27 14:00:54 +08:00
ic_dict = None
if ic :
# compatible with custom-defined ActionOutput
schema = ic . model_json_schema ( )
2024-01-09 15:40:42 +08:00
ic_type = str ( type ( ic ) )
if " <class ' metagpt.actions.action_node " in ic_type :
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
2023-12-27 14:00:54 +08:00
mapping = actionoutout_schema_to_mapping ( schema )
mapping = actionoutput_mapping_to_str ( mapping )
ic_dict = { " class " : schema [ " title " ] , " mapping " : mapping , " value " : ic . model_dump ( ) }
2024-01-09 15:40:42 +08:00
else :
# due to instruct_content can be assigned by subclasses of BaseModel
ic_dict = { " class " : schema [ " title " ] , " module " : ic . __module__ , " value " : ic . model_dump ( ) }
2023-12-27 14:00:54 +08:00
return ic_dict
def __init__ ( self , content : str = " " , * * data : Any ) :
data [ " content " ] = data . get ( " content " , content )
super ( ) . __init__ ( * * data )
2023-10-31 15:23:37 +08:00
2023-11-04 16:20:47 +08:00
def __setattr__ ( self , key , val ) :
2023-11-08 20:27:18 +08:00
""" Override `@property.setter`, convert non-string parameters into string parameters. """
2023-11-04 16:20:47 +08:00
if key == MESSAGE_ROUTE_CAUSE_BY :
2023-11-08 20:27:18 +08:00
new_val = any_to_str ( val )
elif key == MESSAGE_ROUTE_FROM :
new_val = any_to_str ( val )
elif key == MESSAGE_ROUTE_TO :
new_val = any_to_str_set ( val )
else :
new_val = val
super ( ) . __setattr__ ( key , new_val )
2023-06-30 17:10:48 +08:00
def __str__ ( self ) :
# prefix = '-'.join([self.role, str(self.cause_by)])
2023-12-22 16:40:04 +08:00
if self . instruct_content :
2023-12-26 14:44:09 +08:00
return f " { self . role } : { self . instruct_content . model_dump ( ) } "
2023-06-30 17:10:48 +08:00
return f " { self . role } : { self . content } "
def __repr__ ( self ) :
return self . __str__ ( )
2024-03-07 19:05:46 +08:00
def rag_key ( self ) - > str :
""" For search """
return self . content
2023-06-30 17:10:48 +08:00
def to_dict ( self ) - > dict :
2023-11-01 20:08:58 +08:00
""" Return a dict containing `role` and `content` for the LLM call.l """
2023-10-31 15:23:37 +08:00
return { " role " : self . role , " content " : self . content }
2023-11-04 14:26:48 +08:00
def dump ( self ) - > str :
2023-11-01 20:08:58 +08:00
""" Convert the object to json string """
2023-12-27 16:34:43 +08:00
return self . model_dump_json ( exclude_none = True , warnings = False )
2023-10-31 15:23:37 +08:00
@staticmethod
2023-12-19 16:16:52 +08:00
@handle_exception ( exception_type = JSONDecodeError , default_return = None )
2023-11-08 20:27:18 +08:00
def load ( val ) :
2023-11-01 20:08:58 +08:00
""" Convert the json string to object. """
2023-12-22 16:40:04 +08:00
2023-10-31 15:23:37 +08:00
try :
2023-12-19 10:44:06 +08:00
m = json . loads ( val )
id = m . get ( " id " )
if " id " in m :
del m [ " id " ]
msg = Message ( * * m )
if id :
msg . id = id
return msg
2023-10-31 15:23:37 +08:00
except JSONDecodeError as err :
2023-11-08 20:27:18 +08:00
logger . error ( f " parse json failed: { val } , error: { err } " )
2023-10-31 15:23:37 +08:00
return None
2023-06-30 17:10:48 +08:00
2024-04-28 22:49:43 +08:00
async def parse_resources ( self , llm : " BaseLLM " , key_descriptions : Dict [ str , str ] = None ) - > Dict :
2024-04-30 14:32:45 +08:00
"""
` parse_resources ` corresponds to the in - context adaptation capability of the input of the atomic action ,
which will be migrated to the context builder later .
Args :
llm ( BaseLLM ) : The instance of the BaseLLM class .
key_descriptions ( Dict [ str , str ] , optional ) : A dictionary containing descriptions for each key ,
if provided . Defaults to None .
Returns :
Dict : A dictionary containing parsed resources .
"""
2024-04-28 22:49:43 +08:00
if not self . content :
return { }
content = f " ## Original Requirement \n ```text \n { self . content } \n ``` \n "
return_format = (
" Return a markdown JSON object with: \n "
' - a " resources " key contain a list of objects. Each object with: \n '
' - a " resource_type " key explain the type of resource; \n '
' - a " value " key containing a string type of resource content; \n '
' - a " description " key explaining why; \n '
)
key_descriptions = key_descriptions or { }
for k , v in key_descriptions . items ( ) :
return_format + = f ' - a " { k } " key containing { v } ; \n '
return_format + = ' - a " reason " key explaining why; \n '
instructions = [ ' Lists all the resources contained in the " Original Requirement " . ' , return_format ]
rsp = await llm . aask ( msg = content , system_msgs = instructions )
2024-04-30 11:18:10 +08:00
json_data = CodeParser . parse_code ( text = rsp , lang = " json " )
m = json . loads ( json_data )
2024-04-28 22:49:43 +08:00
m [ " resources " ] = [ Resource ( * * i ) for i in m . get ( " resources " , [ ] ) ]
return m
2024-04-30 12:06:56 +08:00
def add_metadata ( self , key : str , value : str ) :
self . metadata [ key ] = value
2024-05-18 14:36:22 +08:00
@staticmethod
def create_instruct_value ( kvs : Dict [ str , Any ] , class_name : str = " " ) - > BaseModel :
"""
Dynamically creates a Pydantic BaseModel subclass based on a given dictionary .
Parameters :
- data : A dictionary from which to create the BaseModel subclass .
Returns :
- A Pydantic BaseModel subclass instance populated with the given data .
"""
if not class_name :
class_name = " DM " + uuid . uuid4 ( ) . hex [ 0 : 8 ]
dynamic_class = create_model ( class_name , * * { key : ( value . __class__ , . . . ) for key , value in kvs . items ( ) } )
return dynamic_class . model_validate ( kvs )
2023-06-30 17:10:48 +08:00
class UserMessage ( Message ) :
2023-08-08 12:44:33 +01:00
""" 便于支持OpenAI的消息
2023-10-31 15:23:37 +08:00
Facilitate support for OpenAI messages
2023-08-08 12:44:33 +01:00
"""
2023-10-31 15:23:37 +08:00
2024-04-29 15:07:21 +08:00
def __init__ ( self , content : str , * * kwargs ) :
kwargs . pop ( " role " , None )
super ( ) . __init__ ( content = content , role = " user " , * * kwargs )
2023-06-30 17:10:48 +08:00
class SystemMessage ( Message ) :
2023-08-08 12:44:33 +01:00
""" 便于支持OpenAI的消息
2023-10-31 15:23:37 +08:00
Facilitate support for OpenAI messages
2023-08-08 12:44:33 +01:00
"""
2023-10-31 15:23:37 +08:00
2024-04-29 15:07:21 +08:00
def __init__ ( self , content : str , * * kwargs ) :
kwargs . pop ( " role " , None )
super ( ) . __init__ ( content = content , role = " system " , * * kwargs )
2023-06-30 17:10:48 +08:00
class AIMessage ( Message ) :
2023-08-08 12:44:33 +01:00
""" 便于支持OpenAI的消息
2023-10-31 15:23:37 +08:00
Facilitate support for OpenAI messages
2023-08-08 12:44:33 +01:00
"""
2023-10-31 15:23:37 +08:00
2024-04-29 15:07:21 +08:00
def __init__ ( self , content : str , * * kwargs ) :
kwargs . pop ( " role " , None )
super ( ) . __init__ ( content = content , role = " assistant " , * * kwargs )
2023-11-01 20:08:58 +08:00
2024-04-30 12:06:56 +08:00
def with_agent ( self , name : str ) :
self . add_metadata ( key = AGENT , value = name )
2024-04-30 12:24:30 +08:00
return self
2024-04-30 12:06:56 +08:00
@property
def agent ( self ) - > str :
return self . metadata . get ( AGENT , " " )
2023-11-01 20:08:58 +08:00
2023-11-23 21:59:25 +08:00
class Task ( BaseModel ) :
task_id : str = " "
2024-01-10 14:15:30 +08:00
dependent_task_ids : list [ str ] = [ ] # Tasks prerequisite to this Task
2023-11-23 21:59:25 +08:00
instruction : str = " "
task_type : str = " "
code : str = " "
result : str = " "
2023-12-28 20:17:33 +08:00
is_success : bool = False
2023-11-23 21:59:25 +08:00
is_finished : bool = False
2024-04-25 12:00:32 +08:00
assignee : str = " "
2023-11-23 21:59:25 +08:00
2024-02-01 20:07:44 +08:00
def reset ( self ) :
self . code = " "
self . result = " "
self . is_success = False
self . is_finished = False
def update_task_result ( self , task_result : TaskResult ) :
self . code = task_result . code
self . result = task_result . result
self . is_success = task_result . is_success
2023-11-23 21:59:25 +08:00
2024-01-09 16:54:36 +08:00
class TaskResult ( BaseModel ) :
""" Result of taking a task, with result and is_success required to be filled """
2024-01-10 14:15:30 +08:00
2024-01-09 16:54:36 +08:00
code : str = " "
result : str
is_success : bool
2024-06-03 17:57:24 +08:00
@register_tool (
include_functions = [
" append_task " ,
" reset_task " ,
" replace_task " ,
" finish_current_task " ,
]
)
2023-11-23 21:59:25 +08:00
class Plan ( BaseModel ) :
2024-06-03 17:57:24 +08:00
""" Plan is a sequence of tasks towards a goal. """
2023-11-24 14:05:11 +08:00
goal : str
2023-12-01 00:44:47 +08:00
context : str = " "
2023-11-23 21:59:25 +08:00
tasks : list [ Task ] = [ ]
task_map : dict [ str , Task ] = { }
2024-01-10 17:20:01 +08:00
current_task_id : str = " "
2023-11-23 21:59:25 +08:00
def _topological_sort ( self , tasks : list [ Task ] ) :
task_map = { task . task_id : task for task in tasks }
dependencies = { task . task_id : set ( task . dependent_task_ids ) for task in tasks }
sorted_tasks = [ ]
visited = set ( )
def visit ( task_id ) :
if task_id in visited :
return
visited . add ( task_id )
for dependent_id in dependencies . get ( task_id , [ ] ) :
visit ( dependent_id )
sorted_tasks . append ( task_map [ task_id ] )
for task in tasks :
visit ( task . task_id )
return sorted_tasks
def add_tasks ( self , tasks : list [ Task ] ) :
"""
Integrates new tasks into the existing plan , ensuring dependency order is maintained .
2024-01-10 14:15:30 +08:00
2023-11-23 21:59:25 +08:00
This method performs two primary functions based on the current state of the task list :
2024-01-10 14:15:30 +08:00
1. If there are no existing tasks , it topologically sorts the provided tasks to ensure
2023-11-23 21:59:25 +08:00
correct execution order based on dependencies , and sets these as the current tasks .
2024-01-10 14:15:30 +08:00
2. If there are existing tasks , it merges the new tasks with the existing ones . It maintains
any common prefix of tasks ( based on task_id and instruction ) and appends the remainder
2023-11-23 21:59:25 +08:00
of the new tasks . The current task is updated to the first unfinished task in this merged list .
Args :
tasks ( list [ Task ] ) : A list of tasks ( may be unordered ) to add to the plan .
Returns :
None : The method updates the internal state of the plan but does not return anything .
"""
if not tasks :
return
# Topologically sort the new tasks to ensure correct dependency order
new_tasks = self . _topological_sort ( tasks )
if not self . tasks :
# If there are no existing tasks, set the new tasks as the current tasks
self . tasks = new_tasks
else :
# Find the length of the common prefix between existing and new tasks
prefix_length = 0
for old_task , new_task in zip ( self . tasks , new_tasks ) :
if old_task . task_id != new_task . task_id or old_task . instruction != new_task . instruction :
break
prefix_length + = 1
# Combine the common prefix with the remainder of the new tasks
final_tasks = self . tasks [ : prefix_length ] + new_tasks [ prefix_length : ]
self . tasks = final_tasks
2024-01-10 14:15:30 +08:00
2023-11-23 21:59:25 +08:00
# Update current_task_id to the first unfinished task in the merged list
2023-12-11 16:13:34 +08:00
self . _update_current_task ( )
2023-11-23 21:59:25 +08:00
# Update the task map for quick access to tasks by ID
self . task_map = { task . task_id : task for task in self . tasks }
2024-01-10 14:15:30 +08:00
2023-12-02 01:34:22 +08:00
def reset_task ( self , task_id : str ) :
"""
2024-06-03 17:57:24 +08:00
Reset a task based on task_id , i . e . set Task . is_finished = False and request redo . This also resets all tasks depending on it .
2023-12-02 01:34:22 +08:00
Args :
task_id ( str ) : The ID of the task to be reset .
"""
if task_id in self . task_map :
task = self . task_map [ task_id ]
2024-02-01 20:07:44 +08:00
task . reset ( )
2024-05-13 15:11:43 +08:00
# reset all downstream tasks that are dependent on the reset task
for dep_task in self . tasks :
if task_id in dep_task . dependent_task_ids :
# FIXME: if LLM generates cyclic tasks, this will result in infinite recursion
self . reset_task ( dep_task . task_id )
self . _update_current_task ( )
2023-12-02 01:34:22 +08:00
2024-06-03 17:57:24 +08:00
def _replace_task ( self , new_task : Task ) :
2023-12-02 01:34:22 +08:00
"""
Replace an existing task with the new input task based on task_id , and reset all tasks depending on it .
Args :
new_task ( Task ) : The new task that will replace an existing one .
Returns :
None
"""
2023-12-28 20:17:33 +08:00
assert new_task . task_id in self . task_map
# Replace the task in the task map and the task list
self . task_map [ new_task . task_id ] = new_task
for i , task in enumerate ( self . tasks ) :
if task . task_id == new_task . task_id :
self . tasks [ i ] = new_task
break
# Reset dependent tasks
for task in self . tasks :
if new_task . task_id in task . dependent_task_ids :
self . reset_task ( task . task_id )
2023-11-23 21:59:25 +08:00
2024-04-03 22:16:17 +08:00
self . _update_current_task ( )
2024-06-03 17:57:24 +08:00
def _append_task ( self , new_task : Task ) :
2023-12-11 16:13:34 +08:00
"""
Append a new task to the end of existing task sequences
Args :
new_task ( Task ) : The new task to be appended to the existing task sequence
2024-01-10 14:15:30 +08:00
2023-12-11 16:13:34 +08:00
Returns :
None
"""
2024-04-25 12:00:32 +08:00
# assert not self.has_task_id(new_task.task_id), "Task already in current plan, use replace_task instead"
if self . has_task_id ( new_task . task_id ) :
logger . warning (
" Task already in current plan, should use replace_task instead. Overwriting the existing task. "
)
2023-12-11 16:13:34 +08:00
2024-01-10 14:15:30 +08:00
assert all (
[ self . has_task_id ( dep_id ) for dep_id in new_task . dependent_task_ids ]
) , " New task has unknown dependencies "
2023-12-11 16:13:34 +08:00
# Existing tasks do not depend on the new task, it's fine to put it to the end of the sorted task sequence
self . tasks . append ( new_task )
self . task_map [ new_task . task_id ] = new_task
self . _update_current_task ( )
2024-01-10 14:15:30 +08:00
2023-12-02 01:34:22 +08:00
def has_task_id ( self , task_id : str ) - > bool :
return task_id in self . task_map
2023-12-11 16:13:34 +08:00
def _update_current_task ( self ) :
2024-04-25 12:00:32 +08:00
self . tasks = self . _topological_sort ( self . tasks )
# Update the task map for quick access to tasks by ID
self . task_map = { task . task_id : task for task in self . tasks }
2023-12-11 16:13:34 +08:00
current_task_id = " "
for task in self . tasks :
if not task . is_finished :
current_task_id = task . task_id
break
2024-04-03 22:16:17 +08:00
self . current_task_id = current_task_id
2024-05-08 15:13:13 +08:00
TaskReporter ( ) . report ( { " tasks " : [ i . model_dump ( ) for i in self . tasks ] , " current_task_id " : current_task_id } )
2024-01-10 14:15:30 +08:00
2023-11-23 21:59:25 +08:00
@property
def current_task ( self ) - > Task :
""" Find current task to execute
Returns :
Task : the current task to be executed
"""
return self . task_map . get ( self . current_task_id , None )
def finish_current_task ( self ) :
2024-01-10 14:15:30 +08:00
""" Finish current task, set Task.is_finished=True, set current task to next task """
2023-11-23 21:59:25 +08:00
if self . current_task_id :
2023-12-11 16:13:34 +08:00
self . current_task . is_finished = True
self . _update_current_task ( ) # set to next task
2023-11-23 21:59:25 +08:00
2024-04-25 12:00:32 +08:00
def is_plan_finished ( self ) - > bool :
""" Check if all tasks are finished """
return all ( task . is_finished for task in self . tasks )
2023-11-23 21:59:25 +08:00
def get_finished_tasks ( self ) - > list [ Task ] :
""" return all finished tasks in correct linearized order
Returns :
list [ Task ] : list of finished tasks
"""
return [ task for task in self . tasks if task . is_finished ]
2024-06-03 17:57:24 +08:00
def append_task ( self , task_id : str , dependent_task_ids : list [ str ] , instruction : str , assignee : str ) :
""" Append a new task with task_id (number) to the end of existing task sequences. If dependent_task_ids is not empty, the task will depend on the tasks with the ids in the list. """
new_task = Task (
task_id = task_id , dependent_task_ids = dependent_task_ids , instruction = instruction , assignee = assignee
)
return self . _append_task ( new_task )
def replace_task ( self , task_id : str , new_dependent_task_ids : list [ str ] , new_instruction : str , new_assignee : str ) :
""" Replace an existing task (can be current task) based on task_id, and reset all tasks depending on it. """
new_task = Task (
task_id = task_id ,
dependent_task_ids = new_dependent_task_ids ,
instruction = new_instruction ,
assignee = new_assignee ,
)
return self . _replace_task ( new_task )
2023-11-23 21:59:25 +08:00
2023-12-19 14:22:52 +08:00
class MessageQueue ( BaseModel ) :
2023-11-01 20:33:34 +08:00
""" Message queue which supports asynchronous updates. """
2023-12-26 14:44:09 +08:00
model_config = ConfigDict ( arbitrary_types_allowed = True )
2023-12-19 14:22:52 +08:00
2023-12-26 14:44:09 +08:00
_queue : Queue = PrivateAttr ( default_factory = Queue )
2023-12-19 14:22:52 +08:00
2023-11-01 20:08:58 +08:00
def pop ( self ) - > Message | None :
2023-11-01 20:35:37 +08:00
""" Pop one message from the queue. """
2023-11-01 20:08:58 +08:00
try :
item = self . _queue . get_nowait ( )
if item :
self . _queue . task_done ( )
return item
except QueueEmpty :
return None
def pop_all ( self ) - > List [ Message ] :
2023-11-01 20:35:37 +08:00
""" Pop all messages from the queue. """
2023-11-01 20:08:58 +08:00
ret = [ ]
while True :
msg = self . pop ( )
if not msg :
break
ret . append ( msg )
return ret
def push ( self , msg : Message ) :
2023-11-01 20:33:34 +08:00
""" Push a message into the queue. """
2023-11-01 20:08:58 +08:00
self . _queue . put_nowait ( msg )
def empty ( self ) :
2023-11-01 20:33:34 +08:00
""" Return true if the queue is empty. """
2023-11-01 20:08:58 +08:00
return self . _queue . empty ( )
2023-11-04 14:26:48 +08:00
async def dump ( self ) - > str :
2023-11-01 20:33:34 +08:00
""" Convert the `MessageQueue` object to a json string. """
2023-11-01 20:08:58 +08:00
if self . empty ( ) :
return " [] "
lst = [ ]
2023-12-29 14:52:21 +08:00
msgs = [ ]
2023-11-01 20:08:58 +08:00
try :
while True :
item = await wait_for ( self . _queue . get ( ) , timeout = 1.0 )
if item is None :
break
2023-12-29 14:52:21 +08:00
msgs . append ( item )
lst . append ( item . dump ( ) )
2023-11-01 20:08:58 +08:00
self . _queue . task_done ( )
except asyncio . TimeoutError :
logger . debug ( " Queue is empty, exiting... " )
2023-12-29 14:52:21 +08:00
finally :
for m in msgs :
self . _queue . put_nowait ( m )
return json . dumps ( lst , ensure_ascii = False )
2023-11-01 20:08:58 +08:00
@staticmethod
2023-12-19 17:55:34 +08:00
def load ( data ) - > " MessageQueue " :
2023-11-01 20:33:34 +08:00
""" Convert the json string to the `MessageQueue` object. """
2023-12-19 16:16:52 +08:00
queue = MessageQueue ( )
2023-11-01 20:08:58 +08:00
try :
2023-12-19 17:55:34 +08:00
lst = json . loads ( data )
2023-11-01 20:08:58 +08:00
for i in lst :
2023-12-29 14:52:21 +08:00
msg = Message . load ( i )
2023-12-19 16:16:52 +08:00
queue . push ( msg )
2023-11-01 20:08:58 +08:00
except JSONDecodeError as e :
2023-12-19 17:55:34 +08:00
logger . warning ( f " JSON load failed: { data } , error: { e } " )
2023-11-01 20:08:58 +08:00
2023-12-19 16:16:52 +08:00
return queue
2023-11-23 17:49:38 +08:00
2023-12-19 16:16:52 +08:00
# 定义一个泛型类型变量
T = TypeVar ( " T " , bound = " BaseModel " )
2023-12-19 23:53:04 +08:00
class BaseContext ( BaseModel , ABC ) :
2023-12-19 16:31:38 +08:00
@classmethod
2023-12-19 16:16:52 +08:00
@handle_exception
2023-12-19 16:31:38 +08:00
def loads ( cls : Type [ T ] , val : str ) - > Optional [ T ] :
i = json . loads ( val )
return cls ( * * i )
2023-11-23 17:49:38 +08:00
2023-12-19 16:16:52 +08:00
class CodingContext ( BaseContext ) :
2023-11-23 17:49:38 +08:00
filename : str
2023-12-26 14:44:09 +08:00
design_doc : Optional [ Document ] = None
task_doc : Optional [ Document ] = None
code_doc : Optional [ Document ] = None
2024-02-04 17:23:00 +08:00
code_plan_and_change_doc : Optional [ Document ] = None
2023-11-23 22:41:44 +08:00
2023-12-19 16:16:52 +08:00
class TestingContext ( BaseContext ) :
2023-11-23 22:41:44 +08:00
filename : str
code_doc : Document
2023-12-26 14:44:09 +08:00
test_doc : Optional [ Document ] = None
2023-11-23 22:41:44 +08:00
2023-12-19 16:16:52 +08:00
class RunCodeContext ( BaseContext ) :
2023-11-23 22:41:44 +08:00
mode : str = " script "
2023-12-26 14:44:09 +08:00
code : Optional [ str ] = None
2023-11-23 22:41:44 +08:00
code_filename : str = " "
2023-12-26 14:44:09 +08:00
test_code : Optional [ str ] = None
2023-11-23 22:41:44 +08:00
test_filename : str = " "
command : List [ str ] = Field ( default_factory = list )
working_directory : str = " "
additional_python_paths : List [ str ] = Field ( default_factory = list )
2023-12-26 14:44:09 +08:00
output_filename : Optional [ str ] = None
output : Optional [ str ] = None
2023-11-24 13:30:00 +08:00
2023-11-24 19:56:27 +08:00
2023-12-19 16:16:52 +08:00
class RunCodeResult ( BaseContext ) :
2023-11-24 19:56:27 +08:00
summary : str
stdout : str
stderr : str
2023-11-28 18:16:50 +08:00
class CodeSummarizeContext ( BaseModel ) :
design_filename : str = " "
task_filename : str = " "
2023-12-04 23:04:07 +08:00
codes_filenames : List [ str ] = Field ( default_factory = list )
reason : str = " "
2023-11-28 18:16:50 +08:00
@staticmethod
2023-12-04 23:04:07 +08:00
def loads ( filenames : List ) - > CodeSummarizeContext :
2023-11-28 18:16:50 +08:00
ctx = CodeSummarizeContext ( )
for filename in filenames :
if Path ( filename ) . is_relative_to ( SYSTEM_DESIGN_FILE_REPO ) :
ctx . design_filename = str ( filename )
continue
if Path ( filename ) . is_relative_to ( TASK_FILE_REPO ) :
ctx . task_filename = str ( filename )
continue
return ctx
2023-12-04 23:04:07 +08:00
def __hash__ ( self ) :
return hash ( ( self . design_filename , self . task_filename ) )
2023-12-12 21:32:03 +08:00
2024-01-23 19:11:58 +08:00
class CodePlanAndChangeContext ( BaseModel ) :
requirement : str = " "
2024-03-25 17:09:02 +08:00
issue : str = " "
2024-01-23 19:11:58 +08:00
prd_filename : str = " "
design_filename : str = " "
task_filename : str = " "
2024-01-19 19:53:17 +08:00
2024-01-02 23:09:09 +08:00
# mermaid class view
2024-01-22 22:49:46 +08:00
class UMLClassMeta ( BaseModel ) :
2024-01-02 23:09:09 +08:00
name : str = " "
visibility : str = " "
2024-01-22 22:49:46 +08:00
@staticmethod
def name_to_visibility ( name : str ) - > str :
if name == " __init__ " :
return " + "
if name . startswith ( " __ " ) :
return " - "
elif name . startswith ( " _ " ) :
return " # "
return " + "
2024-01-02 23:09:09 +08:00
2024-01-22 22:49:46 +08:00
class UMLClassAttribute ( UMLClassMeta ) :
2024-01-02 23:09:09 +08:00
value_type : str = " "
default_value : str = " "
def get_mermaid ( self , align = 1 ) - > str :
content = " " . join ( [ " \t " for i in range ( align ) ] ) + self . visibility
if self . value_type :
2024-01-22 22:49:46 +08:00
content + = self . value_type . replace ( " " , " " ) + " "
name = self . name . split ( " : " , 1 ) [ 1 ] if " : " in self . name else self . name
content + = name
2024-01-02 23:09:09 +08:00
if self . default_value :
content + = " = "
if self . value_type not in [ " str " , " string " , " String " ] :
content + = self . default_value
else :
content + = ' " ' + self . default_value . replace ( ' " ' , " " ) + ' " '
2024-01-22 22:49:46 +08:00
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
2024-01-02 23:09:09 +08:00
return content
2024-01-22 22:49:46 +08:00
class UMLClassMethod ( UMLClassMeta ) :
args : List [ UMLClassAttribute ] = Field ( default_factory = list )
2024-01-02 23:09:09 +08:00
return_type : str = " "
def get_mermaid ( self , align = 1 ) - > str :
content = " " . join ( [ " \t " for i in range ( align ) ] ) + self . visibility
2024-01-22 22:49:46 +08:00
name = self . name . split ( " : " , 1 ) [ 1 ] if " : " in self . name else self . name
content + = name + " ( " + " , " . join ( [ v . get_mermaid ( align = 0 ) for v in self . args ] ) + " ) "
2024-01-02 23:09:09 +08:00
if self . return_type :
2024-01-22 22:49:46 +08:00
content + = " " + self . return_type . replace ( " " , " " )
2024-02-01 20:19:52 +08:00
# if self.abstraction:
# content += "*"
# if self.static:
# content += "$"
2024-01-02 23:09:09 +08:00
return content
2024-01-22 22:49:46 +08:00
class UMLClassView ( UMLClassMeta ) :
attributes : List [ UMLClassAttribute ] = Field ( default_factory = list )
methods : List [ UMLClassMethod ] = Field ( default_factory = list )
2024-01-02 23:09:09 +08:00
def get_mermaid ( self , align = 1 ) - > str :
content = " " . join ( [ " \t " for i in range ( align ) ] ) + " class " + self . name + " { \n "
for v in self . attributes :
content + = v . get_mermaid ( align = align + 1 ) + " \n "
for v in self . methods :
content + = v . get_mermaid ( align = align + 1 ) + " \n "
content + = " " . join ( [ " \t " for i in range ( align ) ] ) + " } \n "
return content
2024-01-26 19:39:06 +08:00
@classmethod
def load_dot_class_info ( cls , dot_class_info : DotClassInfo ) - > UMLClassView :
visibility = UMLClassView . name_to_visibility ( dot_class_info . name )
class_view = cls ( name = dot_class_info . name , visibility = visibility )
for i in dot_class_info . attributes . values ( ) :
visibility = UMLClassAttribute . name_to_visibility ( i . name )
attr = UMLClassAttribute ( name = i . name , visibility = visibility , value_type = i . type_ , default_value = i . default_ )
class_view . attributes . append ( attr )
for i in dot_class_info . methods . values ( ) :
visibility = UMLClassMethod . name_to_visibility ( i . name )
method = UMLClassMethod ( name = i . name , visibility = visibility , return_type = i . return_args . type_ )
for j in i . args :
arg = UMLClassAttribute ( name = j . name , value_type = j . type_ , default_value = j . default_ )
method . args . append ( arg )
2024-02-19 13:08:14 +08:00
method . return_type = i . return_args . type_
class_view . methods . append ( method )
2024-01-26 19:39:06 +08:00
return class_view
2024-03-31 10:53:08 +08:00
class BaseEnum ( Enum ) :
""" Base class for enums. """
def __new__ ( cls , value , desc = None ) :
"""
Construct an instance of the enum member .
Args :
cls : The class .
value : The value of the enum member .
desc : The description of the enum member . Defaults to None .
"""
if issubclass ( cls , str ) :
obj = str . __new__ ( cls , value )
elif issubclass ( cls , int ) :
obj = int . __new__ ( cls , value )
else :
obj = object . __new__ ( cls )
obj . _value_ = value
obj . desc = desc
return obj