mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-11 00:32:37 +02:00
fixbug: brain memory serialize
This commit is contained in:
parent
903e89cec3
commit
3e9151e52e
4 changed files with 90 additions and 16 deletions
|
|
@ -1,10 +1,11 @@
|
|||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
import pydantic
|
||||
|
||||
from metagpt import Message
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
Talk = "TALK"
|
||||
Solution = "SOLUTION"
|
||||
|
|
@ -14,29 +15,28 @@ class MessageType(Enum):
|
|||
|
||||
|
||||
class BrainMemory(pydantic.BaseModel):
|
||||
history: List[Message] = []
|
||||
stack: List[Message] = []
|
||||
solution: List[Message] = []
|
||||
knowledge: List[Message] = []
|
||||
|
||||
history: List[Dict] = []
|
||||
stack: List[Dict] = []
|
||||
solution: List[Dict] = []
|
||||
knowledge: List[Dict] = []
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
msg.add_tag(MessageType.Talk.value)
|
||||
self.history.append(msg)
|
||||
self.history.append(msg.dict())
|
||||
|
||||
def add_answer(self, msg: Message):
|
||||
msg.add_tag(MessageType.Answer.value)
|
||||
self.history.append(msg)
|
||||
self.history.append(msg.dict())
|
||||
|
||||
def get_knowledge(self) -> str:
|
||||
texts = [k.content for k in self.knowledge]
|
||||
texts = [Message(**m).content for m in self.knowledge]
|
||||
return "\n".join(texts)
|
||||
|
||||
@property
|
||||
def history_text(self):
|
||||
if len(self.history) == 0:
|
||||
return ""
|
||||
texts = [m.content for m in self.history[:-1]]
|
||||
texts = [Message(**m).content for m in self.history[:-1]]
|
||||
return "\n".join(texts)
|
||||
|
||||
def move_to_solution(self):
|
||||
|
|
@ -44,7 +44,7 @@ class BrainMemory(pydantic.BaseModel):
|
|||
return
|
||||
msgs = self.history[:-1]
|
||||
self.solution.extend(msgs)
|
||||
if not self.history[-1].is_contain(MessageType.Talk.value):
|
||||
if not Message(**self.history[-1]).is_contain(MessageType.Talk.value):
|
||||
self.solution.append(self.history[-1])
|
||||
self.history = []
|
||||
else:
|
||||
|
|
@ -52,7 +52,9 @@ class BrainMemory(pydantic.BaseModel):
|
|||
|
||||
@property
|
||||
def last_talk(self):
|
||||
if len(self.history) == 0 or not self.history[-1].is_contain_tags([MessageType.Talk.value]):
|
||||
if len(self.history) == 0:
|
||||
return None
|
||||
return self.history[-1].content
|
||||
|
||||
last_msg = Message(**self.history[-1])
|
||||
if not last_msg.is_contain(MessageType.Talk.value):
|
||||
return None
|
||||
return last_msg.content
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class Assistant(Role):
|
|||
return output
|
||||
|
||||
async def talk(self, text):
|
||||
self.memory.add_talk(Message(content=text, tags=set([MessageType.Talk.value])))
|
||||
self.memory.add_talk(Message(content=text))
|
||||
|
||||
async def _plan(self, rsp: str, **kwargs) -> bool:
|
||||
skill, text = Assistant.extract_info(input_string=rsp)
|
||||
|
|
|
|||
|
|
@ -70,6 +70,22 @@ class Message:
|
|||
def is_contain(self, tag):
|
||||
return self.is_contain_tags([tag])
|
||||
|
||||
def dict(self):
|
||||
"""pydantic-like `dict` function"""
|
||||
full = {
|
||||
"instruct_content": self.instruct_content,
|
||||
"cause_by": self.cause_by,
|
||||
"sent_from": self.sent_from,
|
||||
"send_to": self.send_to,
|
||||
"tags": self.tags
|
||||
}
|
||||
|
||||
m = {"content": self.content}
|
||||
for k, v in full.items():
|
||||
if v:
|
||||
m[k] = v
|
||||
return m
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage(Message):
|
||||
|
|
@ -101,7 +117,6 @@ class AIMessage(Message):
|
|||
super().__init__(content, 'assistant')
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_content = 'test_message'
|
||||
msgs = [
|
||||
|
|
|
|||
57
tests/metagpt/memory/test_brain_memory.py
Normal file
57
tests/metagpt/memory/test_brain_memory.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/27
|
||||
@Author : mashenquan
|
||||
@File : test_brain_memory.py
|
||||
"""
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import pydantic
|
||||
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_json():
|
||||
class Input(pydantic.BaseModel):
|
||||
history: List[str]
|
||||
solution: List[str]
|
||||
knowledge: List[str]
|
||||
stack: List[str]
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"history": ["a", "b"],
|
||||
"solution": ["c"],
|
||||
"knowledge": ["d", "e"],
|
||||
"stack": ["f"]
|
||||
}
|
||||
]
|
||||
|
||||
for i in inputs:
|
||||
v = Input(**i)
|
||||
bm = BrainMemory()
|
||||
for h in v.history:
|
||||
msg = Message(content=h)
|
||||
bm.history.append(msg.dict())
|
||||
for h in v.solution:
|
||||
msg = Message(content=h)
|
||||
bm.solution.append(msg.dict())
|
||||
for h in v.knowledge:
|
||||
msg = Message(content=h)
|
||||
bm.knowledge.append(msg.dict())
|
||||
for h in v.stack:
|
||||
msg = Message(content=h)
|
||||
bm.stack.append(msg.dict())
|
||||
s = bm.json()
|
||||
m = json.loads(s)
|
||||
bm = BrainMemory(**m)
|
||||
assert bm
|
||||
for v in bm.history:
|
||||
msg = Message(**v)
|
||||
assert msg
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_json()
|
||||
Loading…
Add table
Add a link
Reference in a new issue