mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
refactor: brain memory
This commit is contained in:
parent
4c82298e88
commit
530d2f5b30
3 changed files with 123 additions and 120 deletions
|
|
@ -8,12 +8,16 @@
|
|||
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Dict, List
|
||||
|
||||
import openai
|
||||
import pydantic
|
||||
|
||||
from metagpt import Message
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RawMessage
|
||||
from metagpt.utils.redis import Redis
|
||||
|
|
@ -36,6 +40,7 @@ class BrainMemory(pydantic.BaseModel):
|
|||
last_history_id: str = ""
|
||||
is_dirty: bool = False
|
||||
last_talk: str = None
|
||||
llm_type: str
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
msg.add_tag(MessageType.Talk.value)
|
||||
|
|
@ -172,3 +177,111 @@ class BrainMemory(pydantic.BaseModel):
|
|||
self.history = []
|
||||
self.is_dirty = True
|
||||
return self.historical_summary
|
||||
|
||||
async def get_summary(self, text: str, llm, max_words=200, keep_language: bool = False, **kwargs):
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
while max_count > 0:
|
||||
if text_length < max_token_count:
|
||||
return await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language)
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
return summaries[0]
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
raise openai.error.InvalidRequestError("text too long")
|
||||
|
||||
async def _get_summary(self, text: str, llm, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
return response
|
||||
|
||||
async def get_title(self, text: str, llm, max_words=5, **kwargs) -> str:
|
||||
"""Generate text title"""
|
||||
summary = await self.get_summary(text, max_words=500)
|
||||
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
command = f"Translate the above summary into a {language} title of less than {max_words} words."
|
||||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
logger.debug(f"title ask:{msg}")
|
||||
response = await llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"title rsp: {response}")
|
||||
return response
|
||||
|
||||
async def is_related(self, text1, text2, llm):
|
||||
# command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]."
|
||||
command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
result = True if "TRUE" in rsp else False
|
||||
p2 = text2.replace("\n", "")
|
||||
p1 = text1.replace("\n", "")
|
||||
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
|
||||
return result
|
||||
|
||||
async def rewrite(self, sentence: str, context: str, llm):
|
||||
# command = (
|
||||
# f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}"
|
||||
# )
|
||||
command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}"
|
||||
rsp = await llm.aask(msg=command, system_msgs=[])
|
||||
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
|
||||
return rsp
|
||||
|
||||
@staticmethod
|
||||
def split_texts(text: str, window_size) -> List[str]:
|
||||
"""Splitting long text into sliding windows text"""
|
||||
if window_size <= 0:
|
||||
window_size = BrainMemory.DEFAULT_TOKEN_SIZE
|
||||
total_len = len(text)
|
||||
if total_len <= window_size:
|
||||
return [text]
|
||||
|
||||
padding_size = 20 if window_size > 20 else 0
|
||||
windows = []
|
||||
idx = 0
|
||||
data_len = window_size - padding_size
|
||||
while idx < total_len:
|
||||
if window_size + idx > total_len: # 不足一个滑窗
|
||||
windows.append(text[idx:])
|
||||
break
|
||||
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
|
||||
# window_size=3, padding_size=1:
|
||||
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
|
||||
# idx=2, | idx=5 | idx=8 | ...
|
||||
w = text[idx : idx + window_size]
|
||||
windows.append(w)
|
||||
idx += data_len
|
||||
|
||||
return windows
|
||||
|
||||
@staticmethod
|
||||
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
|
||||
match = re.match(pattern, input_string)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
else:
|
||||
return None, input_string
|
||||
|
||||
DEFAULT_TOKEN_SIZE = 500
|
||||
|
|
|
|||
|
|
@ -8,10 +8,8 @@
|
|||
"""
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
from openai.error import APIConnectionError
|
||||
|
|
@ -24,7 +22,6 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
|
@ -223,112 +220,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
async def get_summary(self, text: str, max_words=200, keep_language: bool = False, **kwargs):
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
while max_count > 0:
|
||||
if text_length < max_token_count:
|
||||
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
return summaries[0]
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
raise openai.error.InvalidRequestError("text too long")
|
||||
|
||||
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
return response
|
||||
|
||||
async def get_context_title(self, text: str, max_words=5, **kwargs) -> str:
|
||||
"""Generate text title"""
|
||||
summary = await self.get_summary(text, max_words=500)
|
||||
|
||||
language = CONFIG.language or DEFAULT_LANGUAGE
|
||||
command = f"Translate the above summary into a {language} title of less than {max_words} words."
|
||||
summaries = [summary, command]
|
||||
msg = "\n".join(summaries)
|
||||
logger.debug(f"title ask:{msg}")
|
||||
response = await self.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"title rsp: {response}")
|
||||
return response
|
||||
|
||||
async def is_related(self, text1, text2):
|
||||
# command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]."
|
||||
command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear."
|
||||
rsp = await self.aask(msg=command, system_msgs=[])
|
||||
result = True if "TRUE" in rsp else False
|
||||
p2 = text2.replace("\n", "")
|
||||
p1 = text1.replace("\n", "")
|
||||
logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n")
|
||||
return result
|
||||
|
||||
async def rewrite(self, sentence: str, context: str):
|
||||
# command = (
|
||||
# f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}"
|
||||
# )
|
||||
command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}"
|
||||
rsp = await self.aask(msg=command, system_msgs=[])
|
||||
logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n")
|
||||
return rsp
|
||||
|
||||
@staticmethod
|
||||
def split_texts(text: str, window_size) -> List[str]:
|
||||
"""Splitting long text into sliding windows text"""
|
||||
if window_size <= 0:
|
||||
window_size = OpenAIGPTAPI.DEFAULT_TOKEN_SIZE
|
||||
total_len = len(text)
|
||||
if total_len <= window_size:
|
||||
return [text]
|
||||
|
||||
padding_size = 20 if window_size > 20 else 0
|
||||
windows = []
|
||||
idx = 0
|
||||
data_len = window_size - padding_size
|
||||
while idx < total_len:
|
||||
if window_size + idx > total_len: # 不足一个滑窗
|
||||
windows.append(text[idx:])
|
||||
break
|
||||
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
|
||||
# window_size=3, padding_size=1:
|
||||
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
|
||||
# idx=2, | idx=5 | idx=8 | ...
|
||||
w = text[idx : idx + window_size]
|
||||
windows.append(w)
|
||||
idx += data_len
|
||||
|
||||
return windows
|
||||
|
||||
@staticmethod
|
||||
def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"):
|
||||
match = re.match(pattern, input_string)
|
||||
if match:
|
||||
return match.group(1), match.group(2)
|
||||
else:
|
||||
return None, input_string
|
||||
|
||||
@staticmethod
|
||||
async def async_retry_call(func, *args, **kwargs):
|
||||
for i in range(OpenAIGPTAPI.MAX_TRY):
|
||||
|
|
@ -371,7 +262,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
raise openai.error.OpenAIError("Exceeds the maximum retries")
|
||||
|
||||
MAX_TRY = 5
|
||||
DEFAULT_TOKEN_SIZE = 500
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -121,23 +121,23 @@ class Assistant(Role):
|
|||
return None
|
||||
if history_text == "":
|
||||
return last_talk
|
||||
history_summary = await self._llm.get_summary(
|
||||
text=history_text, max_words=800, keep_language=True, memory=self.memory
|
||||
history_summary = await self.memory.get_summary(
|
||||
text=history_text, max_words=800, keep_language=True, llm=self._llm
|
||||
)
|
||||
await self.memory.set_history_summary(
|
||||
history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
|
||||
)
|
||||
if last_talk and await self._llm.is_related(last_talk, history_summary): # Merge relevant content.
|
||||
last_talk = await self._llm.rewrite(sentence=last_talk, context=history_text)
|
||||
# await self.memory.set_history_summary(
|
||||
# history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
|
||||
# )
|
||||
if last_talk and await self.memory.is_related(
|
||||
text1=last_talk, text2=history_summary, llm=self._llm
|
||||
): # Merge relevant content.
|
||||
last_talk = await self.memory.rewrite(sentence=last_talk, context=history_text, llm=self._llm)
|
||||
return last_talk
|
||||
|
||||
return last_talk
|
||||
|
||||
@staticmethod
|
||||
def extract_info(input_string):
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
|
||||
return OpenAIGPTAPI.extract_info(input_string)
|
||||
return BrainMemory.extract_info(input_string)
|
||||
|
||||
def get_memory(self) -> str:
|
||||
return self.memory.json()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue