#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2023/5/8 22:12 @Author : alexanderwu @File : schema.py @Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. @Modified By: mashenquan, 2023/11/22. 1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135. 2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing between actions. 3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135. """ from __future__ import annotations import asyncio import json import os.path import uuid from asyncio import Queue, QueueEmpty, wait_for from json import JSONDecodeError from pathlib import Path from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar from pydantic import BaseModel, Field from metagpt.config import CONFIG from metagpt.const import ( MESSAGE_ROUTE_CAUSE_BY, MESSAGE_ROUTE_FROM, MESSAGE_ROUTE_TO, MESSAGE_ROUTE_TO_ALL, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO, ) from metagpt.logs import logger from metagpt.utils.common import any_to_str, any_to_str_set from metagpt.utils.exceptions import handle_exception class RawMessage(TypedDict): content: str role: str class Document(BaseModel): """ Represents a document. """ root_path: str = "" filename: str = "" content: str = "" def get_meta(self) -> Document: """Get metadata of the document. :return: A new Document instance with the same root path and filename. """ return Document(root_path=self.root_path, filename=self.filename) @property def root_relative_path(self): """Get relative path from root of git repository. :return: relative path from root of git repository. """ return os.path.join(self.root_path, self.filename) @property def full_path(self): if not CONFIG.git_repo: return None return str(CONFIG.git_repo.workdir / self.root_path / self.filename) def __str__(self): return self.content def __repr__(self): return self.content class Documents(BaseModel): """A class representing a collection of documents. Attributes: docs (Dict[str, Document]): A dictionary mapping document names to Document instances. """ docs: Dict[str, Document] = Field(default_factory=dict) class Message(BaseModel): """list[: ]""" id: str # According to Section 2.2.3.1.1 of RFC 135 content: str instruct_content: BaseModel = Field(default=None) role: str = "user" # system / user / assistant cause_by: str = "" sent_from: str = "" send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) def __init__( self, content, instruct_content=None, role="user", cause_by="", sent_from="", send_to=MESSAGE_ROUTE_TO_ALL, **kwargs, ): """ Parameters not listed below will be stored as meta info, including custom parameters. :param content: Message content. :param instruct_content: Message content struct. :param cause_by: Message producer :param sent_from: Message route info tells who sent this message. :param send_to: Specifies the target recipient or consumer for message delivery in the environment. :param role: Message meta info tells who sent this message. """ super().__init__( id=uuid.uuid4().hex, content=content, instruct_content=instruct_content, role=role, cause_by=any_to_str(cause_by), sent_from=any_to_str(sent_from), send_to=any_to_str_set(send_to), **kwargs, ) def __setattr__(self, key, val): """Override `@property.setter`, convert non-string parameters into string parameters.""" if key == MESSAGE_ROUTE_CAUSE_BY: new_val = any_to_str(val) elif key == MESSAGE_ROUTE_FROM: new_val = any_to_str(val) elif key == MESSAGE_ROUTE_TO: new_val = any_to_str_set(val) else: new_val = val super().__setattr__(key, new_val) def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) return f"{self.role}: {self.content}" def __repr__(self): return self.__str__() def to_dict(self) -> dict: """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} def dump(self) -> str: """Convert the object to json string""" return self.json(exclude_none=True) @staticmethod @handle_exception(exception_type=JSONDecodeError, default_return=None) def load(val): """Convert the json string to object.""" d = json.loads(val) return Message(**d) class UserMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="user") class SystemMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="system") class AIMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): super().__init__(content=content, role="assistant") class MessageQueue: """Message queue which supports asynchronous updates.""" def __init__(self): self._queue = Queue() def pop(self) -> Message | None: """Pop one message from the queue.""" try: item = self._queue.get_nowait() if item: self._queue.task_done() return item except QueueEmpty: return None def pop_all(self) -> List[Message]: """Pop all messages from the queue.""" ret = [] while True: msg = self.pop() if not msg: break ret.append(msg) return ret def push(self, msg: Message): """Push a message into the queue.""" self._queue.put_nowait(msg) def empty(self): """Return true if the queue is empty.""" return self._queue.empty() async def dump(self) -> str: """Convert the `MessageQueue` object to a json string.""" if self.empty(): return "[]" lst = [] try: while True: item = await wait_for(self._queue.get(), timeout=1.0) if item is None: break lst.append(item.dict(exclude_none=True)) self._queue.task_done() except asyncio.TimeoutError: logger.debug("Queue is empty, exiting...") return json.dumps(lst) @staticmethod def load(i) -> "MessageQueue": """Convert the json string to the `MessageQueue` object.""" queue = MessageQueue() try: lst = json.loads(i) for i in lst: msg = Message(**i) queue.push(msg) except JSONDecodeError as e: logger.warning(f"JSON load failed: {i}, error:{e}") return queue # 定义一个泛型类型变量 T = TypeVar("T", bound="BaseModel") class BaseContext(BaseModel): @classmethod @handle_exception def loads(cls: Type[T], val: str) -> Optional[T]: i = json.loads(val) return cls(**i) class CodingContext(BaseContext): filename: str design_doc: Optional[Document] task_doc: Optional[Document] code_doc: Optional[Document] class TestingContext(BaseContext): filename: str code_doc: Document test_doc: Optional[Document] class RunCodeContext(BaseContext): mode: str = "script" code: Optional[str] code_filename: str = "" test_code: Optional[str] test_filename: str = "" command: List[str] = Field(default_factory=list) working_directory: str = "" additional_python_paths: List[str] = Field(default_factory=list) output_filename: Optional[str] output: Optional[str] class RunCodeResult(BaseContext): summary: str stdout: str stderr: str class CodeSummarizeContext(BaseModel): design_filename: str = "" task_filename: str = "" codes_filenames: List[str] = Field(default_factory=list) reason: str = "" @staticmethod def loads(filenames: List) -> CodeSummarizeContext: ctx = CodeSummarizeContext() for filename in filenames: if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO): ctx.design_filename = str(filename) continue if Path(filename).is_relative_to(TASK_FILE_REPO): ctx.task_filename = str(filename) continue return ctx def __hash__(self): return hash((self.design_filename, self.task_filename)) class BugFixContext(BaseContext): filename: str = ""