2023-06-30 17:10:48 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
@Time : 2023/5/8 22:12
|
|
|
|
|
@Author : alexanderwu
|
|
|
|
|
@File : schema.py
|
2023-11-03 11:53:47 +08:00
|
|
|
@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116:
|
|
|
|
|
Replanned the distribution of responsibilities and functional positioning of `Message` class attributes.
|
2023-06-30 17:10:48 +08:00
|
|
|
"""
|
|
|
|
|
from __future__ import annotations
|
2023-07-22 11:28:22 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
import asyncio
|
2023-10-31 15:23:37 +08:00
|
|
|
import json
|
2023-11-01 20:08:58 +08:00
|
|
|
from asyncio import Queue, QueueEmpty, wait_for
|
2023-10-31 15:23:37 +08:00
|
|
|
from json import JSONDecodeError
|
2023-11-01 20:08:58 +08:00
|
|
|
from typing import Dict, List, Set, TypedDict
|
2023-06-30 17:10:48 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
from pydantic import BaseModel, Field
|
2023-06-30 17:10:48 +08:00
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
from metagpt.const import (
|
|
|
|
|
MESSAGE_META_ROLE,
|
|
|
|
|
MESSAGE_ROUTE_CAUSE_BY,
|
|
|
|
|
MESSAGE_ROUTE_FROM,
|
|
|
|
|
MESSAGE_ROUTE_TO,
|
|
|
|
|
)
|
2023-07-22 11:28:22 +08:00
|
|
|
from metagpt.logs import logger
|
2023-11-04 16:46:32 +08:00
|
|
|
from metagpt.utils.common import any_to_str
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class RawMessage(TypedDict):
|
|
|
|
|
content: str
|
|
|
|
|
role: str
|
|
|
|
|
|
|
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
class Routes(BaseModel):
|
|
|
|
|
"""Responsible for managing routing information for the Message class."""
|
|
|
|
|
|
|
|
|
|
routes: List[Dict] = Field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
def set_from(self, value):
|
|
|
|
|
"""Set the label of the message sender."""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
route[MESSAGE_ROUTE_FROM] = value
|
|
|
|
|
|
|
|
|
|
def set_to(self, tags: Set):
|
|
|
|
|
"""Set the labels of the message recipient."""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
if tags:
|
|
|
|
|
route[MESSAGE_ROUTE_TO] = tags
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if MESSAGE_ROUTE_TO in route:
|
|
|
|
|
del route[MESSAGE_ROUTE_TO]
|
|
|
|
|
|
|
|
|
|
def add_to(self, tag: str):
|
|
|
|
|
"""Add a label of the message recipient."""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
tags = route.get(MESSAGE_ROUTE_TO, set())
|
|
|
|
|
tags.add(tag)
|
|
|
|
|
route[MESSAGE_ROUTE_TO] = tags
|
|
|
|
|
|
|
|
|
|
def _get_route(self) -> Dict:
|
|
|
|
|
if not self.routes:
|
|
|
|
|
self.routes.append({})
|
|
|
|
|
return self.routes[0]
|
|
|
|
|
|
|
|
|
|
def is_recipient(self, tags: Set) -> bool:
|
|
|
|
|
"""Check if it is the message recipient."""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
to_tags = route.get(MESSAGE_ROUTE_TO)
|
|
|
|
|
if not to_tags:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
for k in tags:
|
|
|
|
|
if k in to_tags:
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tx_from(self):
|
|
|
|
|
"""Message route info tells who sent this message."""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
return route.get(MESSAGE_ROUTE_FROM)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tx_to(self):
|
|
|
|
|
"""Labels for the consumer to filter its subscribed messages."""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
return route.get(MESSAGE_ROUTE_TO)
|
|
|
|
|
|
2023-11-04 16:20:47 +08:00
|
|
|
def replace(self, old_val, new_val):
|
|
|
|
|
"""Replace old value with new value"""
|
|
|
|
|
route = self._get_route()
|
|
|
|
|
tags = route.get(MESSAGE_ROUTE_TO, set())
|
|
|
|
|
tags.discard(old_val)
|
|
|
|
|
tags.add(new_val)
|
|
|
|
|
route[MESSAGE_ROUTE_TO] = tags
|
|
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
class Message(BaseModel):
|
2023-06-30 17:10:48 +08:00
|
|
|
"""list[<role>: <content>]"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
content: str
|
2023-10-31 15:23:37 +08:00
|
|
|
instruct_content: BaseModel = None
|
|
|
|
|
meta_info: Dict = Field(default_factory=dict)
|
2023-11-01 20:08:58 +08:00
|
|
|
route: Routes = Field(default_factory=Routes)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
|
|
|
|
def __init__(self, content, **kwargs):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""
|
2023-11-03 11:53:47 +08:00
|
|
|
Parameters not listed below will be stored as meta info, including custom parameters.
|
2023-11-01 20:08:58 +08:00
|
|
|
:param content: Message content.
|
|
|
|
|
:param instruct_content: Message content struct.
|
|
|
|
|
:param meta_info: Message meta info.
|
|
|
|
|
:param route: Message route configuration.
|
|
|
|
|
:param tx_from: Message route info tells who sent this message.
|
|
|
|
|
:param tx_to: Labels for the consumer to filter its subscribed messages.
|
|
|
|
|
:param cause_by: Labels for the consumer to filter its subscribed messages, also serving as meta info.
|
|
|
|
|
:param role: Message meta info tells who sent this message.
|
|
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
super(Message, self).__init__(
|
|
|
|
|
content=content or kwargs.get("content"),
|
|
|
|
|
instruct_content=kwargs.get("instruct_content"),
|
|
|
|
|
meta_info=kwargs.get("meta_info", {}),
|
2023-11-01 20:08:58 +08:00
|
|
|
route=kwargs.get("route", Routes()),
|
2023-10-31 15:23:37 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attribute_names = Message.__annotations__.keys()
|
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
|
if k in attribute_names:
|
|
|
|
|
continue
|
2023-11-01 20:08:58 +08:00
|
|
|
if k == MESSAGE_ROUTE_FROM:
|
2023-11-04 16:46:32 +08:00
|
|
|
self.set_from(any_to_str(v))
|
2023-11-01 20:08:58 +08:00
|
|
|
continue
|
|
|
|
|
if k == MESSAGE_ROUTE_CAUSE_BY:
|
2023-11-04 16:46:32 +08:00
|
|
|
self.set_cause_by(v)
|
|
|
|
|
continue
|
|
|
|
|
if k == MESSAGE_ROUTE_TO:
|
2023-11-04 16:52:21 +08:00
|
|
|
self.add_to(any_to_str(v))
|
2023-11-01 20:08:58 +08:00
|
|
|
continue
|
2023-10-31 15:23:37 +08:00
|
|
|
self.meta_info[k] = v
|
|
|
|
|
|
|
|
|
|
def get_meta(self, key):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Get meta info"""
|
2023-10-31 15:23:37 +08:00
|
|
|
return self.meta_info.get(key)
|
|
|
|
|
|
|
|
|
|
def set_meta(self, key, value):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Set meta info"""
|
2023-10-31 15:23:37 +08:00
|
|
|
self.meta_info[key] = value
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def role(self):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Message meta info tells who sent this message."""
|
|
|
|
|
return self.get_meta(MESSAGE_META_ROLE)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def cause_by(self):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Labels for the consumer to filter its subscribed messages, also serving as meta info."""
|
|
|
|
|
return self.get_meta(MESSAGE_ROUTE_CAUSE_BY)
|
|
|
|
|
|
2023-11-04 16:20:47 +08:00
|
|
|
def __setattr__(self, key, val):
|
|
|
|
|
"""Override `@property.setter`"""
|
|
|
|
|
if key == MESSAGE_ROUTE_CAUSE_BY:
|
|
|
|
|
self.set_cause_by(val)
|
|
|
|
|
return
|
2023-11-04 16:46:32 +08:00
|
|
|
if key == MESSAGE_ROUTE_FROM:
|
|
|
|
|
self.set_from(any_to_str(val))
|
2023-11-04 16:20:47 +08:00
|
|
|
super().__setattr__(key, val)
|
|
|
|
|
|
|
|
|
|
def set_cause_by(self, val):
|
|
|
|
|
"""Update the value of `cause_by` in the `meta_info` and `routes` attributes."""
|
|
|
|
|
old_value = self.get_meta(MESSAGE_ROUTE_CAUSE_BY)
|
2023-11-04 16:46:32 +08:00
|
|
|
new_value = any_to_str(val)
|
2023-11-04 16:20:47 +08:00
|
|
|
self.set_meta(MESSAGE_ROUTE_CAUSE_BY, new_value)
|
|
|
|
|
self.route.replace(old_value, new_value)
|
|
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
@property
|
|
|
|
|
def tx_from(self):
|
|
|
|
|
"""Message route info tells who sent this message."""
|
|
|
|
|
return self.route.tx_from
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def tx_to(self):
|
|
|
|
|
"""Labels for the consumer to filter its subscribed messages."""
|
|
|
|
|
return self.route.tx_to
|
2023-10-31 15:23:37 +08:00
|
|
|
|
|
|
|
|
def set_role(self, v):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Set the message's meta info indicating the sender."""
|
|
|
|
|
self.set_meta(MESSAGE_META_ROLE, v)
|
|
|
|
|
|
|
|
|
|
def set_from(self, v):
|
|
|
|
|
"""Set the message's meta info indicating the sender."""
|
|
|
|
|
self.route.set_from(v)
|
|
|
|
|
|
|
|
|
|
def set_to(self, tags: Set):
|
|
|
|
|
"""Set the message's meta info indicating the sender."""
|
|
|
|
|
self.route.set_to(tags)
|
|
|
|
|
|
|
|
|
|
def add_to(self, tag: str):
|
|
|
|
|
"""Add a subscription label for the recipients."""
|
|
|
|
|
self.route.add_to(tag)
|
|
|
|
|
|
|
|
|
|
def is_recipient(self, tags: Set):
|
|
|
|
|
"""Return true if any input label exists in the message's subscription labels."""
|
|
|
|
|
return self.route.is_recipient(tags)
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
# prefix = '-'.join([self.role, str(self.cause_by)])
|
|
|
|
|
return f"{self.role}: {self.content}"
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return self.__str__()
|
|
|
|
|
|
|
|
|
|
def to_dict(self) -> dict:
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
2023-10-31 15:23:37 +08:00
|
|
|
return {"role": self.role, "content": self.content}
|
|
|
|
|
|
2023-11-04 14:26:48 +08:00
|
|
|
def dump(self) -> str:
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Convert the object to json string"""
|
2023-10-31 15:23:37 +08:00
|
|
|
return self.json(exclude_none=True)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def load(v):
|
2023-11-01 20:08:58 +08:00
|
|
|
"""Convert the json string to object."""
|
2023-10-31 15:23:37 +08:00
|
|
|
try:
|
|
|
|
|
d = json.loads(v)
|
|
|
|
|
return Message(**d)
|
|
|
|
|
except JSONDecodeError as err:
|
|
|
|
|
logger.error(f"parse json failed: {v}, error:{err}")
|
|
|
|
|
return None
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="user")
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class SystemMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="system")
|
2023-06-30 17:10:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIMessage(Message):
|
2023-08-08 12:44:33 +01:00
|
|
|
"""便于支持OpenAI的消息
|
2023-10-31 15:23:37 +08:00
|
|
|
Facilitate support for OpenAI messages
|
2023-08-08 12:44:33 +01:00
|
|
|
"""
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
def __init__(self, content: str):
|
2023-11-01 20:08:58 +08:00
|
|
|
super().__init__(content=content, role="assistant")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MessageQueue:
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Message queue which supports asynchronous updates."""
|
|
|
|
|
|
2023-11-01 20:08:58 +08:00
|
|
|
def __init__(self):
|
|
|
|
|
self._queue = Queue()
|
|
|
|
|
|
|
|
|
|
def pop(self) -> Message | None:
|
2023-11-01 20:35:37 +08:00
|
|
|
"""Pop one message from the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
try:
|
|
|
|
|
item = self._queue.get_nowait()
|
|
|
|
|
if item:
|
|
|
|
|
self._queue.task_done()
|
|
|
|
|
return item
|
|
|
|
|
except QueueEmpty:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def pop_all(self) -> List[Message]:
|
2023-11-01 20:35:37 +08:00
|
|
|
"""Pop all messages from the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
ret = []
|
|
|
|
|
while True:
|
|
|
|
|
msg = self.pop()
|
|
|
|
|
if not msg:
|
|
|
|
|
break
|
|
|
|
|
ret.append(msg)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def push(self, msg: Message):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Push a message into the queue."""
|
2023-11-01 20:08:58 +08:00
|
|
|
self._queue.put_nowait(msg)
|
|
|
|
|
|
|
|
|
|
def empty(self):
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Return true if the queue is empty."""
|
2023-11-01 20:08:58 +08:00
|
|
|
return self._queue.empty()
|
|
|
|
|
|
2023-11-04 14:26:48 +08:00
|
|
|
async def dump(self) -> str:
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Convert the `MessageQueue` object to a json string."""
|
2023-11-01 20:08:58 +08:00
|
|
|
if self.empty():
|
|
|
|
|
return "[]"
|
|
|
|
|
|
|
|
|
|
lst = []
|
|
|
|
|
try:
|
|
|
|
|
while True:
|
|
|
|
|
item = await wait_for(self._queue.get(), timeout=1.0)
|
|
|
|
|
if item is None:
|
|
|
|
|
break
|
|
|
|
|
lst.append(item.dict(exclude_none=True))
|
|
|
|
|
self._queue.task_done()
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
logger.debug("Queue is empty, exiting...")
|
|
|
|
|
return json.dumps(lst)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def load(self, v) -> "MessageQueue":
|
2023-11-01 20:33:34 +08:00
|
|
|
"""Convert the json string to the `MessageQueue` object."""
|
2023-11-01 20:08:58 +08:00
|
|
|
q = MessageQueue()
|
|
|
|
|
try:
|
|
|
|
|
lst = json.loads(v)
|
|
|
|
|
for i in lst:
|
|
|
|
|
msg = Message(**i)
|
|
|
|
|
q.push(msg)
|
|
|
|
|
except JSONDecodeError as e:
|
|
|
|
|
logger.warning(f"JSON load failed: {v}, error:{e}")
|
|
|
|
|
|
|
|
|
|
return q
|
2023-10-31 15:23:37 +08:00
|
|
|
|
2023-06-30 17:10:48 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
m = Message("a", role="v1")
|
|
|
|
|
m.set_role("v2")
|
2023-11-04 14:26:48 +08:00
|
|
|
v = m.dump()
|
2023-10-31 15:23:37 +08:00
|
|
|
m = Message.load(v)
|
2023-11-04 16:20:47 +08:00
|
|
|
m.cause_by = "Message"
|
|
|
|
|
m.cause_by = Routes
|
|
|
|
|
m.cause_by = Routes()
|
|
|
|
|
m.content = "b"
|
2023-06-30 17:10:48 +08:00
|
|
|
|
2023-10-31 15:23:37 +08:00
|
|
|
test_content = "test_message"
|
2023-06-30 17:10:48 +08:00
|
|
|
msgs = [
|
|
|
|
|
UserMessage(test_content),
|
|
|
|
|
SystemMessage(test_content),
|
|
|
|
|
AIMessage(test_content),
|
2023-10-31 15:23:37 +08:00
|
|
|
Message(test_content, role="QA"),
|
2023-06-30 17:10:48 +08:00
|
|
|
]
|
|
|
|
|
logger.info(msgs)
|
2023-10-31 15:23:37 +08:00
|
|
|
|
|
|
|
|
jsons = [
|
2023-11-04 14:26:48 +08:00
|
|
|
UserMessage(test_content).dump(),
|
|
|
|
|
SystemMessage(test_content).dump(),
|
|
|
|
|
AIMessage(test_content).dump(),
|
|
|
|
|
Message(test_content, role="QA").dump(),
|
2023-10-31 15:23:37 +08:00
|
|
|
]
|
|
|
|
|
logger.info(jsons)
|