refactor openai api and brain memory

This commit is contained in:
geekan 2023-12-26 15:09:37 +08:00
parent 8351c8ec35
commit e15de55368
2 changed files with 79 additions and 79 deletions

View file

@ -10,14 +10,15 @@
"""
import json
import re
from typing import Dict, List
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import logger
from metagpt.provider import MetaGPTAPI
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.schema import Message, SimpleMessage
from metagpt.utils.redis import Redis
@ -30,6 +31,7 @@ class BrainMemory(BaseModel):
is_dirty: bool = False
last_talk: str = None
cacheable: bool = True
llm: Optional[BaseGPTAPI] = None
def add_talk(self, msg: Message):
"""
@ -120,6 +122,7 @@ class BrainMemory(BaseModel):
if isinstance(llm, MetaGPTAPI):
return await self._metagpt_summarize(max_words=max_words)
self.llm = llm
return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit)
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1):
@ -131,7 +134,7 @@ class BrainMemory(BaseModel):
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
if summary:
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
return summary
@ -251,3 +254,74 @@ class BrainMemory(BaseModel):
texts.append(t)
return "\n".join(texts)
async def _summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = ""
while max_count > 0:
if text_length < max_token_count:
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
return summary
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.llm.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows

View file

@ -12,7 +12,7 @@
import asyncio
import json
import time
from typing import AsyncIterator, List, Union
from typing import AsyncIterator, Union
import openai
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
@ -28,7 +28,6 @@ from tenacity import (
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
@ -190,9 +189,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return self.get_choice_text(rsp)
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
"""
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
"""
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
if "tools" not in kwargs:
configs = {
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
@ -353,74 +350,3 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
if self.async_client:
await self.async_client.close()
self.async_client = None
async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = ""
while max_count > 0:
if text_length < max_token_count:
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
return summary
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows