refactor: brain memory

This commit is contained in:
莘权 马 2023-09-07 19:13:23 +08:00
parent 530d2f5b30
commit 4c873a9158
2 changed files with 16 additions and 18 deletions

View file

@ -178,13 +178,16 @@ class BrainMemory(pydantic.BaseModel):
self.is_dirty = True
return self.historical_summary
async def get_summary(self, text: str, llm, max_words=200, keep_language: bool = False, **kwargs):
async def summerize(self, llm, max_words=200, keep_language: bool = False, **kwargs):
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text = self.history_text
text_length = len(text)
summary = ""
while max_count > 0:
if text_length < max_token_count:
return await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language)
summary = await self._get_summary(text=text, llm=llm, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
@ -194,13 +197,18 @@ class BrainMemory(pydantic.BaseModel):
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
return summaries[0]
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
if not summary:
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
return summary
raise openai.error.InvalidRequestError("text too long")
async def _get_summary(self, text: str, llm, max_words=20, keep_language: bool = False):

View file

@ -45,7 +45,7 @@ class Assistant(Role):
name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, *args, **kwargs
)
brain_memory = CONFIG.BRAIN_MEMORY
self.memory = BrainMemory(**brain_memory) if brain_memory else BrainMemory()
self.memory = BrainMemory(**brain_memory) if brain_memory else BrainMemory(llm_type=CONFIG.LLM_TYPE)
skill_path = Path(CONFIG.SKILL_PATH) if CONFIG.SKILL_PATH else None
self.skills = SkillLoader(skill_yaml_file_name=skill_path)
@ -83,7 +83,7 @@ class Assistant(Role):
self.memory.add_talk(Message(content=text))
async def _plan(self, rsp: str, **kwargs) -> bool:
skill, text = Assistant.extract_info(input_string=rsp)
skill, text = BrainMemory.extract_info(input_string=rsp)
handlers = {
MessageType.Talk.value: self.talk_handler,
MessageType.Skill.value: self.skill_handler,
@ -121,24 +121,14 @@ class Assistant(Role):
return None
if history_text == "":
return last_talk
history_summary = await self.memory.get_summary(
text=history_text, max_words=800, keep_language=True, llm=self._llm
)
# await self.memory.set_history_summary(
# history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
# )
if last_talk and await self.memory.is_related(
text1=last_talk, text2=history_summary, llm=self._llm
): # Merge relevant content.
history_summary = await self.memory.summerize(max_words=800, keep_language=True, llm=self._llm)
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self._llm):
# Merge relevant content.
last_talk = await self.memory.rewrite(sentence=last_talk, context=history_text, llm=self._llm)
return last_talk
return last_talk
@staticmethod
def extract_info(input_string):
return BrainMemory.extract_info(input_string)
def get_memory(self) -> str:
return self.memory.json()