mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
refactor: Message
This commit is contained in:
parent
d8adba99d4
commit
5e8ada5cff
2 changed files with 92 additions and 36 deletions
|
|
@ -4,13 +4,15 @@
|
|||
@Time : 2023/5/8 22:12
|
||||
@Author : alexanderwu
|
||||
@File : schema.py
|
||||
@Modified By: mashenquan, 2023-10-31, optimize class members.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, TypedDict
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, List, TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
|
@ -20,16 +22,44 @@ class RawMessage(TypedDict):
|
|||
role: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
||||
content: str
|
||||
instruct_content: BaseModel = field(default=None)
|
||||
role: str = field(default='user') # system / user / assistant
|
||||
cause_by: Type["Action"] = field(default="")
|
||||
sent_from: str = field(default="")
|
||||
send_to: str = field(default="")
|
||||
restricted_to: str = field(default="")
|
||||
instruct_content: BaseModel = None
|
||||
meta_info: Dict = Field(default_factory=dict)
|
||||
route: List[Dict] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, content, **kwargs):
|
||||
super(Message, self).__init__(
|
||||
content=content or kwargs.get("content"),
|
||||
instruct_content=kwargs.get("instruct_content"),
|
||||
meta_info=kwargs.get("meta_info", {}),
|
||||
route=kwargs.get("route", []),
|
||||
)
|
||||
|
||||
attribute_names = Message.__annotations__.keys()
|
||||
for k, v in kwargs.items():
|
||||
if k in attribute_names:
|
||||
continue
|
||||
self.meta_info[k] = v
|
||||
|
||||
def get_meta(self, key):
|
||||
return self.meta_info.get(key)
|
||||
|
||||
def set_meta(self, key, value):
|
||||
self.meta_info[key] = value
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
return self.get_meta("role")
|
||||
|
||||
@property
|
||||
def cause_by(self):
|
||||
return self.get_meta("cause_by")
|
||||
|
||||
def set_role(self, v):
|
||||
self.set_meta("role", v)
|
||||
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
|
|
@ -39,45 +69,67 @@ class Message:
|
|||
return self.__str__()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"role": self.role,
|
||||
"content": self.content
|
||||
}
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
def save(self) -> str:
|
||||
return self.json(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def load(v):
|
||||
try:
|
||||
d = json.loads(v)
|
||||
return Message(**d)
|
||||
except JSONDecodeError as err:
|
||||
logger.error(f"parse json failed: {v}, error:{err}")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content, 'user')
|
||||
super(Message, self).__init__(content=content, meta_info={"role": "user"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content, 'system')
|
||||
super().__init__(content=content, meta_info={"role": "system"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIMessage(Message):
|
||||
"""便于支持OpenAI的消息
|
||||
Facilitate support for OpenAI messages
|
||||
Facilitate support for OpenAI messages
|
||||
"""
|
||||
|
||||
def __init__(self, content: str):
|
||||
super().__init__(content, 'assistant')
|
||||
super().__init__(content=content, meta_info={"role": "assistant"})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_content = 'test_message'
|
||||
if __name__ == "__main__":
|
||||
m = Message("a", role="v1")
|
||||
m.set_role("v2")
|
||||
v = m.save()
|
||||
m = Message.load(v)
|
||||
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
Message(test_content, role="QA"),
|
||||
]
|
||||
logger.info(msgs)
|
||||
|
||||
jsons = [
|
||||
UserMessage(test_content).save(),
|
||||
SystemMessage(test_content).save(),
|
||||
AIMessage(test_content).save(),
|
||||
Message(test_content, role="QA").save(),
|
||||
]
|
||||
logger.info(jsons)
|
||||
|
|
|
|||
|
|
@ -11,26 +11,30 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe
|
|||
|
||||
|
||||
def test_message():
|
||||
msg = Message(role='User', content='WTF')
|
||||
assert msg.to_dict()['role'] == 'User'
|
||||
assert 'User' in str(msg)
|
||||
msg = Message(role="User", content="WTF")
|
||||
assert msg.to_dict()["role"] == "User"
|
||||
assert "User" in str(msg)
|
||||
|
||||
|
||||
def test_all_messages():
|
||||
test_content = 'test_message'
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role='QA')
|
||||
Message(test_content, role="QA"),
|
||||
]
|
||||
for msg in msgs:
|
||||
assert msg.content == test_content
|
||||
|
||||
|
||||
def test_raw_message():
|
||||
msg = RawMessage(role='user', content='raw')
|
||||
assert msg['role'] == 'user'
|
||||
assert msg['content'] == 'raw'
|
||||
msg = RawMessage(role="user", content="raw")
|
||||
assert msg["role"] == "user"
|
||||
assert msg["content"] == "raw"
|
||||
with pytest.raises(KeyError):
|
||||
assert msg['1'] == 1, "KeyError: '1'"
|
||||
assert msg["1"] == 1, "KeyError: '1'"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue