diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 9d1b038bb..cb67fea8e 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -1,10 +1,11 @@ from enum import Enum -from typing import List +from typing import List, Dict import pydantic from metagpt import Message + class MessageType(Enum): Talk = "TALK" Solution = "SOLUTION" @@ -14,29 +15,28 @@ class MessageType(Enum): class BrainMemory(pydantic.BaseModel): - history: List[Message] = [] - stack: List[Message] = [] - solution: List[Message] = [] - knowledge: List[Message] = [] - + history: List[Dict] = [] + stack: List[Dict] = [] + solution: List[Dict] = [] + knowledge: List[Dict] = [] def add_talk(self, msg: Message): msg.add_tag(MessageType.Talk.value) - self.history.append(msg) + self.history.append(msg.dict()) def add_answer(self, msg: Message): msg.add_tag(MessageType.Answer.value) - self.history.append(msg) + self.history.append(msg.dict()) def get_knowledge(self) -> str: - texts = [k.content for k in self.knowledge] + texts = [Message(**m).content for m in self.knowledge] return "\n".join(texts) @property def history_text(self): if len(self.history) == 0: return "" - texts = [m.content for m in self.history[:-1]] + texts = [Message(**m).content for m in self.history[:-1]] return "\n".join(texts) def move_to_solution(self): @@ -44,7 +44,7 @@ class BrainMemory(pydantic.BaseModel): return msgs = self.history[:-1] self.solution.extend(msgs) - if not self.history[-1].is_contain(MessageType.Talk.value): + if not Message(**self.history[-1]).is_contain(MessageType.Talk.value): self.solution.append(self.history[-1]) self.history = [] else: @@ -52,7 +52,9 @@ class BrainMemory(pydantic.BaseModel): @property def last_talk(self): - if len(self.history) == 0 or not self.history[-1].is_contain_tags([MessageType.Talk.value]): + if len(self.history) == 0: return None - return self.history[-1].content - + last_msg = Message(**self.history[-1]) + if not last_msg.is_contain(MessageType.Talk.value): + return None + return last_msg.content diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 199cdcafd..4519fcdb8 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -77,7 +77,7 @@ class Assistant(Role): return output async def talk(self, text): - self.memory.add_talk(Message(content=text, tags=set([MessageType.Talk.value]))) + self.memory.add_talk(Message(content=text)) async def _plan(self, rsp: str, **kwargs) -> bool: skill, text = Assistant.extract_info(input_string=rsp) diff --git a/metagpt/schema.py b/metagpt/schema.py index 909313886..4d7f0cc21 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -70,6 +70,22 @@ class Message: def is_contain(self, tag): return self.is_contain_tags([tag]) + def dict(self): + """pydantic-like `dict` function""" + full = { + "instruct_content": self.instruct_content, + "cause_by": self.cause_by, + "sent_from": self.sent_from, + "send_to": self.send_to, + "tags": self.tags + } + + m = {"content": self.content} + for k, v in full.items(): + if v: + m[k] = v + return m + @dataclass class UserMessage(Message): @@ -101,7 +117,6 @@ class AIMessage(Message): super().__init__(content, 'assistant') - if __name__ == '__main__': test_content = 'test_message' msgs = [ diff --git a/tests/metagpt/memory/test_brain_memory.py b/tests/metagpt/memory/test_brain_memory.py new file mode 100644 index 000000000..b5fc942ca --- /dev/null +++ b/tests/metagpt/memory/test_brain_memory.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/27 +@Author : mashenquan +@File : test_brain_memory.py +""" +import json +from typing import List + +import pydantic + +from metagpt.memory.brain_memory import BrainMemory +from metagpt.schema import Message + + +def test_json(): + class Input(pydantic.BaseModel): + history: List[str] + solution: List[str] + knowledge: List[str] + stack: List[str] + + inputs = [ + { + "history": ["a", "b"], + "solution": ["c"], + "knowledge": ["d", "e"], + "stack": ["f"] + } + ] + + for i in inputs: + v = Input(**i) + bm = BrainMemory() + for h in v.history: + msg = Message(content=h) + bm.history.append(msg.dict()) + for h in v.solution: + msg = Message(content=h) + bm.solution.append(msg.dict()) + for h in v.knowledge: + msg = Message(content=h) + bm.knowledge.append(msg.dict()) + for h in v.stack: + msg = Message(content=h) + bm.stack.append(msg.dict()) + s = bm.json() + m = json.loads(s) + bm = BrainMemory(**m) + assert bm + for v in bm.history: + msg = Message(**v) + assert msg + +if __name__ == '__main__': + test_json() \ No newline at end of file