Merge branch 'dev' into code_intepreter

This commit is contained in:
yzlin 2024-01-31 00:08:09 +08:00
commit 2fcb2a1cfe
282 changed files with 6993 additions and 3210 deletions

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -32,15 +32,17 @@ from pydantic import (
PrivateAttr,
field_serializer,
field_validator,
model_serializer,
model_validator,
)
from pydantic_core import core_schema
from metagpt.config import CONFIG
from metagpt.const import (
CODE_PLAN_AND_CHANGE_FILENAME,
MESSAGE_ROUTE_CAUSE_BY,
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
PRDS_FILE_REPO,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
@ -54,7 +56,7 @@ from metagpt.utils.serialize import (
)
class SerializationMixin(BaseModel):
class SerializationMixin(BaseModel, extra="forbid"):
"""
PolyMorphic subclasses Serialization / Deserialization Mixin
- First of all, we need to know that pydantic is not designed for polymorphism.
@ -69,49 +71,44 @@ class SerializationMixin(BaseModel):
__is_polymorphic_base = False
__subclasses_map__ = {}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(source)
og_schema_ref = schema["ref"]
schema["ref"] += ":mixin"
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__,
schema=schema,
ref=og_schema_ref,
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
)
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
ret = handler(value)
if not len(cls.__subclasses__()):
# only subclass add `__module_class_name`
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
@model_serializer(mode="wrap")
def __serialize_with_class_type__(self, default_serializer) -> Any:
# default serializer, then append the `__module_class_name` field and return
ret = default_serializer(self)
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
return ret
@model_validator(mode="wrap")
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
if not isinstance(value, dict):
return value
def __convert_to_real_type__(cls, value: Any, handler):
if isinstance(value, dict) is False:
return handler(value)
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
# add right condition to init BaseClass like Action()
return value
module_class_name = value.get("__module_class_name", None)
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
# it is a dict so make sure to remove the __module_class_name
# because we don't allow extra keywords but want to ensure
# e.g Cat.model_validate(cat.model_dump()) works
class_full_name = value.pop("__module_class_name", None)
class_type = cls.__subclasses_map__.get(module_class_name, None)
# if it's not the polymorphic base we construct via default handler
if not cls.__is_polymorphic_base:
if class_full_name is None:
return handler(value)
elif str(cls) == f"<class '{class_full_name}'>":
return handler(value)
else:
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
pass
# otherwise we lookup the correct polymorphic type and construct that
# instead
if class_full_name is None:
raise ValueError("Missing __module_class_name field")
class_type = cls.__subclasses_map__.get(class_full_name, None)
if class_type is None:
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
# TODO could try dynamic import
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
return class_type(**value)
@ -151,12 +148,6 @@ class Document(BaseModel):
"""
return os.path.join(self.root_path, self.filename)
@property
def full_path(self):
if not CONFIG.git_repo:
return None
return str(CONFIG.git_repo.workdir / self.root_path / self.filename)
def __str__(self):
return self.content
@ -173,6 +164,26 @@ class Documents(BaseModel):
docs: Dict[str, Document] = Field(default_factory=dict)
@classmethod
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
"""Create a Documents instance from a list of Document instances.
:param documents: A list of Document instances.
:return: A Documents instance.
"""
docs = {doc.filename: doc for doc in documents}
return Documents(docs=docs)
def to_action_output(self) -> "ActionOutput":
"""Convert to action output string.
:return: A string representing action output.
"""
from metagpt.actions.action_output import ActionOutput
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
class Message(BaseModel):
"""list[<role>: <content>]"""
@ -193,12 +204,17 @@ class Message(BaseModel):
@field_validator("instruct_content", mode="before")
@classmethod
def check_instruct_content(cls, ic: Any) -> BaseModel:
if ic and not isinstance(ic, BaseModel) and "class" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
if ic and isinstance(ic, dict) and "class" in ic:
if "mapping" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
elif "module" in ic:
# subclasses of BaseModel
ic_obj = import_class(ic["class"], ic["module"])
else:
raise KeyError("missing required key to init Message.instruct_content from dict")
ic = ic_obj(**ic["value"])
return ic
@ -218,18 +234,21 @@ class Message(BaseModel):
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
ic_dict = None
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
ic_type = str(type(ic))
if "<class 'metagpt.actions.action_node" in ic_type:
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
else:
# due to instruct_content can be assigned by subclasses of BaseModel
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
return ic_dict
def __init__(self, content: str = "", **data: Any):
@ -647,6 +666,30 @@ class BugFixContext(BaseContext):
filename: str = ""
class CodePlanAndChangeContext(BaseModel):
filename: str = CODE_PLAN_AND_CHANGE_FILENAME
requirement: str = ""
prd_filename: str = ""
design_filename: str = ""
task_filename: str = ""
@staticmethod
def loads(filenames: List, **kwargs) -> CodePlanAndChangeContext:
ctx = CodePlanAndChangeContext(requirement=kwargs.get("requirement", ""))
for filename in filenames:
filename = Path(filename)
if filename.is_relative_to(PRDS_FILE_REPO):
ctx.prd_filename = filename.name
continue
if filename.is_relative_to(SYSTEM_DESIGN_FILE_REPO):
ctx.design_filename = filename.name
continue
if filename.is_relative_to(TASK_FILE_REPO):
ctx.task_filename = filename.name
continue
return ctx
# mermaid class view
class ClassMeta(BaseModel):
name: str = ""