diff --git a/metagpt/schema.py b/metagpt/schema.py index bdca093c2..1124fb28e 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -4,13 +4,15 @@ @Time : 2023/5/8 22:12 @Author : alexanderwu @File : schema.py +@Modified By: mashenquan, 2023-10-31, optimize class members. """ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Type, TypedDict +import json +from json import JSONDecodeError +from typing import Dict, List, TypedDict -from pydantic import BaseModel +from pydantic import BaseModel, Field from metagpt.logs import logger @@ -20,16 +22,44 @@ class RawMessage(TypedDict): role: str -@dataclass -class Message: +class Message(BaseModel): """list[: ]""" + content: str - instruct_content: BaseModel = field(default=None) - role: str = field(default='user') # system / user / assistant - cause_by: Type["Action"] = field(default="") - sent_from: str = field(default="") - send_to: str = field(default="") - restricted_to: str = field(default="") + instruct_content: BaseModel = None + meta_info: Dict = Field(default_factory=dict) + route: List[Dict] = Field(default_factory=list) + + def __init__(self, content, **kwargs): + super(Message, self).__init__( + content=content or kwargs.get("content"), + instruct_content=kwargs.get("instruct_content"), + meta_info=kwargs.get("meta_info", {}), + route=kwargs.get("route", []), + ) + + attribute_names = Message.__annotations__.keys() + for k, v in kwargs.items(): + if k in attribute_names: + continue + self.meta_info[k] = v + + def get_meta(self, key): + return self.meta_info.get(key) + + def set_meta(self, key, value): + self.meta_info[key] = value + + @property + def role(self): + return self.get_meta("role") + + @property + def cause_by(self): + return self.get_meta("cause_by") + + def set_role(self, v): + self.set_meta("role", v) def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -39,45 +69,67 @@ class Message: return self.__str__() def to_dict(self) -> dict: - return { - "role": self.role, - "content": self.content - } + return {"role": self.role, "content": self.content} + + def save(self) -> str: + return self.json(exclude_none=True) + + @staticmethod + def load(v): + try: + d = json.loads(v) + return Message(**d) + except JSONDecodeError as err: + logger.error(f"parse json failed: {v}, error:{err}") + return None -@dataclass class UserMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ + def __init__(self, content: str): - super().__init__(content, 'user') + super(Message, self).__init__(content=content, meta_info={"role": "user"}) -@dataclass class SystemMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ + def __init__(self, content: str): - super().__init__(content, 'system') + super().__init__(content=content, meta_info={"role": "system"}) -@dataclass class AIMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ + def __init__(self, content: str): - super().__init__(content, 'assistant') + super().__init__(content=content, meta_info={"role": "assistant"}) -if __name__ == '__main__': - test_content = 'test_message' +if __name__ == "__main__": + m = Message("a", role="v1") + m.set_role("v2") + v = m.save() + m = Message.load(v) + + test_content = "test_message" msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role='QA') + Message(test_content, role="QA"), ] logger.info(msgs) + + jsons = [ + UserMessage(test_content).save(), + SystemMessage(test_content).save(), + AIMessage(test_content).save(), + Message(test_content, role="QA").save(), + ] + logger.info(jsons) diff --git a/tests/metagpt/test_message.py b/tests/metagpt/test_message.py index e26f38381..4f46311ce 100644 --- a/tests/metagpt/test_message.py +++ b/tests/metagpt/test_message.py @@ -11,26 +11,30 @@ from metagpt.schema import AIMessage, Message, RawMessage, SystemMessage, UserMe def test_message(): - msg = Message(role='User', content='WTF') - assert msg.to_dict()['role'] == 'User' - assert 'User' in str(msg) + msg = Message(role="User", content="WTF") + assert msg.to_dict()["role"] == "User" + assert "User" in str(msg) def test_all_messages(): - test_content = 'test_message' + test_content = "test_message" msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role='QA') + Message(test_content, role="QA"), ] for msg in msgs: assert msg.content == test_content def test_raw_message(): - msg = RawMessage(role='user', content='raw') - assert msg['role'] == 'user' - assert msg['content'] == 'raw' + msg = RawMessage(role="user", content="raw") + assert msg["role"] == "user" + assert msg["content"] == "raw" with pytest.raises(KeyError): - assert msg['1'] == 1, "KeyError: '1'" + assert msg["1"] == 1, "KeyError: '1'" + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])