refactor: brain memory

This commit is contained in:
莘权 马 2023-09-07 19:03:41 +08:00
parent 4c82298e88
commit 530d2f5b30
3 changed files with 123 additions and 120 deletions

View file

@ -8,12 +8,16 @@
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
"""
import json
import re
from enum import Enum
from typing import Dict, List
import openai
import pydantic
from metagpt import Message
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
from metagpt.logs import logger
from metagpt.schema import RawMessage
from metagpt.utils.redis import Redis
@ -36,6 +40,7 @@ class BrainMemory(pydantic.BaseModel):
last_history_id: str = ""
is_dirty: bool = False
last_talk: str = None
llm_type: str
def add_talk(self, msg: Message):
msg.add_tag(MessageType.Talk.value)
@ -172,3 +177,111 @@ class BrainMemory(pydantic.BaseModel):
self.history = []
self.is_dirty = True
return self.historical_summary
async def get_summary(self, text: str, llm, max_words=200, keep_language: bool = False, **kwargs):
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
while max_count > 0:
if text_length < max_token_count:
return await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language)
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
return summaries[0]
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
raise openai.error.InvalidRequestError("text too long")
async def _get_summary(self, text: str, llm, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
async def get_title(self, text: str, llm, max_words=5, **kwargs) -> str:
"""Generate text title"""
summary = await self.get_summary(text, max_words=500)
language = CONFIG.language or DEFAULT_LANGUAGE
command = f"Translate the above summary into a {language} title of less than {max_words} words."
summaries = [summary, command]
msg = "\n".join(summaries)
logger.debug(f"title ask:{msg}")
response = await llm.aask(msg=msg, system_msgs=[])
logger.debug(f"title rsp: {response}")
return response
async def is_related(self, text1, text2, llm):
# command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]."
command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
rsp = await llm.aask(msg=command, system_msgs=[])
result = True if "TRUE" in rsp else False
p2 = text2.replace("\n", "")
p1 = text1.replace("\n", "")
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
return result
async def rewrite(self, sentence: str, context: str, llm):
# command = (
# f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}"
# )
command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}"
rsp = await llm.aask(msg=command, system_msgs=[])
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
return rsp
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = BrainMemory.DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows
@staticmethod
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
match = re.match(pattern, input_string)
if match:
return match.group(1), match.group(2)
else:
return None, input_string
DEFAULT_TOKEN_SIZE = 500

View file

@ -8,10 +8,8 @@
"""
import asyncio
import random
import re
import time
import traceback
from typing import List
import openai
from openai.error import APIConnectionError
@ -24,7 +22,6 @@ from tenacity import (
)
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.utils.cost_manager import Costs
@ -223,112 +220,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
async def get_summary(self, text: str, max_words=200, keep_language: bool = False, **kwargs):
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
while max_count > 0:
if text_length < max_token_count:
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
return summaries[0]
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
raise openai.error.InvalidRequestError("text too long")
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
async def get_context_title(self, text: str, max_words=5, **kwargs) -> str:
"""Generate text title"""
summary = await self.get_summary(text, max_words=500)
language = CONFIG.language or DEFAULT_LANGUAGE
command = f"Translate the above summary into a {language} title of less than {max_words} words."
summaries = [summary, command]
msg = "\n".join(summaries)
logger.debug(f"title ask:{msg}")
response = await self.aask(msg=msg, system_msgs=[])
logger.debug(f"title rsp: {response}")
return response
async def is_related(self, text1, text2):
# command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]."
command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
rsp = await self.aask(msg=command, system_msgs=[])
result = True if "TRUE" in rsp else False
p2 = text2.replace("\n", "")
p1 = text1.replace("\n", "")
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
return result
async def rewrite(self, sentence: str, context: str):
# command = (
# f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}"
# )
command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}"
rsp = await self.aask(msg=command, system_msgs=[])
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
return rsp
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = OpenAIGPTAPI.DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows
@staticmethod
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
match = re.match(pattern, input_string)
if match:
return match.group(1), match.group(2)
else:
return None, input_string
@staticmethod
async def async_retry_call(func, *args, **kwargs):
for i in range(OpenAIGPTAPI.MAX_TRY):
@ -371,7 +262,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
raise openai.error.OpenAIError("Exceeds the maximum retries")
MAX_TRY = 5
DEFAULT_TOKEN_SIZE = 500
if __name__ == "__main__":

View file

@ -121,23 +121,23 @@ class Assistant(Role):
return None
if history_text == "":
return last_talk
history_summary = await self._llm.get_summary(
text=history_text, max_words=800, keep_language=True, memory=self.memory
history_summary = await self.memory.get_summary(
text=history_text, max_words=800, keep_language=True, llm=self._llm
)
await self.memory.set_history_summary(
history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
)
if last_talk and await self._llm.is_related(last_talk, history_summary): # Merge relevant content.
last_talk = await self._llm.rewrite(sentence=last_talk, context=history_text)
# await self.memory.set_history_summary(
# history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
# )
if last_talk and await self.memory.is_related(
text1=last_talk, text2=history_summary, llm=self._llm
): # Merge relevant content.
last_talk = await self.memory.rewrite(sentence=last_talk, context=history_text, llm=self._llm)
return last_talk
return last_talk
@staticmethod
def extract_info(input_string):
from metagpt.provider.openai_api import OpenAIGPTAPI
return OpenAIGPTAPI.extract_info(input_string)
return BrainMemory.extract_info(input_string)
def get_memory(self) -> str:
return self.memory.json()