mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 16:56:26 +02:00
343 lines
9.5 KiB
Python
343 lines
9.5 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@Time : 2023/5/8 22:12
|
|
@Author : alexanderwu
|
|
@File : schema.py
|
|
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
|
|
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
|
|
@Modified By: mashenquan, 2023/11/22.
|
|
1. Add `Document` and `Documents` for `FileRepository` in Section 2.2.3.4 of RFC 135.
|
|
2. Encapsulate the common key-values set to pydantic structures to standardize and unify parameter passing
|
|
between actions.
|
|
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import os.path
|
|
import uuid
|
|
from asyncio import Queue, QueueEmpty, wait_for
|
|
from json import JSONDecodeError
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Set, TypedDict
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from metagpt.config import CONFIG
|
|
from metagpt.const import (
|
|
MESSAGE_ROUTE_CAUSE_BY,
|
|
MESSAGE_ROUTE_FROM,
|
|
MESSAGE_ROUTE_TO,
|
|
MESSAGE_ROUTE_TO_ALL,
|
|
SYSTEM_DESIGN_FILE_REPO,
|
|
TASK_FILE_REPO,
|
|
)
|
|
from metagpt.logs import logger
|
|
from metagpt.utils.common import any_to_str, any_to_str_set
|
|
|
|
|
|
class RawMessage(TypedDict):
|
|
content: str
|
|
role: str
|
|
|
|
|
|
class Document(BaseModel):
|
|
"""
|
|
Represents a document.
|
|
"""
|
|
|
|
root_path: str = ""
|
|
filename: str = ""
|
|
content: str = ""
|
|
|
|
def get_meta(self) -> Document:
|
|
"""Get metadata of the document.
|
|
|
|
:return: A new Document instance with the same root path and filename.
|
|
"""
|
|
|
|
return Document(root_path=self.root_path, filename=self.filename)
|
|
|
|
@property
|
|
def root_relative_path(self):
|
|
"""Get relative path from root of git repository.
|
|
|
|
:return: relative path from root of git repository.
|
|
"""
|
|
return os.path.join(self.root_path, self.filename)
|
|
|
|
@property
|
|
def full_path(self):
|
|
if not CONFIG.git_repo:
|
|
return None
|
|
return str(CONFIG.git_repo.workdir / self.root_path / self.filename)
|
|
|
|
|
|
class Documents(BaseModel):
|
|
"""A class representing a collection of documents.
|
|
|
|
Attributes:
|
|
docs (Dict[str, Document]): A dictionary mapping document names to Document instances.
|
|
"""
|
|
|
|
docs: Dict[str, Document] = Field(default_factory=dict)
|
|
|
|
|
|
class Message(BaseModel):
|
|
"""list[<role>: <content>]"""
|
|
|
|
id: str # According to Section 2.2.3.1.1 of RFC 135
|
|
content: str
|
|
instruct_content: BaseModel = Field(default=None)
|
|
role: str = "user" # system / user / assistant
|
|
cause_by: str = ""
|
|
sent_from: str = ""
|
|
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})
|
|
|
|
def __init__(
|
|
self,
|
|
content,
|
|
instruct_content=None,
|
|
role="user",
|
|
cause_by="",
|
|
sent_from="",
|
|
send_to=MESSAGE_ROUTE_TO_ALL,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Parameters not listed below will be stored as meta info, including custom parameters.
|
|
:param content: Message content.
|
|
:param instruct_content: Message content struct.
|
|
:param cause_by: Message producer
|
|
:param sent_from: Message route info tells who sent this message.
|
|
:param send_to: Specifies the target recipient or consumer for message delivery in the environment.
|
|
:param role: Message meta info tells who sent this message.
|
|
"""
|
|
super().__init__(
|
|
id=uuid.uuid4().hex,
|
|
content=content,
|
|
instruct_content=instruct_content,
|
|
role=role,
|
|
cause_by=any_to_str(cause_by),
|
|
sent_from=any_to_str(sent_from),
|
|
send_to=any_to_str_set(send_to),
|
|
**kwargs,
|
|
)
|
|
|
|
def __setattr__(self, key, val):
|
|
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
|
if key == MESSAGE_ROUTE_CAUSE_BY:
|
|
new_val = any_to_str(val)
|
|
elif key == MESSAGE_ROUTE_FROM:
|
|
new_val = any_to_str(val)
|
|
elif key == MESSAGE_ROUTE_TO:
|
|
new_val = any_to_str_set(val)
|
|
else:
|
|
new_val = val
|
|
super().__setattr__(key, new_val)
|
|
|
|
def __str__(self):
|
|
# prefix = '-'.join([self.role, str(self.cause_by)])
|
|
return f"{self.role}: {self.content}"
|
|
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
|
return {"role": self.role, "content": self.content}
|
|
|
|
def dump(self) -> str:
|
|
"""Convert the object to json string"""
|
|
return self.json(exclude_none=True)
|
|
|
|
@staticmethod
|
|
def load(val):
|
|
"""Convert the json string to object."""
|
|
try:
|
|
d = json.loads(val)
|
|
return Message(**d)
|
|
except JSONDecodeError as err:
|
|
logger.error(f"parse json failed: {val}, error:{err}")
|
|
return None
|
|
|
|
|
|
class UserMessage(Message):
|
|
"""便于支持OpenAI的消息
|
|
Facilitate support for OpenAI messages
|
|
"""
|
|
|
|
def __init__(self, content: str):
|
|
super().__init__(content=content, role="user")
|
|
|
|
|
|
class SystemMessage(Message):
|
|
"""便于支持OpenAI的消息
|
|
Facilitate support for OpenAI messages
|
|
"""
|
|
|
|
def __init__(self, content: str):
|
|
super().__init__(content=content, role="system")
|
|
|
|
|
|
class AIMessage(Message):
|
|
"""便于支持OpenAI的消息
|
|
Facilitate support for OpenAI messages
|
|
"""
|
|
|
|
def __init__(self, content: str):
|
|
super().__init__(content=content, role="assistant")
|
|
|
|
|
|
class MessageQueue:
|
|
"""Message queue which supports asynchronous updates."""
|
|
|
|
def __init__(self):
|
|
self._queue = Queue()
|
|
|
|
def pop(self) -> Message | None:
|
|
"""Pop one message from the queue."""
|
|
try:
|
|
item = self._queue.get_nowait()
|
|
if item:
|
|
self._queue.task_done()
|
|
return item
|
|
except QueueEmpty:
|
|
return None
|
|
|
|
def pop_all(self) -> List[Message]:
|
|
"""Pop all messages from the queue."""
|
|
ret = []
|
|
while True:
|
|
msg = self.pop()
|
|
if not msg:
|
|
break
|
|
ret.append(msg)
|
|
return ret
|
|
|
|
def push(self, msg: Message):
|
|
"""Push a message into the queue."""
|
|
self._queue.put_nowait(msg)
|
|
|
|
def empty(self):
|
|
"""Return true if the queue is empty."""
|
|
return self._queue.empty()
|
|
|
|
async def dump(self) -> str:
|
|
"""Convert the `MessageQueue` object to a json string."""
|
|
if self.empty():
|
|
return "[]"
|
|
|
|
lst = []
|
|
try:
|
|
while True:
|
|
item = await wait_for(self._queue.get(), timeout=1.0)
|
|
if item is None:
|
|
break
|
|
lst.append(item.dict(exclude_none=True))
|
|
self._queue.task_done()
|
|
except asyncio.TimeoutError:
|
|
logger.debug("Queue is empty, exiting...")
|
|
return json.dumps(lst)
|
|
|
|
@staticmethod
|
|
def load(self, v) -> "MessageQueue":
|
|
"""Convert the json string to the `MessageQueue` object."""
|
|
q = MessageQueue()
|
|
try:
|
|
lst = json.loads(v)
|
|
for i in lst:
|
|
msg = Message(**i)
|
|
q.push(msg)
|
|
except JSONDecodeError as e:
|
|
logger.warning(f"JSON load failed: {v}, error:{e}")
|
|
|
|
return q
|
|
|
|
|
|
class CodingContext(BaseModel):
|
|
filename: str
|
|
design_doc: Document
|
|
task_doc: Optional[Document]
|
|
code_doc: Optional[Document]
|
|
|
|
@staticmethod
|
|
def loads(val: str) -> CodingContext | None:
|
|
try:
|
|
m = json.loads(val)
|
|
return CodingContext(**m)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
class TestingContext(BaseModel):
|
|
filename: str
|
|
code_doc: Document
|
|
test_doc: Optional[Document]
|
|
|
|
@staticmethod
|
|
def loads(val: str) -> TestingContext | None:
|
|
try:
|
|
m = json.loads(val)
|
|
return TestingContext(**m)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
class RunCodeContext(BaseModel):
|
|
mode: str = "script"
|
|
code: Optional[str]
|
|
code_filename: str = ""
|
|
test_code: Optional[str]
|
|
test_filename: str = ""
|
|
command: List[str] = Field(default_factory=list)
|
|
working_directory: str = ""
|
|
additional_python_paths: List[str] = Field(default_factory=list)
|
|
output_filename: Optional[str]
|
|
output: Optional[str]
|
|
|
|
@staticmethod
|
|
def loads(val: str) -> RunCodeContext | None:
|
|
try:
|
|
m = json.loads(val)
|
|
return RunCodeContext(**m)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
class RunCodeResult(BaseModel):
|
|
summary: str
|
|
stdout: str
|
|
stderr: str
|
|
|
|
@staticmethod
|
|
def loads(val: str) -> RunCodeResult | None:
|
|
try:
|
|
m = json.loads(val)
|
|
return RunCodeResult(**m)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
class CodeSummarizeContext(BaseModel):
|
|
design_filename: str = ""
|
|
task_filename: str = ""
|
|
codes_filenames: List[str] = Field(default_factory=list)
|
|
reason: str = ""
|
|
|
|
@staticmethod
|
|
def loads(filenames: List) -> CodeSummarizeContext:
|
|
ctx = CodeSummarizeContext()
|
|
for filename in filenames:
|
|
if Path(filename).is_relative_to(SYSTEM_DESIGN_FILE_REPO):
|
|
ctx.design_filename = str(filename)
|
|
continue
|
|
if Path(filename).is_relative_to(TASK_FILE_REPO):
|
|
ctx.task_filename = str(filename)
|
|
continue
|
|
return ctx
|
|
|
|
def __hash__(self):
|
|
return hash((self.design_filename, self.task_filename))
|