diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 8b47ba79a..347e3e0fb 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -10,14 +10,15 @@ """ import json import re -from typing import Dict, List +from typing import Dict, List, Optional from pydantic import BaseModel, Field from metagpt.config import CONFIG -from metagpt.const import DEFAULT_LANGUAGE +from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger from metagpt.provider import MetaGPTAPI +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message, SimpleMessage from metagpt.utils.redis import Redis @@ -30,6 +31,7 @@ class BrainMemory(BaseModel): is_dirty: bool = False last_talk: str = None cacheable: bool = True + llm: Optional[BaseGPTAPI] = None def add_talk(self, msg: Message): """ @@ -120,6 +122,7 @@ class BrainMemory(BaseModel): if isinstance(llm, MetaGPTAPI): return await self._metagpt_summarize(max_words=max_words) + self.llm = llm return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit) async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1): @@ -131,7 +134,7 @@ class BrainMemory(BaseModel): text_length = len(text) if limit > 0 and text_length < limit: return text - summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit) + summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit) if summary: await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS) return summary @@ -251,3 +254,74 @@ class BrainMemory(BaseModel): texts.append(t) return "\n".join(texts) + + async def _summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str: + max_token_count = DEFAULT_MAX_TOKENS + max_count = 100 + text_length = len(text) + if limit > 0 and text_length < limit: + return text + summary = "" + while max_count > 0: + if text_length < max_token_count: + summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) + break + + padding_size = 20 if max_token_count > 20 else 0 + text_windows = self.split_texts(text, window_size=max_token_count - padding_size) + part_max_words = min(int(max_words / len(text_windows)) + 1, 100) + summaries = [] + for ws in text_windows: + response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) + summaries.append(response) + if len(summaries) == 1: + summary = summaries[0] + break + + # Merged and retry + text = "\n".join(summaries) + text_length = len(text) + + max_count -= 1 # safeguard + return summary + + async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): + """Generate text summary""" + if len(text) < max_words: + return text + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." + else: + command = f"Translate the above content into a summary of less than {max_words} words." + msg = text + "\n\n" + command + logger.debug(f"summary ask:{msg}") + response = await self.llm.aask(msg=msg, system_msgs=[]) + logger.debug(f"summary rsp: {response}") + return response + + @staticmethod + def split_texts(text: str, window_size) -> List[str]: + """Splitting long text into sliding windows text""" + if window_size <= 0: + window_size = DEFAULT_TOKEN_SIZE + total_len = len(text) + if total_len <= window_size: + return [text] + + padding_size = 20 if window_size > 20 else 0 + windows = [] + idx = 0 + data_len = window_size - padding_size + while idx < total_len: + if window_size + idx > total_len: # 不足一个滑窗 + windows.append(text[idx:]) + break + # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] + # window_size=3, padding_size=1: + # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... + # idx=2, | idx=5 | idx=8 | ... + w = text[idx : idx + window_size] + windows.append(w) + idx += data_len + + return windows diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 405d523e5..b72eff0dc 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -12,7 +12,7 @@ import asyncio import json import time -from typing import AsyncIterator, List, Union +from typing import AsyncIterator, Union import openai from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI @@ -28,7 +28,6 @@ from tenacity import ( ) from metagpt.config import CONFIG, Config, LLMProviderEnum -from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE @@ -190,9 +189,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return self.get_choice_text(rsp) def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict: - """ - Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - """ + """Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create""" if "tools" not in kwargs: configs = { "tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}], @@ -353,74 +350,3 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): if self.async_client: await self.async_client.close() self.async_client = None - - async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str: - max_token_count = DEFAULT_MAX_TOKENS - max_count = 100 - text_length = len(text) - if limit > 0 and text_length < limit: - return text - summary = "" - while max_count > 0: - if text_length < max_token_count: - summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) - break - - padding_size = 20 if max_token_count > 20 else 0 - text_windows = self.split_texts(text, window_size=max_token_count - padding_size) - part_max_words = min(int(max_words / len(text_windows)) + 1, 100) - summaries = [] - for ws in text_windows: - response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) - summaries.append(response) - if len(summaries) == 1: - summary = summaries[0] - break - - # Merged and retry - text = "\n".join(summaries) - text_length = len(text) - - max_count -= 1 # safeguard - return summary - - async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): - """Generate text summary""" - if len(text) < max_words: - return text - if keep_language: - command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." - else: - command = f"Translate the above content into a summary of less than {max_words} words." - msg = text + "\n\n" + command - logger.debug(f"summary ask:{msg}") - response = await self.aask(msg=msg, system_msgs=[]) - logger.debug(f"summary rsp: {response}") - return response - - @staticmethod - def split_texts(text: str, window_size) -> List[str]: - """Splitting long text into sliding windows text""" - if window_size <= 0: - window_size = DEFAULT_TOKEN_SIZE - total_len = len(text) - if total_len <= window_size: - return [text] - - padding_size = 20 if window_size > 20 else 0 - windows = [] - idx = 0 - data_len = window_size - padding_size - while idx < total_len: - if window_size + idx > total_len: # 不足一个滑窗 - windows.append(text[idx:]) - break - # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] - # window_size=3, padding_size=1: - # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... - # idx=2, | idx=5 | idx=8 | ... - w = text[idx : idx + window_size] - windows.append(w) - idx += data_len - - return windows