MetaGPT/metagpt/schema.py

511 lines
16 KiB
Python
Raw Normal View History

2023-06-30 17:10:48 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/8 22:12
@Author : alexanderwu
@File : schema.py
2023-11-03 11:53:47 +08:00
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
2023-11-27 16:15:55 +08:00
@Modified By: mashenquan, 2023/11/22.
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
between actions.
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
2023-06-30 17:10:48 +08:00
"""
2023-07-22 11:28:22 +08:00
2023-06-30 17:10:48 +08:00
from __future__ import annotations
2023-07-22 11:28:22 +08:00
import asyncio
2023-10-31 15:23:37 +08:00
import json
2023-11-22 17:08:00 +08:00
import os.path
import uuid
from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
2023-10-31 15:23:37 +08:00
from json import JSONDecodeError
2023-11-28 18:16:50 +08:00
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_serializer,
field_validator,
model_serializer,
model_validator,
)
2023-06-30 17:10:48 +08:00
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
2023-11-28 18:16:50 +08:00
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
2023-07-22 11:28:22 +08:00
from metagpt.logs import logger
2023-12-20 10:54:49 +08:00
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.serialize import (
actionoutout_schema_to_mapping,
actionoutput_mapping_to_str,
actionoutput_str_to_mapping,
)
2023-06-30 17:10:48 +08:00
class SerializationMixin(BaseModel, extra="forbid"):
2024-01-02 15:26:23 +08:00
"""
PolyMorphic subclasses Serialization / Deserialization Mixin
- First of all, we need to know that pydantic is not designed for polymorphism.
- If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need
to add `class name` to Engineer. So we need Engineer inherit SerializationMixin.
More details:
- https://docs.pydantic.dev/latest/concepts/serialization/
- https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__`
"""
2023-12-28 16:07:39 +08:00
__is_polymorphic_base = False
__subclasses_map__ = {}
@model_serializer(mode="wrap")
def __serialize_with_class_type__(self, default_serializer) -> Any:
# default serializer, then append the `__module_class_name` field and return
ret = default_serializer(self)
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
2023-12-28 16:07:39 +08:00
return ret
@model_validator(mode="wrap")
2023-12-28 16:07:39 +08:00
@classmethod
def __convert_to_real_type__(cls, value: Any, handler):
if isinstance(value, dict) is False:
return handler(value)
# it is a dict so make sure to remove the __module_class_name
# because we don't allow extra keywords but want to ensure
# e.g Cat.model_validate(cat.model_dump()) works
class_full_name = value.pop("__module_class_name", None)
# if it's not the polymorphic base we construct via default handler
if not cls.__is_polymorphic_base:
if class_full_name is None:
return handler(value)
elif str(cls) == f"<class '{class_full_name}'>":
return handler(value)
else:
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
pass
# otherwise we lookup the correct polymorphic type and construct that
# instead
if class_full_name is None:
raise ValueError("Missing __module_class_name field")
class_type = cls.__subclasses_map__.get(class_full_name, None)
2023-12-28 16:07:39 +08:00
if class_type is None:
# TODO could try dynamic import
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
2023-12-28 16:07:39 +08:00
return class_type(**value)
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
cls.__is_polymorphic_base = is_polymorphic_base
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
super().__init_subclass__(**kwargs)
class SimpleMessage(BaseModel):
2023-06-30 17:10:48 +08:00
content: str
role: str
2023-11-22 17:08:00 +08:00
class Document(BaseModel):
"""
Represents a document.
"""
root_path: str = ""
filename: str = ""
content: str = ""
2023-11-22 17:08:00 +08:00
def get_meta(self) -> Document:
"""Get metadata of the document.
:return: A new Document instance with the same root path and filename.
"""
return Document(root_path=self.root_path, filename=self.filename)
@property
def root_relative_path(self):
"""Get relative path from root of git repository.
:return: relative path from root of git repository.
"""
return os.path.join(self.root_path, self.filename)
2023-12-15 00:37:10 +08:00
def __str__(self):
return self.content
def __repr__(self):
return self.content
2023-11-22 17:08:00 +08:00
class Documents(BaseModel):
"""A class representing a collection of documents.
Attributes:
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
"""
docs: Dict[str, Document] = Field(default_factory=dict)
2023-10-31 15:23:37 +08:00
class Message(BaseModel):
2023-06-30 17:10:48 +08:00
"""list[<role>: <content>]"""
2023-10-31 15:23:37 +08:00
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
2023-06-30 17:10:48 +08:00
content: str
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
2023-11-08 20:27:18 +08:00
role: str = "user" # system / user / assistant
cause_by: str = Field(default="", validate_default=True)
sent_from: str = Field(default="", validate_default=True)
2023-12-29 04:27:44 +08:00
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
@field_validator("id", mode="before")
@classmethod
def check_id(cls, id: str) -> str:
return id if id else uuid.uuid4().hex
@field_validator("instruct_content", mode="before")
@classmethod
def check_instruct_content(cls, ic: Any) -> BaseModel:
if ic and isinstance(ic, dict) and "class" in ic:
if "mapping" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
elif "module" in ic:
2024-01-09 16:07:33 +08:00
# subclasses of BaseModel
ic_obj = import_class(ic["class"], ic["module"])
else:
raise KeyError("missing required key to init Message.instruct_content from dict")
ic = ic_obj(**ic["value"])
return ic
@field_validator("cause_by", mode="before")
@classmethod
def check_cause_by(cls, cause_by: Any) -> str:
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
@field_validator("sent_from", mode="before")
@classmethod
def check_sent_from(cls, sent_from: Any) -> str:
return any_to_str(sent_from if sent_from else "")
@field_validator("send_to", mode="before")
@classmethod
def check_send_to(cls, send_to: Any) -> set:
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
ic_dict = None
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
ic_type = str(type(ic))
if "<class 'metagpt.actions.action_node" in ic_type:
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
else:
# due to instruct_content can be assigned by subclasses of BaseModel
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
return ic_dict
def __init__(self, content: str = "", **data: Any):
data["content"] = data.get("content", content)
super().__init__(**data)
2023-10-31 15:23:37 +08:00
2023-11-04 16:20:47 +08:00
def __setattr__(self, key, val):
2023-11-08 20:27:18 +08:00
"""Override `@property.setter`, convert non-string parameters into string parameters."""
2023-11-04 16:20:47 +08:00
if key == MESSAGE_ROUTE_CAUSE_BY:
2023-11-08 20:27:18 +08:00
new_val = any_to_str(val)
elif key == MESSAGE_ROUTE_FROM:
new_val = any_to_str(val)
elif key == MESSAGE_ROUTE_TO:
new_val = any_to_str_set(val)
else:
new_val = val
super().__setattr__(key, new_val)
2023-06-30 17:10:48 +08:00
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
2023-12-22 16:40:04 +08:00
if self.instruct_content:
2023-12-26 14:44:09 +08:00
return f"{self.role}: {self.instruct_content.model_dump()}"
2023-06-30 17:10:48 +08:00
return f"{self.role}: {self.content}"
def __repr__(self):
return self.__str__()
def to_dict(self) -> dict:
"""Return a dict containing `role` and `content` for the LLM call.l"""
2023-10-31 15:23:37 +08:00
return {"role": self.role, "content": self.content}
2023-11-04 14:26:48 +08:00
def dump(self) -> str:
"""Convert the object to json string"""
return self.model_dump_json(exclude_none=True, warnings=False)
2023-10-31 15:23:37 +08:00
@staticmethod
@handle_exception(exception_type=JSONDecodeError, default_return=None)
2023-11-08 20:27:18 +08:00
def load(val):
"""Convert the json string to object."""
2023-12-22 16:40:04 +08:00
2023-10-31 15:23:37 +08:00
try:
2023-12-19 10:44:06 +08:00
m = json.loads(val)
id = m.get("id")
if "id" in m:
del m["id"]
msg = Message(**m)
if id:
msg.id = id
return msg
2023-10-31 15:23:37 +08:00
except JSONDecodeError as err:
2023-11-08 20:27:18 +08:00
logger.error(f"parse json failed: {val}, error:{err}")
2023-10-31 15:23:37 +08:00
return None
2023-06-30 17:10:48 +08:00
class UserMessage(Message):
"""便于支持OpenAI的消息
2023-10-31 15:23:37 +08:00
Facilitate support for OpenAI messages
"""
2023-10-31 15:23:37 +08:00
2023-06-30 17:10:48 +08:00
def __init__(self, content: str):
super().__init__(content=content, role="user")
2023-06-30 17:10:48 +08:00
class SystemMessage(Message):
"""便于支持OpenAI的消息
2023-10-31 15:23:37 +08:00
Facilitate support for OpenAI messages
"""
2023-10-31 15:23:37 +08:00
2023-06-30 17:10:48 +08:00
def __init__(self, content: str):
super().__init__(content=content, role="system")
2023-06-30 17:10:48 +08:00
class AIMessage(Message):
"""便于支持OpenAI的消息
2023-10-31 15:23:37 +08:00
Facilitate support for OpenAI messages
"""
2023-10-31 15:23:37 +08:00
2023-06-30 17:10:48 +08:00
def __init__(self, content: str):
super().__init__(content=content, role="assistant")
2023-12-19 14:22:52 +08:00
class MessageQueue(BaseModel):
"""Message queue which supports asynchronous updates."""
2023-12-26 14:44:09 +08:00
model_config = ConfigDict(arbitrary_types_allowed=True)
2023-12-19 14:22:52 +08:00
2023-12-26 14:44:09 +08:00
_queue: Queue = PrivateAttr(default_factory=Queue)
2023-12-19 14:22:52 +08:00
def pop(self) -> Message | None:
"""Pop one message from the queue."""
try:
item = self._queue.get_nowait()
if item:
self._queue.task_done()
return item
except QueueEmpty:
return None
def pop_all(self) -> List[Message]:
"""Pop all messages from the queue."""
ret = []
while True:
msg = self.pop()
if not msg:
break
ret.append(msg)
return ret
def push(self, msg: Message):
"""Push a message into the queue."""
self._queue.put_nowait(msg)
def empty(self):
"""Return true if the queue is empty."""
return self._queue.empty()
2023-11-04 14:26:48 +08:00
async def dump(self) -> str:
"""Convert the `MessageQueue` object to a json string."""
if self.empty():
return "[]"
lst = []
2023-12-29 14:52:21 +08:00
msgs = []
try:
while True:
item = await wait_for(self._queue.get(), timeout=1.0)
if item is None:
break
2023-12-29 14:52:21 +08:00
msgs.append(item)
lst.append(item.dump())
self._queue.task_done()
except asyncio.TimeoutError:
logger.debug("Queue is empty, exiting...")
2023-12-29 14:52:21 +08:00
finally:
for m in msgs:
self._queue.put_nowait(m)
return json.dumps(lst, ensure_ascii=False)
@staticmethod
2023-12-19 17:55:34 +08:00
def load(data) -> "MessageQueue":
"""Convert the json string to the `MessageQueue` object."""
queue = MessageQueue()
try:
2023-12-19 17:55:34 +08:00
lst = json.loads(data)
for i in lst:
2023-12-29 14:52:21 +08:00
msg = Message.load(i)
queue.push(msg)
except JSONDecodeError as e:
2023-12-19 17:55:34 +08:00
logger.warning(f"JSON load failed: {data}, error:{e}")
return queue
# 定义一个泛型类型变量
T = TypeVar("T", bound="BaseModel")
class BaseContext(BaseModel, ABC):
2023-12-19 16:31:38 +08:00
@classmethod
@handle_exception
2023-12-19 16:31:38 +08:00
def loads(cls: Type[T], val: str) -> Optional[T]:
i = json.loads(val)
return cls(**i)
class CodingContext(BaseContext):
filename: str
2023-12-26 14:44:09 +08:00
design_doc: Optional[Document] = None
task_doc: Optional[Document] = None
code_doc: Optional[Document] = None
class TestingContext(BaseContext):
filename: str
code_doc: Document
2023-12-26 14:44:09 +08:00
test_doc: Optional[Document] = None
class RunCodeContext(BaseContext):
mode: str = "script"
2023-12-26 14:44:09 +08:00
code: Optional[str] = None
code_filename: str = ""
2023-12-26 14:44:09 +08:00
test_code: Optional[str] = None
test_filename: str = ""
command: List[str] = Field(default_factory=list)
working_directory: str = ""
additional_python_paths: List[str] = Field(default_factory=list)
2023-12-26 14:44:09 +08:00
output_filename: Optional[str] = None
output: Optional[str] = None
2023-11-24 13:30:00 +08:00
2023-11-24 19:56:27 +08:00
class RunCodeResult(BaseContext):
2023-11-24 19:56:27 +08:00
summary: str
stdout: str
stderr: str
2023-11-28 18:16:50 +08:00
class CodeSummarizeContext(BaseModel):
design_filename: str = ""
task_filename: str = ""
codes_filenames: List[str] = Field(default_factory=list)
reason: str = ""
2023-11-28 18:16:50 +08:00
@staticmethod
def loads(filenames: List) -> CodeSummarizeContext:
2023-11-28 18:16:50 +08:00
ctx = CodeSummarizeContext()
for filename in filenames:
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
ctx.design_filename = str(filename)
continue
if Path(filename).is_relative_to(TASK_FILE_REPO):
ctx.task_filename = str(filename)
continue
return ctx
def __hash__(self):
return hash((self.design_filename, self.task_filename))
class BugFixContext(BaseContext):
filename: str = ""
2024-01-02 23:09:09 +08:00
# mermaid class view
class ClassMeta(BaseModel):
name: str = ""
abstraction: bool = False
static: bool = False
visibility: str = ""
class ClassAttribute(ClassMeta):
value_type: str = ""
default_value: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
if self.value_type:
content += self.value_type + " "
content += self.name
if self.default_value:
content += "="
if self.value_type not in ["str", "string", "String"]:
content += self.default_value
else:
content += '"' + self.default_value.replace('"', "") + '"'
if self.abstraction:
content += "*"
if self.static:
content += "$"
return content
class ClassMethod(ClassMeta):
args: List[ClassAttribute] = Field(default_factory=list)
return_type: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
if self.return_type:
content += ":" + self.return_type
if self.abstraction:
content += "*"
if self.static:
content += "$"
return content
class ClassView(ClassMeta):
attributes: List[ClassAttribute] = Field(default_factory=list)
methods: List[ClassMethod] = Field(default_factory=list)
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
for v in self.attributes:
content += v.get_mermaid(align=align + 1) + "\n"
for v in self.methods:
content += v.get_mermaid(align=align + 1) + "\n"
content += "".join(["\t" for i in range(align)]) + "}\n"
return content