diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 06a3154e8..510041e98 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -7,6 +7,7 @@ Change cost control from global to company level. """ import asyncio +import re import time from typing import NamedTuple, List @@ -333,6 +334,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): for ws in text_windows: response = await self.get_summary(ws) summaries.append(response) + if len(summaries) == 1: + return summaries[0] language = self._options.get("language", "English") command = f"Translate the above summary into a {language} title of less than {max_words} words." @@ -343,6 +346,17 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.info(f"title rsp: {response}") return response + async def is_related(self, text1, text2): + command = f"{text1}\n{text2}\n\nIf the two sentences above are related, return [TRUE] brief and clear. Otherwise, return [FALSE]." + rsp = await self.aask(msg=command, system_msgs=[]) + result, _ = self.extract_info(rsp) + return result == "TRUE" + + async def rewrite(self, sentence: str, context: str): + command = f"{context}\n\nConsidering the content above, rewrite and return this sentence brief and clear:\n{sentence}" + rsp = await self.aask(msg=command, system_msgs=[]) + return rsp + @staticmethod def split_texts(text: str, window_size) -> List[str]: """Splitting long text into sliding windows text""" @@ -365,3 +379,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): break windows[i] += windows[i + 1][0:padding_size] return windows + + @staticmethod + def extract_info(input_string): + pattern = r'\[([A-Z]+)\]:\s*(.+)' + match = re.match(pattern, input_string) + if match: + return match.group(1), match.group(2) + else: + return None, input_string \ No newline at end of file diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index fde011892..dfbd406bc 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -14,7 +14,7 @@ """ import asyncio -import re + from metagpt.actions import ActionOutput from metagpt.actions.talk_action import TalkAction @@ -42,32 +42,18 @@ class Assistant(Role): async def think(self) -> bool: """Everything will be done part by part.""" - if self.memory.history_text != "": - self._refine_memory() - - - prompt = "" - history_text = self.memory.history_text - history_summary = "" - if history_text != "": - max_tokens = self.options.get("MAX_TOKENS", DEFAULT_MAX_TOKENS) - history_summary = await self._llm.get_summary(history_text, max_tokens - COMMAND_TOKENS) - prompt += history_summary + "\n\n" - prompt += "Analyze the conversation history above, in conjunction with the current sentence: \n{self.memory.last_talk}\n\n" - else: - prompt += f"Refer to this sentence:\n {self.memory.last_talk}\n" + last_talk = await self.refine_memory() + prompt = f"Refer to this sentence:\n {last_talk}\n" skills = self.skills.get_skill_list() for desc, name in skills.items(): prompt += f"If want you to do {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: text_to_image\n" - if history_text != "": - prompt += "If the last sentence is not related to the conversation history above, return `[SOLUTION]: {title of the history conversation}` brief and clear. For instance: [SOLUTION]: Solution for distributing watermelon\n" prompt += "If the preceding text presents a complete question and solution, rewrite and return `[SOLUTION]: {problem}` brief and clear. For instance: [SOLUTION]: Solution for distributing watermelon\n" prompt += "If the preceding text presents an unresolved issue and its corresponding discussion, rewrite and return `[PROBLEM]: {problem}` brief and clear. For instance: [PROBLEM]: How to distribute watermelon?\n" prompt += "Otherwise, rewrite and return `[TALK]: {talk}` brief and clear. For instance: [TALK]: distribute watermelon" logger.info(prompt) rsp = await self._llm.aask(prompt, []) logger.info(rsp) - return await self._plan(rsp, history_summary=history_summary) + return await self._plan(rsp) async def act(self) -> ActionOutput: result = await self._rc.todo.run(**self._options) @@ -86,40 +72,42 @@ class Assistant(Role): async def talk(self, text): self.memory.add_talk(Message(content=text, tags=set([MessageType.Talk.value]))) - async def _plan(self, rsp, **kwargs) -> bool: - skill, text = Assistant.extract_info(rsp) + async def _plan(self, rsp: str, **kwargs) -> bool: + skill, text = Assistant.extract_info(input_string=rsp) handlers = { MessageType.Talk.value: self.talk_handler, - MessageType.Problem.value: self.problem_handler, - MessageType.Solution.value: self.solution_handler, + MessageType.Problem.value: self.talk_handler, MessageType.Skill.value: self.skill_handler, } handler = handlers.get(skill, self.talk_handler) return await handler(text, **kwargs) - @staticmethod - def extract_info(input_string): - pattern = r'\[([A-Z]+)\]:\s*(.+)' - match = re.match(pattern, input_string) - if match: - return match.group(1), match.group(2) - else: - return None, input_string - - async def problem_handler(self, text, **kwargs) -> bool: + async def talk_handler(self, text, **kwargs) -> bool: action = TalkAction(options=self.options, talk=text, llm=self._llm, **kwargs) self.add_to_do(action) return True - async def solution_handler(self, text, **kwargs) -> bool: - self.memory.move_to_solution() # 问题解决后及时清空内存 - action = TalkAction(options=self.options, talk=text, history_summary="", **kwargs) - self.add_to_do(action) - async def skill_handler(self, text, **kwargs) -> bool: + skill = pass - async def _refine_memory(self): + async def refine_memory(self) -> str: + history_text = self.memory.history_text + last_talk = self.memory.last_talk + if history_text == "": + return last_talk + history_summary = await self._llm.get_context_title(history_text, max_words=20) + if await self._llm.is_related(last_talk, history_summary): # 合并相关内容 + last_talk = await self._llm.rewrite(sentence=last_talk, context=history_text) + return last_talk + + self.memory.move_to_solution() # 问题解决后及时清空内存 + return last_talk + + @staticmethod + def extract_info(input_string): + from metagpt.provider.openai_api import OpenAIGPTAPI + return OpenAIGPTAPI.extract_info(input_string) async def main():