From 530d2f5b308a9c280853a20f51c2fac929c95134 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 7 Sep 2023 19:03:41 +0800 Subject: [PATCH] refactor: brain memory --- metagpt/memory/brain_memory.py | 113 +++++++++++++++++++++++++++++++++ metagpt/provider/openai_api.py | 110 -------------------------------- metagpt/roles/assistant.py | 20 +++--- 3 files changed, 123 insertions(+), 120 deletions(-) diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index e8a98c55b..7eda9c601 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -8,12 +8,16 @@ @Modified By: mashenquan, 2023/9/4. + redis memory cache. """ import json +import re from enum import Enum from typing import Dict, List +import openai import pydantic from metagpt import Message +from metagpt.config import CONFIG +from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS from metagpt.logs import logger from metagpt.schema import RawMessage from metagpt.utils.redis import Redis @@ -36,6 +40,7 @@ class BrainMemory(pydantic.BaseModel): last_history_id: str = "" is_dirty: bool = False last_talk: str = None + llm_type: str def add_talk(self, msg: Message): msg.add_tag(MessageType.Talk.value) @@ -172,3 +177,111 @@ class BrainMemory(pydantic.BaseModel): self.history = [] self.is_dirty = True return self.historical_summary + + async def get_summary(self, text: str, llm, max_words=200, keep_language: bool = False, **kwargs): + max_token_count = DEFAULT_MAX_TOKENS + max_count = 100 + text_length = len(text) + while max_count > 0: + if text_length < max_token_count: + return await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language) + + padding_size = 20 if max_token_count > 20 else 0 + text_windows = self.split_texts(text, window_size=max_token_count - padding_size) + part_max_words = min(int(max_words / len(text_windows)) + 1, 100) + summaries = [] + for ws in text_windows: + response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) + summaries.append(response) + if len(summaries) == 1: + return summaries[0] + + # Merged and retry + text = "\n".join(summaries) + text_length = len(text) + + max_count -= 1 # safeguard + raise openai.error.InvalidRequestError("text too long") + + async def _get_summary(self, text: str, llm, max_words=20, keep_language: bool = False): + """Generate text summary""" + if len(text) < max_words: + return text + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." + else: + command = f"Translate the above content into a summary of less than {max_words} words." + msg = text + "\n\n" + command + logger.debug(f"summary ask:{msg}") + response = await llm.aask(msg=msg, system_msgs=[]) + logger.debug(f"summary rsp: {response}") + return response + + async def get_title(self, text: str, llm, max_words=5, **kwargs) -> str: + """Generate text title""" + summary = await self.get_summary(text, max_words=500) + + language = CONFIG.language or DEFAULT_LANGUAGE + command = f"Translate the above summary into a {language} title of less than {max_words} words." + summaries = [summary, command] + msg = "\n".join(summaries) + logger.debug(f"title ask:{msg}") + response = await llm.aask(msg=msg, system_msgs=[]) + logger.debug(f"title rsp: {response}") + return response + + async def is_related(self, text1, text2, llm): + # command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]." + command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear." + rsp = await llm.aask(msg=command, system_msgs=[]) + result = True if "TRUE" in rsp else False + p2 = text2.replace("\n", "") + p1 = text1.replace("\n", "") + logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n") + return result + + async def rewrite(self, sentence: str, context: str, llm): + # command = ( + # f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}" + # ) + command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}" + rsp = await llm.aask(msg=command, system_msgs=[]) + logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n") + return rsp + + @staticmethod + def split_texts(text: str, window_size) -> List[str]: + """Splitting long text into sliding windows text""" + if window_size <= 0: + window_size = BrainMemory.DEFAULT_TOKEN_SIZE + total_len = len(text) + if total_len <= window_size: + return [text] + + padding_size = 20 if window_size > 20 else 0 + windows = [] + idx = 0 + data_len = window_size - padding_size + while idx < total_len: + if window_size + idx > total_len: # 不足一个滑窗 + windows.append(text[idx:]) + break + # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] + # window_size=3, padding_size=1: + # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... + # idx=2, | idx=5 | idx=8 | ... + w = text[idx : idx + window_size] + windows.append(w) + idx += data_len + + return windows + + @staticmethod + def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"): + match = re.match(pattern, input_string) + if match: + return match.group(1), match.group(2) + else: + return None, input_string + + DEFAULT_TOKEN_SIZE = 500 diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 64267975e..231b568c7 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -8,10 +8,8 @@ """ import asyncio import random -import re import time import traceback -from typing import List import openai from openai.error import APIConnectionError @@ -24,7 +22,6 @@ from tenacity import ( ) from metagpt.config import CONFIG -from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.cost_manager import Costs @@ -223,112 +220,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) - async def get_summary(self, text: str, max_words=200, keep_language: bool = False, **kwargs): - max_token_count = DEFAULT_MAX_TOKENS - max_count = 100 - text_length = len(text) - while max_count > 0: - if text_length < max_token_count: - return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) - - padding_size = 20 if max_token_count > 20 else 0 - text_windows = self.split_texts(text, window_size=max_token_count - padding_size) - part_max_words = min(int(max_words / len(text_windows)) + 1, 100) - summaries = [] - for ws in text_windows: - response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) - summaries.append(response) - if len(summaries) == 1: - return summaries[0] - - # Merged and retry - text = "\n".join(summaries) - text_length = len(text) - - max_count -= 1 # safeguard - raise openai.error.InvalidRequestError("text too long") - - async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): - """Generate text summary""" - if len(text) < max_words: - return text - if keep_language: - command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." - else: - command = f"Translate the above content into a summary of less than {max_words} words." - msg = text + "\n\n" + command - logger.debug(f"summary ask:{msg}") - response = await self.aask(msg=msg, system_msgs=[]) - logger.debug(f"summary rsp: {response}") - return response - - async def get_context_title(self, text: str, max_words=5, **kwargs) -> str: - """Generate text title""" - summary = await self.get_summary(text, max_words=500) - - language = CONFIG.language or DEFAULT_LANGUAGE - command = f"Translate the above summary into a {language} title of less than {max_words} words." - summaries = [summary, command] - msg = "\n".join(summaries) - logger.debug(f"title ask:{msg}") - response = await self.aask(msg=msg, system_msgs=[]) - logger.debug(f"title rsp: {response}") - return response - - async def is_related(self, text1, text2): - # command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]." - command = f"{text2}\n\nIs there any sentence above related to the following sentence: {text1}.\nIf is there any relevance, return [TRUE] brief and clear. Otherwise, return [FALSE] brief and clear." - rsp = await self.aask(msg=command, system_msgs=[]) - result = True if "TRUE" in rsp else False - p2 = text2.replace("\n", "") - p1 = text1.replace("\n", "") - logger.info(f"IS_RELATED:\nParagraph 1: {p2}\nParagraph 2: {p1}\nRESULT: {result}\n") - return result - - async def rewrite(self, sentence: str, context: str): - # command = ( - # f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}" - # ) - command = f"{context}\n\nExtract relevant information from every preceding sentence and use it to succinctly supplement or rewrite the following text in brief and clear:\n{sentence}" - rsp = await self.aask(msg=command, system_msgs=[]) - logger.info(f"REWRITE:\nCommand: {command}\nRESULT: {rsp}\n") - return rsp - - @staticmethod - def split_texts(text: str, window_size) -> List[str]: - """Splitting long text into sliding windows text""" - if window_size <= 0: - window_size = OpenAIGPTAPI.DEFAULT_TOKEN_SIZE - total_len = len(text) - if total_len <= window_size: - return [text] - - padding_size = 20 if window_size > 20 else 0 - windows = [] - idx = 0 - data_len = window_size - padding_size - while idx < total_len: - if window_size + idx > total_len: # 不足一个滑窗 - windows.append(text[idx:]) - break - # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] - # window_size=3, padding_size=1: - # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... - # idx=2, | idx=5 | idx=8 | ... - w = text[idx : idx + window_size] - windows.append(w) - idx += data_len - - return windows - - @staticmethod - def extract_info(input_string, pattern=r"\[([A-Z]+)\]:\s*(.+)"): - match = re.match(pattern, input_string) - if match: - return match.group(1), match.group(2) - else: - return None, input_string - @staticmethod async def async_retry_call(func, *args, **kwargs): for i in range(OpenAIGPTAPI.MAX_TRY): @@ -371,7 +262,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): raise openai.error.OpenAIError("Exceeds the maximum retries") MAX_TRY = 5 - DEFAULT_TOKEN_SIZE = 500 if __name__ == "__main__": diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 2fcb6f584..d5467cafb 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -121,23 +121,23 @@ class Assistant(Role): return None if history_text == "": return last_talk - history_summary = await self._llm.get_summary( - text=history_text, max_words=800, keep_language=True, memory=self.memory + history_summary = await self.memory.get_summary( + text=history_text, max_words=800, keep_language=True, llm=self._llm ) - await self.memory.set_history_summary( - history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS - ) - if last_talk and await self._llm.is_related(last_talk, history_summary): # Merge relevant content. - last_talk = await self._llm.rewrite(sentence=last_talk, context=history_text) + # await self.memory.set_history_summary( + # history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS + # ) + if last_talk and await self.memory.is_related( + text1=last_talk, text2=history_summary, llm=self._llm + ): # Merge relevant content. + last_talk = await self.memory.rewrite(sentence=last_talk, context=history_text, llm=self._llm) return last_talk return last_talk @staticmethod def extract_info(input_string): - from metagpt.provider.openai_api import OpenAIGPTAPI - - return OpenAIGPTAPI.extract_info(input_string) + return BrainMemory.extract_info(input_string) def get_memory(self) -> str: return self.memory.json()