From 4c82298e8864f9e8f3712aa9bb6333079a015749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 7 Sep 2023 18:21:10 +0800 Subject: [PATCH] feat: truncated history --- metagpt/memory/brain_memory.py | 62 ++++++++++++++++++++++++----- metagpt/provider/metagpt_llm_api.py | 41 ++----------------- metagpt/roles/assistant.py | 2 +- 3 files changed, 56 insertions(+), 49 deletions(-) diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 04ae6593a..e8a98c55b 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -15,6 +15,7 @@ import pydantic from metagpt import Message from metagpt.logs import logger +from metagpt.schema import RawMessage from metagpt.utils.redis import Redis @@ -54,17 +55,21 @@ class BrainMemory(pydantic.BaseModel): def history_text(self): if len(self.history) == 0 and not self.historical_summary: return "" - texts = [self.historical_summary] if self.historical_summary else [] - for m in self.history[:-1]: - if isinstance(m, Dict): - t = Message(**m).content - elif isinstance(m, Message): - t = m.content - else: - continue - texts.append(t) + try: + self.loads_raw_messages() + return self.dumps_raw_messages() + except: + texts = [self.historical_summary] if self.historical_summary else [] + for m in self.history[:-1]: + if isinstance(m, Dict): + t = Message(**m).content + elif isinstance(m, Message): + t = m.content + else: + continue + texts.append(t) - return "\n".join(texts) + return "\n".join(texts) @staticmethod async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory": @@ -130,3 +135,40 @@ class BrainMemory(pydantic.BaseModel): v = self.last_talk self.last_talk = None return v + + def loads_raw_messages(self): + if not self.historical_summary: + return + vv = json.loads(self.historical_summary) + msgs = [] + for v in vv: + tag = set([MessageType.Talk.value]) if v.get("role") == "user" else set([MessageType.Answer.value]) + m = Message(content=v.get("content"), tags=tag) + msgs.append(m) + msgs.extend(self.history) + self.history = msgs + self.is_dirty = True + + def dumps_raw_messages(self, max_length: int = 0) -> str: + summary = [] + + total_length = 0 + for m in reversed(self.history): + msg = Message(**m) + c = RawMessage(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content) + length_delta = len(msg.content) + if max_length > 0: + if total_length + length_delta > max_length: + left = max_length - total_length + if left > 0: + c.content = msg.content[0:left] + summary.insert(0, c) + break + + total_length += length_delta + summary.insert(0, c) + + self.historical_summary = json.dumps(summary) + self.history = [] + self.is_dirty = True + return self.historical_summary diff --git a/metagpt/provider/metagpt_llm_api.py b/metagpt/provider/metagpt_llm_api.py index d8d06aeaa..3ae65a623 100644 --- a/metagpt/provider/metagpt_llm_api.py +++ b/metagpt/provider/metagpt_llm_api.py @@ -5,35 +5,18 @@ @File : metagpt_llm_api.py @Desc : MetaGPT LLM related APIs """ -import json -from typing import Dict, List -from pydantic import BaseModel - -from metagpt.memory.brain_memory import MessageType +from metagpt.memory.brain_memory import BrainMemory from metagpt.provider import OpenAIGPTAPI -class HisMsg(BaseModel): - content: str - tags: set - id: str - - -class Conversion(BaseModel): - """See: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - - role: str - content: str - - class MetaGPTLLMAPI(OpenAIGPTAPI): """MetaGPT LLM api""" def __init__(self): super().__init__() - async def get_summary(self, history: List[Dict], max_words=200, keep_language: bool = False, **kwargs) -> str: + async def get_summary(self, memory: BrainMemory, max_words=200, keep_language: bool = False, **kwargs) -> str: """ Return string in the following format: [ @@ -43,22 +26,4 @@ class MetaGPTLLMAPI(OpenAIGPTAPI): {"role": "user", "content": "Orange."}, ] """ - summary = [] - - total_length = 0 - for m in reversed(history): - msg = HisMsg(**m) - c = Conversion(role="user" if MessageType.Talk.value in msg.tags else "assistant", content=msg.content) - length_delta = len(msg.content) - if total_length + length_delta > max_words: - left = max_words - total_length - if left > 0: - c.content = msg.content[0:left] - summary.insert(0, c.dict()) - break - - total_length += length_delta - summary.insert(0, c.dict()) - - data = json.dumps(summary) - return data + return memory.dumps_raw_messages(max_length=max_words) diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 2f9059210..2fcb6f584 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -122,7 +122,7 @@ class Assistant(Role): if history_text == "": return last_talk history_summary = await self._llm.get_summary( - text=history_text, max_words=800, keep_language=True, history=self.memory.history + text=history_text, max_words=800, keep_language=True, memory=self.memory ) await self.memory.set_history_summary( history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS