fixbug: brain memory serialize

This commit is contained in:
莘权 马 2023-08-27 15:11:28 +08:00
parent 903e89cec3
commit 3e9151e52e
4 changed files with 90 additions and 16 deletions

View file

@ -1,10 +1,11 @@
from enum import Enum
from typing import List
from typing import List, Dict
import pydantic
from metagpt import Message
class MessageType(Enum):
Talk = "TALK"
Solution = "SOLUTION"
@ -14,29 +15,28 @@ class MessageType(Enum):
class BrainMemory(pydantic.BaseModel):
history: List[Message] = []
stack: List[Message] = []
solution: List[Message] = []
knowledge: List[Message] = []
history: List[Dict] = []
stack: List[Dict] = []
solution: List[Dict] = []
knowledge: List[Dict] = []
def add_talk(self, msg: Message):
msg.add_tag(MessageType.Talk.value)
self.history.append(msg)
self.history.append(msg.dict())
def add_answer(self, msg: Message):
msg.add_tag(MessageType.Answer.value)
self.history.append(msg)
self.history.append(msg.dict())
def get_knowledge(self) -> str:
texts = [k.content for k in self.knowledge]
texts = [Message(**m).content for m in self.knowledge]
return "\n".join(texts)
@property
def history_text(self):
if len(self.history) == 0:
return ""
texts = [m.content for m in self.history[:-1]]
texts = [Message(**m).content for m in self.history[:-1]]
return "\n".join(texts)
def move_to_solution(self):
@ -44,7 +44,7 @@ class BrainMemory(pydantic.BaseModel):
return
msgs = self.history[:-1]
self.solution.extend(msgs)
if not self.history[-1].is_contain(MessageType.Talk.value):
if not Message(**self.history[-1]).is_contain(MessageType.Talk.value):
self.solution.append(self.history[-1])
self.history = []
else:
@ -52,7 +52,9 @@ class BrainMemory(pydantic.BaseModel):
@property
def last_talk(self):
if len(self.history) == 0 or not self.history[-1].is_contain_tags([MessageType.Talk.value]):
if len(self.history) == 0:
return None
return self.history[-1].content
last_msg = Message(**self.history[-1])
if not last_msg.is_contain(MessageType.Talk.value):
return None
return last_msg.content

View file

@ -77,7 +77,7 @@ class Assistant(Role):
return output
async def talk(self, text):
self.memory.add_talk(Message(content=text, tags=set([MessageType.Talk.value])))
self.memory.add_talk(Message(content=text))
async def _plan(self, rsp: str, **kwargs) -> bool:
skill, text = Assistant.extract_info(input_string=rsp)

View file

@ -70,6 +70,22 @@ class Message:
def is_contain(self, tag):
return self.is_contain_tags([tag])
def dict(self):
"""pydantic-like `dict` function"""
full = {
"instruct_content": self.instruct_content,
"cause_by": self.cause_by,
"sent_from": self.sent_from,
"send_to": self.send_to,
"tags": self.tags
}
m = {"content": self.content}
for k, v in full.items():
if v:
m[k] = v
return m
@dataclass
class UserMessage(Message):
@ -101,7 +117,6 @@ class AIMessage(Message):
super().__init__(content, 'assistant')
if __name__ == '__main__':
test_content = 'test_message'
msgs = [

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/27
@Author : mashenquan
@File : test_brain_memory.py
"""
import json
from typing import List
import pydantic
from metagpt.memory.brain_memory import BrainMemory
from metagpt.schema import Message
def test_json():
class Input(pydantic.BaseModel):
history: List[str]
solution: List[str]
knowledge: List[str]
stack: List[str]
inputs = [
{
"history": ["a", "b"],
"solution": ["c"],
"knowledge": ["d", "e"],
"stack": ["f"]
}
]
for i in inputs:
v = Input(**i)
bm = BrainMemory()
for h in v.history:
msg = Message(content=h)
bm.history.append(msg.dict())
for h in v.solution:
msg = Message(content=h)
bm.solution.append(msg.dict())
for h in v.knowledge:
msg = Message(content=h)
bm.knowledge.append(msg.dict())
for h in v.stack:
msg = Message(content=h)
bm.stack.append(msg.dict())
s = bm.json()
m = json.loads(s)
bm = BrainMemory(**m)
assert bm
for v in bm.history:
msg = Message(**v)
assert msg
if __name__ == '__main__':
test_json()