From f2da313548b07f81ce8e9299b2d96bb067ba7e4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 7 Sep 2023 22:58:00 +0800 Subject: [PATCH] refactor: brain memory --- metagpt/actions/talk_action.py | 11 +++++++++++ metagpt/memory/brain_memory.py | 24 ++++++++++++++++++++++++ metagpt/provider/base_gpt_api.py | 19 +++++++++++++++---- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py index 0e3762798..baef47eeb 100644 --- a/metagpt/actions/talk_action.py +++ b/metagpt/actions/talk_action.py @@ -6,10 +6,12 @@ @File : talk_action.py @Desc : Act as it’s a talk """ +import json from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DEFAULT_LANGUAGE +from metagpt.llm import LLMType from metagpt.logs import logger @@ -63,6 +65,15 @@ class TalkAction(Action): return prompt async def run(self, *args, **kwargs) -> ActionOutput: + if CONFIG.LLM_TYPE == LLMType.METAGPT.value: + rsp = await self.llm.aask( + msg=self._talk, + knowledge_msgs=[{"knowledge": self._knowledge}] if self._knowledge else None, + history_msgs=json.loads(self._history_summary) if self._history_summary else None, + ) + self._rsp = ActionOutput(content=rsp) + return self._rsp + prompt = self.prompt rsp = await self.llm.aask(msg=prompt, system_msgs=[]) logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n") diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index e0e2ae1a0..0f9c1dbb6 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -309,4 +309,28 @@ class BrainMemory(pydantic.BaseModel): def is_history_available(self): return bool(self.history or self.historical_summary) + @property + def history_text(self): + if self.llm_type == LLMType.METAGPT.value: + return self._get_metagpt_history_text() + return self._get_openai_history_text() + + def _get_metagpt_history_text(self): + return BrainMemory.to_metagpt_history_format(self.history) + + def _get_openai_history_text(self): + if len(self.history) == 0 and not self.historical_summary: + return "" + texts = [self.historical_summary] if self.historical_summary else [] + for m in self.history[:-1]: + if isinstance(m, Dict): + t = Message(**m).content + elif isinstance(m, Message): + t = m.content + else: + continue + texts.append(t) + + return "\n".join(texts) + DEFAULT_TOKEN_SIZE = 500 diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index 7351e6916..f405ae902 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -38,11 +38,22 @@ class BaseGPTAPI(BaseChatbot): rsp = self.completion(message) return self.get_choice_text(rsp) - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, generator: bool = False) -> str: + async def aask( + self, + msg: str, + system_msgs: Optional[list[str]] = None, + history_msgs: Optional[list[dict[str, str]]] = None, + knowledge_msgs: Optional[list[dict[str, str]]] = None, + generator: bool = False, + ) -> str: + message = [] if system_msgs: - message = self._system_msgs(system_msgs) + [self._user_msg(msg)] - else: - message = [self._default_system_msg(), self._user_msg(msg)] + message = self._system_msgs(system_msgs) + if knowledge_msgs: + message.extend(knowledge_msgs) + if history_msgs: + message.extend(history_msgs) + message.append(self._user_msg(msg)) try: rsp = await self.acompletion_text(message, stream=True, generator=generator) except Exception as e: