#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2023/5/8 22:12 @Author : alexanderwu @File : schema.py @Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. @Modified By: mashenquan, 2023/11/22. 1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135. 2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ from __future__ import annotations import asyncio import json import os.path import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path from typing import Any, Dict, List, Optional, Type, TypeVar, Union from pydantic import ( BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator, model_serializer, model_validator, ) from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set, import_class from metagpt.utils.exceptions import handle_exception from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, actionoutput_mapping_to_str, actionoutput_str_to_mapping, ) class SerializationMixin(BaseModel, extra="forbid"): """ PolyMorphic subclasses Serialization / Deserialization Mixin - First of all, we need to know that pydantic is not designed for polymorphism. - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. More details: - https://docs.pydantic.dev/latest/concepts/serialization/ - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` """ __is_polymorphic_base = False __subclasses_map__ = {} @model_serializer(mode="wrap") def __serialize_with_class_type__(self, default_serializer) -> Any: # default serializer, then append the `__module_class_name` field and return ret = default_serializer(self) ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" return ret @model_validator(mode="wrap") @classmethod def __convert_to_real_type__(cls, value: Any, handler): if isinstance(value, dict) is False: return handler(value) # it is a dict so make sure to remove the __module_class_name # because we don't allow extra keywords but want to ensure # e.g Cat.model_validate(cat.model_dump()) works class_full_name = value.pop("__module_class_name", None) # if it's not the polymorphic base we construct via default handler if not cls.__is_polymorphic_base: if class_full_name is None: return handler(value) elif str(cls) == f"": return handler(value) else: # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") pass # otherwise we lookup the correct polymorphic type and construct that # instead if class_full_name is None: raise ValueError("Missing __module_class_name field") class_type = cls.__subclasses_map__.get(class_full_name, None) if class_type is None: # TODO could try dynamic import raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!") return class_type(**value) def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): cls.__is_polymorphic_base = is_polymorphic_base cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls super().__init_subclass__(**kwargs) class SimpleMessage(BaseModel): content: str role: str class Document(BaseModel): """ Represents a document. """ root_path: str = "" filename: str = "" content: str = "" def get_meta(self) -> Document: """Get metadata of the document. :return: A new Document instance with the same root path and filename. """ return Document(root_path=self.root_path, filename=self.filename) @property def root_relative_path(self): """Get relative path from root of git repository. :return: relative path from root of git repository. """ return os.path.join(self.root_path, self.filename) def __str__(self): return self.content def __repr__(self): return self.content class Documents(BaseModel): """A class representing a collection of documents. Attributes: docs (Dict[str, Document]): A dictionary mapping document names to Document instances. """ docs: Dict[str, Document] = Field(default_factory=dict) class Message(BaseModel): """list[: ]""" id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135 content: str instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True) role: str = "user" # system / user / assistant cause_by: str = Field(default="", validate_default=True) sent_from: str = Field(default="", validate_default=True) send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) @field_validator("id", mode="before") @classmethod def check_id(cls, id: str) -> str: return id if id else uuid.uuid4().hex @field_validator("instruct_content", mode="before") @classmethod def check_instruct_content(cls, ic: Any) -> BaseModel: if ic and isinstance(ic, dict) and "class" in ic: if "mapping" in ic: # compatible with custom-defined ActionOutput mapping = actionoutput_str_to_mapping(ic["mapping"]) actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) elif "module" in ic: # subclasses of BaseModel ic_obj = import_class(ic["class"], ic["module"]) else: raise KeyError("missing required key to init Message.instruct_content from dict") ic = ic_obj(**ic["value"]) return ic @field_validator("cause_by", mode="before") @classmethod def check_cause_by(cls, cause_by: Any) -> str: return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement")) @field_validator("sent_from", mode="before") @classmethod def check_sent_from(cls, sent_from: Any) -> str: return any_to_str(sent_from if sent_from else "") @field_validator("send_to", mode="before") @classmethod def check_send_to(cls, send_to: Any) -> set: return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL}) @field_serializer("instruct_content", mode="plain") def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]: ic_dict = None if ic: # compatible with custom-defined ActionOutput schema = ic.model_json_schema() ic_type = str(type(ic)) if " dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} def dump(self) -> str: """Convert the object to json string""" return self.model_dump_json(exclude_none=True, warnings=False) @staticmethod @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" try: m = json.loads(val) id = m.get("id") if "id" in m: del m["id"] msg = Message(**m) if id: msg.id = id return msg except JSONDecodeError as err: logger.error(f"parse json failed: {val}, error:{err}") return None class UserMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="user") class SystemMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="system") class AIMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="assistant") class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" model_config = ConfigDict(arbitrary_types_allowed=True) _queue: Queue = PrivateAttr(default_factory=Queue) def pop(self) -> Message | None: """Pop one message from the queue.""" try: item = self._queue.get_nowait() if item: self._queue.task_done() return item except QueueEmpty: return None def pop_all(self) -> List[Message]: """Pop all messages from the queue.""" ret = [] while True: msg = self.pop() if not msg: break ret.append(msg) return ret def push(self, msg: Message): """Push a message into the queue.""" self._queue.put_nowait(msg) def empty(self): """Return true if the queue is empty.""" return self._queue.empty() async def dump(self) -> str: """Convert the `MessageQueue` object to a json string.""" if self.empty(): return "[]" lst = [] msgs = [] try: while True: item = await wait_for(self._queue.get(), timeout=1.0) if item is None: break msgs.append(item) lst.append(item.dump()) self._queue.task_done() except asyncio.TimeoutError: logger.debug("Queue is empty, exiting...") finally: for m in msgs: self._queue.put_nowait(m) return json.dumps(lst, ensure_ascii=False) @staticmethod def load(data) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" queue = MessageQueue() try: lst = json.loads(data) for i in lst: msg = Message.load(i) queue.push(msg) except JSONDecodeError as e: logger.warning(f"JSON load failed: {data}, error:{e}") return queue # 定义一个泛型类型变量 T = TypeVar("T", bound="BaseModel") class BaseContext(BaseModel, ABC): @classmethod @handle_exception def loads(cls: Type[T], val: str) -> Optional[T]: i = json.loads(val) return cls(**i) class CodingContext(BaseContext): filename: str design_doc: Optional[Document] = None task_doc: Optional[Document] = None code_doc: Optional[Document] = None class TestingContext(BaseContext): filename: str code_doc: Document test_doc: Optional[Document] = None class RunCodeContext(BaseContext): mode: str = "script" code: Optional[str] = None code_filename: str = "" test_code: Optional[str] = None test_filename: str = "" command: List[str] = Field(default_factory=list) working_directory: str = "" additional_python_paths: List[str] = Field(default_factory=list) output_filename: Optional[str] = None output: Optional[str] = None class RunCodeResult(BaseContext): summary: str stdout: str stderr: str class CodeSummarizeContext(BaseModel): design_filename: str = "" task_filename: str = "" codes_filenames: List[str] = Field(default_factory=list) reason: str = "" @staticmethod def loads(filenames: List) -> CodeSummarizeContext: ctx = CodeSummarizeContext() for filename in filenames: if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): ctx.design_filename = str(filename) continue if Path(filename).is_relative_to(TASK_FILE_REPO): ctx.task_filename = str(filename) continue return ctx def __hash__(self): return hash((self.design_filename, self.task_filename)) class BugFixContext(BaseContext): filename: str = "" # mermaid class view class ClassMeta(BaseModel): name: str = "" abstraction: bool = False static: bool = False visibility: str = "" class ClassAttribute(ClassMeta): value_type: str = "" default_value: str = "" def get_mermaid(self, align=1) -> str: content = "".join(["\t" for i in range(align)]) + self.visibility if self.value_type: content += self.value_type + " " content += self.name if self.default_value: content += "=" if self.value_type not in ["str", "string", "String"]: content += self.default_value else: content += '"' + self.default_value.replace('"', "") + '"' if self.abstraction: content += "*" if self.static: content += "$" return content class ClassMethod(ClassMeta): args: List[ClassAttribute] = Field(default_factory=list) return_type: str = "" def get_mermaid(self, align=1) -> str: content = "".join(["\t" for i in range(align)]) + self.visibility content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" if self.return_type: content += ":" + self.return_type if self.abstraction: content += "*" if self.static: content += "$" return content class ClassView(ClassMeta): attributes: List[ClassAttribute] = Field(default_factory=list) methods: List[ClassMethod] = Field(default_factory=list) def get_mermaid(self, align=1) -> str: content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" for v in self.attributes: content += v.get_mermaid(align=align + 1) + "\n" for v in self.methods: content += v.get_mermaid(align=align + 1) + "\n" content += "".join(["\t" for i in range(align)]) + "}\n" return content