refactor: Message

This commit is contained in:
莘权 马 2023-10-31 15:23:37 +08:00
parent d8adba99d4
commit 5e8ada5cff
2 changed files with 92 additions and 36 deletions

View file

@ -4,13 +4,15 @@
@Time : 2023/5/8 22:12
@Author : alexanderwu
@File : schema.py
@Modified By: mashenquan, 2023-10-31, optimize class members.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Type, TypedDict
import json
from json import JSONDecodeError
from typing import Dict, List, TypedDict
from pydantic import BaseModel
from pydantic import BaseModel, Field
from metagpt.logs import logger
@ -20,16 +22,44 @@ class RawMessage(TypedDict):
role: str
@dataclass
class Message:
class Message(BaseModel):
"""list[<role>: <content>]"""
content: str
instruct_content: BaseModel = field(default=None)
role: str = field(default='user') # system / user / assistant
cause_by: Type["Action"] = field(default="")
sent_from: str = field(default="")
send_to: str = field(default="")
restricted_to: str = field(default="")
instruct_content: BaseModel = None
meta_info: Dict = Field(default_factory=dict)
route: List[Dict] = Field(default_factory=list)
def __init__(self, content, **kwargs):
super(Message, self).__init__(
content=content or kwargs.get("content"),
instruct_content=kwargs.get("instruct_content"),
meta_info=kwargs.get("meta_info", {}),
route=kwargs.get("route", []),
)
attribute_names = Message.__annotations__.keys()
for k, v in kwargs.items():
if k in attribute_names:
continue
self.meta_info[k] = v
def get_meta(self, key):
return self.meta_info.get(key)
def set_meta(self, key, value):
self.meta_info[key] = value
@property
def role(self):
return self.get_meta("role")
@property
def cause_by(self):
return self.get_meta("cause_by")
def set_role(self, v):
self.set_meta("role", v)
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
@ -39,45 +69,67 @@ class Message:
return self.__str__()
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content
}
return {"role": self.role, "content": self.content}
def save(self) -> str:
return self.json(exclude_none=True)
@staticmethod
def load(v):
try:
d = json.loads(v)
return Message(**d)
except JSONDecodeError as err:
logger.error(f"parse json failed: {v}, error:{err}")
return None
@dataclass
class UserMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'user')
super(Message, self).__init__(content=content, meta_info={"role": "user"})
@dataclass
class SystemMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'system')
super().__init__(content=content, meta_info={"role": "system"})
@dataclass
class AIMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'assistant')
super().__init__(content=content, meta_info={"role": "assistant"})
if __name__ == '__main__':
test_content = 'test_message'
if __name__ == "__main__":
m = Message("a", role="v1")
m.set_role("v2")
v = m.save()
m = Message.load(v)
test_content = "test_message"
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
Message(test_content, role="QA"),
]
logger.info(msgs)
jsons = [
UserMessage(test_content).save(),
SystemMessage(test_content).save(),
AIMessage(test_content).save(),
Message(test_content, role="QA").save(),
]
logger.info(jsons)

View file

@ -11,26 +11,30 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe
def test_message():
msg = Message(role='User', content='WTF')
assert msg.to_dict()['role'] == 'User'
assert 'User' in str(msg)
msg = Message(role="User", content="WTF")
assert msg.to_dict()["role"] == "User"
assert "User" in str(msg)
def test_all_messages():
test_content = 'test_message'
test_content = "test_message"
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
Message(test_content, role="QA"),
]
for msg in msgs:
assert msg.content == test_content
def test_raw_message():
msg = RawMessage(role='user', content='raw')
assert msg['role'] == 'user'
assert msg['content'] == 'raw'
msg = RawMessage(role="user", content="raw")
assert msg["role"] == "user"
assert msg["content"] == "raw"
with pytest.raises(KeyError):
assert msg['1'] == 1, "KeyError: '1'"
assert msg["1"] == 1, "KeyError: '1'"
if __name__ == "__main__":
pytest.main([__file__, "-s"])