mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 12:22:39 +02:00
Merge branch 'main' into dev_updated
This commit is contained in:
commit
853086924a
429 changed files with 24237 additions and 5835 deletions
|
|
@ -4,71 +4,309 @@
|
|||
@Time : 2023/5/8 22:12
|
||||
@Author : alexanderwu
|
||||
@File : schema.py
|
||||
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
|
||||
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
|
||||
@Modified By: mashenquan, 2023/11/22.
|
||||
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
|
||||
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
|
||||
between actions.
|
||||
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, TypedDict
|
||||
import asyncio
|
||||
import json
|
||||
import os.path
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
MESSAGE_ROUTE_TO_ALL,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
actionoutput_mapping_to_str,
|
||||
actionoutput_str_to_mapping,
|
||||
)
|
||||
|
||||
|
||||
class RawMessage(TypedDict):
|
||||
class SerializationMixin(BaseModel):
|
||||
"""
|
||||
PolyMorphic subclasses Serialization / Deserialization Mixin
|
||||
- First of all, we need to know that pydantic is not designed for polymorphism.
|
||||
- If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need
|
||||
to add `class name` to Engineer. So we need Engineer inherit SerializationMixin.
|
||||
|
||||
More details:
|
||||
- https://docs.pydantic.dev/latest/concepts/serialization/
|
||||
- https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__`
|
||||
"""
|
||||
|
||||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source)
|
||||
og_schema_ref = schema["ref"]
|
||||
schema["ref"] += ":mixin"
|
||||
|
||||
return core_schema.no_info_before_validator_function(
|
||||
cls.__deserialize_with_real_type__,
|
||||
schema=schema,
|
||||
ref=og_schema_ref,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __serialize_add_class_type__(
|
||||
cls,
|
||||
value,
|
||||
handler: core_schema.SerializerFunctionWrapHandler,
|
||||
) -> Any:
|
||||
ret = handler(value)
|
||||
if not len(cls.__subclasses__()):
|
||||
# only subclass add `__module_class_name`
|
||||
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def __deserialize_with_real_type__(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
|
||||
# add right condition to init BaseClass like Action()
|
||||
return value
|
||||
module_class_name = value.get("__module_class_name", None)
|
||||
if module_class_name is None:
|
||||
raise ValueError("Missing field: __module_class_name")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(module_class_name, None)
|
||||
|
||||
if class_type is None:
|
||||
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
||||
cls.__is_polymorphic_base = is_polymorphic_base
|
||||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
class Document(BaseModel):
|
||||
"""
|
||||
Represents a document.
|
||||
"""
|
||||
|
||||
root_path: str = ""
|
||||
filename: str = ""
|
||||
content: str = ""
|
||||
|
||||
def get_meta(self) -> Document:
|
||||
"""Get metadata of the document.
|
||||
|
||||
:return: A new Document instance with the same root path and filename.
|
||||
"""
|
||||
|
||||
return Document(root_path=self.root_path, filename=self.filename)
|
||||
|
||||
@property
|
||||
def root_relative_path(self):
|
||||
"""Get relative path from root of git repository.
|
||||
|
||||
:return: relative path from root of git repository.
|
||||
"""
|
||||
return os.path.join(self.root_path, self.filename)
|
||||
|
||||
@property
|
||||
def full_path(self):
|
||||
if not CONFIG.git_repo:
|
||||
return None
|
||||
return str(CONFIG.git_repo.workdir / self.root_path / self.filename)
|
||||
|
||||
def __str__(self):
|
||||
return self.content
|
||||
|
||||
def __repr__(self):
|
||||
return self.content
|
||||
|
||||
|
||||
class Documents(BaseModel):
|
||||
"""A class representing a collection of documents.
|
||||
|
||||
Attributes:
|
||||
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
|
||||
"""
|
||||
|
||||
docs: Dict[str, Document] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
||||
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
|
||||
content: str
|
||||
instruct_content: BaseModel = field(default=None)
|
||||
role: str = field(default='user') # system / user / assistant
|
||||
cause_by: Type["Action"] = field(default="")
|
||||
sent_from: str = field(default="")
|
||||
send_to: str = field(default="")
|
||||
restricted_to: str = field(default="")
|
||||
state: str = None # None, done, todo, doing, error
|
||||
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
|
||||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = Field(default="", validate_default=True)
|
||||
sent_from: str = Field(default="", validate_default=True)
|
||||
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def check_id(cls, id: str) -> str:
|
||||
return id if id else uuid.uuid4().hex
|
||||
|
||||
@field_validator("instruct_content", mode="before")
|
||||
@classmethod
|
||||
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
||||
if ic and not isinstance(ic, BaseModel) and "class" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
ic = ic_obj(**ic["value"])
|
||||
return ic
|
||||
|
||||
@field_validator("cause_by", mode="before")
|
||||
@classmethod
|
||||
def check_cause_by(cls, cause_by: Any) -> str:
|
||||
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
||||
|
||||
@field_validator("sent_from", mode="before")
|
||||
@classmethod
|
||||
def check_sent_from(cls, sent_from: Any) -> str:
|
||||
return any_to_str(sent_from if sent_from else "")
|
||||
|
||||
@field_validator("send_to", mode="before")
|
||||
@classmethod
|
||||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
|
||||
ic_dict = None
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.model_json_schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
return ic_dict
|
||||
|
||||
def __init__(self, content: str = "", **data: Any):
|
||||
data["content"] = data.get("content", content)
|
||||
super().__init__(**data)
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
||||
if key == MESSAGE_ROUTE_CAUSE_BY:
|
||||
new_val = any_to_str(val)
|
||||
elif key == MESSAGE_ROUTE_FROM:
|
||||
new_val = any_to_str(val)
|
||||
elif key == MESSAGE_ROUTE_TO:
|
||||
new_val = any_to_str_set(val)
|
||||
else:
|
||||
new_val = val
|
||||
super().__setattr__(key, new_val)
|
||||
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
if self.instruct_content:
|
||||
return f"{self.role}: {self.instruct_content.model_dump()}"
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content
|
||||
}
|
||||
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
def dump(self) -> str:
|
||||
"""Convert the object to json string"""
|
||||
return self.model_dump_json(exclude_none=True, warnings=False)
|
||||
|
||||
@staticmethod
|
||||
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
||||
def load(val):
|
||||
"""Convert the json string to object."""
|
||||
|
||||
try:
|
||||
m = json.loads(val)
|
||||
id = m.get("id")
|
||||
if "id" in m:
|
||||
del m["id"]
|
||||
msg = Message(**m)
|
||||
if id:
|
||||
msg.id = id
|
||||
return msg
|
||||
except JSONDecodeError as err:
|
||||
logger.error(f"parse json failed: {val}, error:{err}")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content, 'user')
|
||||
super().__init__(content=content, role="user")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content, 'system')
|
||||
super().__init__(content=content, role="system")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content, 'assistant')
|
||||
|
||||
|
|
@ -266,12 +504,203 @@ class Plan(BaseModel):
|
|||
return [task for task in self.tasks if task.is_finished]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_content = 'test_message'
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
]
|
||||
logger.info(msgs)
|
||||
class MessageQueue(BaseModel):
|
||||
"""Message queue which supports asynchronous updates."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_queue: Queue = PrivateAttr(default_factory=Queue)
|
||||
|
||||
def pop(self) -> Message | None:
|
||||
"""Pop one message from the queue."""
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
if item:
|
||||
self._queue.task_done()
|
||||
return item
|
||||
except QueueEmpty:
|
||||
return None
|
||||
|
||||
def pop_all(self) -> List[Message]:
|
||||
"""Pop all messages from the queue."""
|
||||
ret = []
|
||||
while True:
|
||||
msg = self.pop()
|
||||
if not msg:
|
||||
break
|
||||
ret.append(msg)
|
||||
return ret
|
||||
|
||||
def push(self, msg: Message):
|
||||
"""Push a message into the queue."""
|
||||
self._queue.put_nowait(msg)
|
||||
|
||||
def empty(self):
|
||||
"""Return true if the queue is empty."""
|
||||
return self._queue.empty()
|
||||
|
||||
async def dump(self) -> str:
|
||||
"""Convert the `MessageQueue` object to a json string."""
|
||||
if self.empty():
|
||||
return "[]"
|
||||
|
||||
lst = []
|
||||
msgs = []
|
||||
try:
|
||||
while True:
|
||||
item = await wait_for(self._queue.get(), timeout=1.0)
|
||||
if item is None:
|
||||
break
|
||||
msgs.append(item)
|
||||
lst.append(item.dump())
|
||||
self._queue.task_done()
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Queue is empty, exiting...")
|
||||
finally:
|
||||
for m in msgs:
|
||||
self._queue.put_nowait(m)
|
||||
return json.dumps(lst, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def load(data) -> "MessageQueue":
|
||||
"""Convert the json string to the `MessageQueue` object."""
|
||||
queue = MessageQueue()
|
||||
try:
|
||||
lst = json.loads(data)
|
||||
for i in lst:
|
||||
msg = Message.load(i)
|
||||
queue.push(msg)
|
||||
except JSONDecodeError as e:
|
||||
logger.warning(f"JSON load failed: {data}, error:{e}")
|
||||
|
||||
return queue
|
||||
|
||||
|
||||
# 定义一个泛型类型变量
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class BaseContext(BaseModel, ABC):
|
||||
@classmethod
|
||||
@handle_exception
|
||||
def loads(cls: Type[T], val: str) -> Optional[T]:
|
||||
i = json.loads(val)
|
||||
return cls(**i)
|
||||
|
||||
|
||||
class CodingContext(BaseContext):
|
||||
filename: str
|
||||
design_doc: Optional[Document] = None
|
||||
task_doc: Optional[Document] = None
|
||||
code_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class TestingContext(BaseContext):
|
||||
filename: str
|
||||
code_doc: Document
|
||||
test_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class RunCodeContext(BaseContext):
|
||||
mode: str = "script"
|
||||
code: Optional[str] = None
|
||||
code_filename: str = ""
|
||||
test_code: Optional[str] = None
|
||||
test_filename: str = ""
|
||||
command: List[str] = Field(default_factory=list)
|
||||
working_directory: str = ""
|
||||
additional_python_paths: List[str] = Field(default_factory=list)
|
||||
output_filename: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
|
||||
|
||||
class RunCodeResult(BaseContext):
|
||||
summary: str
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
class CodeSummarizeContext(BaseModel):
|
||||
design_filename: str = ""
|
||||
task_filename: str = ""
|
||||
codes_filenames: List[str] = Field(default_factory=list)
|
||||
reason: str = ""
|
||||
|
||||
@staticmethod
|
||||
def loads(filenames: List) -> CodeSummarizeContext:
|
||||
ctx = CodeSummarizeContext()
|
||||
for filename in filenames:
|
||||
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
||||
ctx.design_filename = str(filename)
|
||||
continue
|
||||
if Path(filename).is_relative_to(TASK_FILE_REPO):
|
||||
ctx.task_filename = str(filename)
|
||||
continue
|
||||
return ctx
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.design_filename, self.task_filename))
|
||||
|
||||
|
||||
class BugFixContext(BaseContext):
|
||||
filename: str = ""
|
||||
|
||||
|
||||
# mermaid class view
|
||||
class ClassMeta(BaseModel):
|
||||
name: str = ""
|
||||
abstraction: bool = False
|
||||
static: bool = False
|
||||
visibility: str = ""
|
||||
|
||||
|
||||
class ClassAttribute(ClassMeta):
|
||||
value_type: str = ""
|
||||
default_value: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
if self.value_type:
|
||||
content += self.value_type + " "
|
||||
content += self.name
|
||||
if self.default_value:
|
||||
content += "="
|
||||
if self.value_type not in ["str", "string", "String"]:
|
||||
content += self.default_value
|
||||
else:
|
||||
content += '"' + self.default_value.replace('"', "") + '"'
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassMethod(ClassMeta):
|
||||
args: List[ClassAttribute] = Field(default_factory=list)
|
||||
return_type: str = ""
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + self.visibility
|
||||
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
|
||||
if self.return_type:
|
||||
content += ":" + self.return_type
|
||||
if self.abstraction:
|
||||
content += "*"
|
||||
if self.static:
|
||||
content += "$"
|
||||
return content
|
||||
|
||||
|
||||
class ClassView(ClassMeta):
|
||||
attributes: List[ClassAttribute] = Field(default_factory=list)
|
||||
methods: List[ClassMethod] = Field(default_factory=list)
|
||||
|
||||
def get_mermaid(self, align=1) -> str:
|
||||
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
|
||||
for v in self.attributes:
|
||||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
for v in self.methods:
|
||||
content += v.get_mermaid(align=align + 1) + "\n"
|
||||
content += "".join(["\t" for i in range(align)]) + "}\n"
|
||||
return content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue