MetaGPT/metagpt/schema.py

377 lines
11 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/8 22:12
@Author : alexanderwu
@File : schema.py
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
@Modified By: mashenquan, 2023/11/22.
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
between actions.
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
"""
import asyncio
import json
import os.path
import uuid
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Dict, List, Optional, Set, TypedDict
from pydantic import BaseModel, Field
from dataclasses import dataclass, field
from typing import Type, TypedDict, Union, Optional
from pydantic import BaseModel, Field
from pydantic.main import ModelMetaclass
from metagpt.config import CONFIG
from metagpt.const import (
MESSAGE_ROUTE_CAUSE_BY,
MESSAGE_ROUTE_FROM,
MESSAGE_ROUTE_TO,
MESSAGE_ROUTE_TO_ALL,
SYSTEM_DESIGN_FILE_REPO,
TASK_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.utils.common import any_to_str, any_to_str_set
# from metagpt.utils.serialize import actionoutout_schema_to_mapping
# from metagpt.actions.action_output import ActionOutput
# from metagpt.actions.action import Action
from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \
actionoutput_str_to_mapping
from metagpt.utils.utils import import_class
class RawMessage(TypedDict):
content: str
role: str
class Document(BaseModel):
"""
Represents a document.
"""
root_path: str = ""
filename: str = ""
content: str = ""
def get_meta(self) -> "Document"":
"""Get metadata of the document.
:return: A new Document instance with the same root path and filename.
"""
return Document(root_path=self.root_path, filename=self.filename)
@property
def root_relative_path(self):
"""Get relative path from root of git repository.
:return: relative path from root of git repository.
"""
return os.path.join(self.root_path, self.filename)
@property
def full_path(self):
if not CONFIG.git_repo:
return None
return str(CONFIG.git_repo.workdir / self.root_path / self.filename)
def __str__(self):
return self.content
def __repr__(self):
return self.content
class Documents(BaseModel):
"""A class representing a collection of documents.
Attributes:
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
"""
docs: Dict[str, Document] = Field(default_factory=dict)
class Message(BaseModel):
"""list[<role>: <content>]"""
id: str # According to Section 2.2.3.1.1 of RFC 135
content: str
instruct_content: BaseModel = Field(default=None)
role: str = "user" # system / user / assistant
cause_by: str = ""
sent_from: str = ""
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})
def __init__(self, **kwargs):
instruct_content = kwargs.get("instruct_content", None)
cause_by = kwargs.get("cause_by", None)
if instruct_content and not isinstance(instruct_content, BaseModel):
ic = instruct_content
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output")
ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=mapping)
ic_new = ic_obj(**ic["value"])
kwargs["instruct_content"] = ic_new
if cause_by and not isinstance(cause_by, ModelMetaclass):
action_class = import_class("Action", "metagpt.actions.action")
kwargs["cause_by"] = action_class.deser_class(cause_by)
super(Message, self).__init__(**kwargs)
def __setattr__(self, key, val):
"""Override `@property.setter`, convert non-string parameters into string parameters."""
if key == MESSAGE_ROUTE_CAUSE_BY:
new_val = any_to_str(val)
elif key == MESSAGE_ROUTE_FROM:
new_val = any_to_str(val)
elif key == MESSAGE_ROUTE_TO:
new_val = any_to_str_set(val)
else:
new_val = val
super().__setattr__(key, new_val)
def dict(self, *args, **kwargs) -> "DictStrAny":
""" overwrite the `dict` to dump dynamic pydantic model"""
obj_dict = super(Message, self).dict(*args, **kwargs)
ic = self.instruct_content # deal custom-defined action
if ic:
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
cb = self.cause_by
if cb:
obj_dict["cause_by"] = cb.ser_class()
return obj_dict
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
return f"{self.role}: {self.content}"
def __repr__(self):
return self.__str__()
# def dict(self):
# return {
# "content": self.content,
# "instruct_content": self.instruct_content,
# "role": self.role,
# "cause_by": self.cause_by,
# "sent_from": self.sent_from,
# "send_to": self.send_to,
# "restricted_to": self.restricted_to
# }
def to_dict(self) -> dict:
"""Return a dict containing `role` and `content` for the LLM call.l"""
return {"role": self.role, "content": self.content}
def dump(self) -> str:
"""Convert the object to json string"""
return self.json(exclude_none=True)
@staticmethod
def load(val):
"""Convert the json string to object."""
try:
d = json.loads(val)
return Message(**d)
except JSONDecodeError as err:
logger.error(f"parse json failed: {val}, error:{err}")
return None
class UserMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content=content, role="user")
class SystemMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content=content, role="system")
class AIMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content=content, role="assistant")
class MessageQueue:
"""Message queue which supports asynchronous updates."""
def __init__(self):
self._queue = Queue()
def pop(self) -> Message | None:
"""Pop one message from the queue."""
try:
item = self._queue.get_nowait()
if item:
self._queue.task_done()
return item
except QueueEmpty:
return None
def pop_all(self) -> List[Message]:
"""Pop all messages from the queue."""
ret = []
while True:
msg = self.pop()
if not msg:
break
ret.append(msg)
return ret
def push(self, msg: Message):
"""Push a message into the queue."""
self._queue.put_nowait(msg)
def empty(self):
"""Return true if the queue is empty."""
return self._queue.empty()
async def dump(self) -> str:
"""Convert the `MessageQueue` object to a json string."""
if self.empty():
return "[]"
lst = []
try:
while True:
item = await wait_for(self._queue.get(), timeout=1.0)
if item is None:
break
lst.append(item.dict(exclude_none=True))
self._queue.task_done()
except asyncio.TimeoutError:
logger.debug("Queue is empty, exiting...")
return json.dumps(lst)
@staticmethod
def load(self, v) -> "MessageQueue":
"""Convert the json string to the `MessageQueue` object."""
q = MessageQueue()
try:
lst = json.loads(v)
for i in lst:
msg = Message(**i)
q.push(msg)
except JSONDecodeError as e:
logger.warning(f"JSON load failed: {v}, error:{e}")
return q
class CodingContext(BaseModel):
filename: str
design_doc: Optional[Document]
task_doc: Optional[Document]
code_doc: Optional[Document]
@staticmethod
def loads(val: str) -> "CodingContext" | None:
try:
m = json.loads(val)
return CodingContext(**m)
except Exception:
return None
class TestingContext(BaseModel):
filename: str
code_doc: Document
test_doc: Optional[Document]
@staticmethod
def loads(val: str) -> "TestingContext" | None:
try:
m = json.loads(val)
return TestingContext(**m)
except Exception:
return None
class RunCodeContext(BaseModel):
mode: str = "script"
code: Optional[str]
code_filename: str = ""
test_code: Optional[str]
test_filename: str = ""
command: List[str] = Field(default_factory=list)
working_directory: str = ""
additional_python_paths: List[str] = Field(default_factory=list)
output_filename: Optional[str]
output: Optional[str]
@staticmethod
def loads(val: str) -> "RunCodeContext" | None:
try:
m = json.loads(val)
return RunCodeContext(**m)
except Exception:
return None
class RunCodeResult(BaseModel):
summary: str
stdout: str
stderr: str
@staticmethod
def loads(val: str) -> "RunCodeResult" | None:
try:
m = json.loads(val)
return RunCodeResult(**m)
except Exception:
return None
class CodeSummarizeContext(BaseModel):
design_filename: str = ""
task_filename: str = ""
codes_filenames: List[str] = Field(default_factory=list)
reason: str = ""
@staticmethod
def loads(filenames: List) -> "CodeSummarizeContext":
ctx = CodeSummarizeContext()
for filename in filenames:
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
ctx.design_filename = str(filename)
continue
if Path(filename).is_relative_to(TASK_FILE_REPO):
ctx.task_filename = str(filename)
continue
return ctx
def __hash__(self):
return hash((self.design_filename, self.task_filename))
class BugFixContext(BaseModel):
filename: str = ""