#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2023/5/8 22:12 @Author : alexanderwu @File : schema.py @Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. @Modified By: mashenquan, 2023/11/22. 1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135. 2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ from __future__ import annotations import asyncio import json import os.path import uuid from abc import ABC from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from pydantic import ( BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator, ) from pydantic_core import core_schema from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set, import_class from metagpt.utils.exceptions import handle_exception from metagpt.utils.serialize import ( actionoutout_schema_to_mapping, actionoutput_mapping_to_str, actionoutput_str_to_mapping, ) class SerializationMixin(BaseModel): """ PolyMorphic subclasses Serialization / Deserialization Mixin - First of all, we need to know that pydantic is not designed for polymorphism. - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. More details: - https://docs.pydantic.dev/latest/concepts/serialization/ - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` """ __is_polymorphic_base = False __subclasses_map__ = {} @classmethod def __get_pydantic_core_schema__( cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema] ) -> core_schema.CoreSchema: schema = handler(source) og_schema_ref = schema["ref"] schema["ref"] += ":mixin" return core_schema.no_info_before_validator_function( cls.__deserialize_with_real_type__, schema=schema, ref=og_schema_ref, serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__), ) @classmethod def __serialize_add_class_type__( cls, value, handler: core_schema.SerializerFunctionWrapHandler, ) -> Any: ret = handler(value) if not len(cls.__subclasses__()): # only subclass add `__module_class_name` ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}" return ret @classmethod def __deserialize_with_real_type__(cls, value: Any): if not isinstance(value, dict): return value if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value): # add right condition to init BaseClass like Action() return value module_class_name = value.get("__module_class_name", None) if module_class_name is None: raise ValueError("Missing field: __module_class_name") class_type = cls.__subclasses_map__.get(module_class_name, None) if class_type is None: raise TypeError("Trying to instantiate {module_class_name} which not defined yet.") return class_type(**value) def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): cls.__is_polymorphic_base = is_polymorphic_base cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls super().__init_subclass__(**kwargs) class SimpleMessage(BaseModel): content: str role: str class Document(BaseModel): """ Represents a document. """ root_path: str = "" filename: str = "" content: str = "" def get_meta(self) -> Document: """Get metadata of the document. :return: A new Document instance with the same root path and filename. """ return Document(root_path=self.root_path, filename=self.filename) @property def root_relative_path(self): """Get relative path from root of git repository. :return: relative path from root of git repository. """ return os.path.join(self.root_path, self.filename) @property def full_path(self): if not CONFIG.git_repo: return None return str(CONFIG.git_repo.workdir / self.root_path / self.filename) def __str__(self): return self.content def __repr__(self): return self.content class Documents(BaseModel): """A class representing a collection of documents. Attributes: docs (Dict[str, Document]): A dictionary mapping document names to Document instances. """ docs: Dict[str, Document] = Field(default_factory=dict) class Message(BaseModel): """list[: ]""" id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135 content: str instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True) role: str = "user" # system / user / assistant cause_by: str = Field(default="", validate_default=True) sent_from: str = Field(default="", validate_default=True) send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) @field_validator("id", mode="before") @classmethod def check_id(cls, id: str) -> str: return id if id else uuid.uuid4().hex @field_validator("instruct_content", mode="before") @classmethod def check_instruct_content(cls, ic: Any) -> BaseModel: if ic and not isinstance(ic, BaseModel) and "class" in ic: # compatible with custom-defined ActionOutput mapping = actionoutput_str_to_mapping(ic["mapping"]) actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping) ic = ic_obj(**ic["value"]) return ic @field_validator("cause_by", mode="before") @classmethod def check_cause_by(cls, cause_by: Any) -> str: return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement")) @field_validator("sent_from", mode="before") @classmethod def check_sent_from(cls, sent_from: Any) -> str: return any_to_str(sent_from if sent_from else "") @field_validator("send_to", mode="before") @classmethod def check_send_to(cls, send_to: Any) -> set: return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL}) @field_serializer("instruct_content", mode="plain") def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]: ic_dict = None if ic: # compatible with custom-defined ActionOutput schema = ic.model_json_schema() # `Documents` contain definitions if "definitions" not in schema: # TODO refine with nested BaseModel mapping = actionoutout_schema_to_mapping(schema) mapping = actionoutput_mapping_to_str(mapping) ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()} return ic_dict def __init__(self, content: str = "", **data: Any): data["content"] = data.get("content", content) super().__init__(**data) def __setattr__(self, key, val): """Override `@property.setter`, convert non-string parameters into string parameters.""" if key == MESSAGE_ROUTE_CAUSE_BY: new_val = any_to_str(val) elif key == MESSAGE_ROUTE_FROM: new_val = any_to_str(val) elif key == MESSAGE_ROUTE_TO: new_val = any_to_str_set(val) else: new_val = val super().__setattr__(key, new_val) def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) if self.instruct_content: return f"{self.role}: {self.instruct_content.model_dump()}" return f"{self.role}: {self.content}" def __repr__(self): return self.__str__() def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} def dump(self) -> str: """Convert the object to json string""" return self.model_dump_json(exclude_none=True, warnings=False) @staticmethod @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" try: m = json.loads(val) id = m.get("id") if "id" in m: del m["id"] msg = Message(**m) if id: msg.id = id return msg except JSONDecodeError as err: logger.error(f"parse json failed: {val}, error:{err}") return None class UserMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="user") class SystemMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="system") class AIMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="assistant") class MessageQueue(BaseModel): """Message queue which supports asynchronous updates.""" model_config = ConfigDict(arbitrary_types_allowed=True) _queue: Queue = PrivateAttr(default_factory=Queue) def pop(self) -> Message | None: """Pop one message from the queue.""" try: item = self._queue.get_nowait() if item: self._queue.task_done() return item except QueueEmpty: return None def pop_all(self) -> List[Message]: """Pop all messages from the queue.""" ret = [] while True: msg = self.pop() if not msg: break ret.append(msg) return ret def push(self, msg: Message): """Push a message into the queue.""" self._queue.put_nowait(msg) def empty(self): """Return true if the queue is empty.""" return self._queue.empty() async def dump(self) -> str: """Convert the `MessageQueue` object to a json string.""" if self.empty(): return "[]" lst = [] msgs = [] try: while True: item = await wait_for(self._queue.get(), timeout=1.0) if item is None: break msgs.append(item) lst.append(item.dump()) self._queue.task_done() except asyncio.TimeoutError: logger.debug("Queue is empty, exiting...") finally: for m in msgs: self._queue.put_nowait(m) return json.dumps(lst, ensure_ascii=False) @staticmethod def load(data) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" queue = MessageQueue() try: lst = json.loads(data) for i in lst: msg = Message.load(i) queue.push(msg) except JSONDecodeError as e: logger.warning(f"JSON load failed: {data}, error:{e}") return queue # 定义一个泛型类型变量 T = TypeVar("T", bound="BaseModel") class BaseContext(BaseModel, ABC): @classmethod @handle_exception def loads(cls: Type[T], val: str) -> Optional[T]: i = json.loads(val) return cls(**i) class CodingContext(BaseContext): filename: str design_doc: Optional[Document] = None task_doc: Optional[Document] = None code_doc: Optional[Document] = None class TestingContext(BaseContext): filename: str code_doc: Document test_doc: Optional[Document] = None class RunCodeContext(BaseContext): mode: str = "script" code: Optional[str] = None code_filename: str = "" test_code: Optional[str] = None test_filename: str = "" command: List[str] = Field(default_factory=list) working_directory: str = "" additional_python_paths: List[str] = Field(default_factory=list) output_filename: Optional[str] = None output: Optional[str] = None class RunCodeResult(BaseContext): summary: str stdout: str stderr: str class CodeSummarizeContext(BaseModel): design_filename: str = "" task_filename: str = "" codes_filenames: List[str] = Field(default_factory=list) reason: str = "" @staticmethod def loads(filenames: List) -> CodeSummarizeContext: ctx = CodeSummarizeContext() for filename in filenames: if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): ctx.design_filename = str(filename) continue if Path(filename).is_relative_to(TASK_FILE_REPO): ctx.task_filename = str(filename) continue return ctx def __hash__(self): return hash((self.design_filename, self.task_filename)) class BugFixContext(BaseContext): filename: str = "" # mermaid class view class ClassMeta(BaseModel): name: str = "" abstraction: bool = False static: bool = False visibility: str = "" class ClassAttribute(ClassMeta): value_type: str = "" default_value: str = "" def get_mermaid(self, align=1) -> str: content = "".join(["\t" for i in range(align)]) + self.visibility if self.value_type: content += self.value_type + " " content += self.name if self.default_value: content += "=" if self.value_type not in ["str", "string", "String"]: content += self.default_value else: content += '"' + self.default_value.replace('"', "") + '"' if self.abstraction: content += "*" if self.static: content += "$" return content class ClassMethod(ClassMeta): args: List[ClassAttribute] = Field(default_factory=list) return_type: str = "" def get_mermaid(self, align=1) -> str: content = "".join(["\t" for i in range(align)]) + self.visibility content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")" if self.return_type: content += ":" + self.return_type if self.abstraction: content += "*" if self.static: content += "$" return content class ClassView(ClassMeta): attributes: List[ClassAttribute] = Field(default_factory=list) methods: List[ClassMethod] = Field(default_factory=list) def get_mermaid(self, align=1) -> str: content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n" for v in self.attributes: content += v.get_mermaid(align=align + 1) + "\n" for v in self.methods: content += v.get_mermaid(align=align + 1) + "\n" content += "".join(["\t" for i in range(align)]) + "}\n" return content