diff --git a/config/config.yaml b/config/config.yaml index 7c3d212f6..5c8dea03e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -83,4 +83,11 @@ MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k S3: access_key: "YOUR_S3_ACCESS_KEY" secret_key: "YOUR_S3_SECRET_KEY" - endpoint_url: "YOUR_S3_ENDPOINT_URL" \ No newline at end of file + endpoint_url: "YOUR_S3_ENDPOINT_URL" + +### Redis config +REDIS: + host: "YOUR_REDIS_HOST" + port: YOUR_REDIS_PORT, int + password: "YOUR_REDIS_PASSWORD" + db: "YOUR_REDIS_DB_INDEX, str, 0-based" \ No newline at end of file diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py index 4eed0d4f8..83504b62d 100644 --- a/metagpt/actions/talk_action.py +++ b/metagpt/actions/talk_action.py @@ -45,6 +45,20 @@ class TalkAction(Action): ) return prompt + @property + def formation_prompt(self): + kvs = { + "{role}": CONFIG.agent_description or "", + "{history}": self._history_summary or "", + "{knowledge}": self._knowledge or "", + "{language}": CONFIG.language or DEFAULT_LANGUAGE, + "{ask}": self._talk, + } + prompt = TalkAction.__FORMATION_LOOSE__ + for k, v in kvs.items(): + prompt = prompt.replace(k, v) + return prompt + async def run(self, *args, **kwargs) -> ActionOutput: prompt = self.prompt logger.info(prompt) @@ -52,3 +66,60 @@ class TalkAction(Action): logger.info(rsp) self._rsp = ActionOutput(content=rsp) return self._rsp + + __FORMATION__ = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "[ASK_BEGIN]" and [ASK_END] tags enclose the requirements for your to respond; + "Constraint" defines the conditions that your responses must comply with. + +Capacity and role: {role} +Statement: Your responses should align with the role-play agreement, maintaining the + character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing + your AI nature to preserve the character's image. + +[HISTORY_BEGIN] +{history} +[HISTORY_END] + +[KNOWLEDGE_BEGIN] +{knowledge} +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Answer the following questions in {language}, and the answers must follow the Markdown format + , excluding any tag likes "[HISTORY_BEGIN]", "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]", "[ASK_BEGIN]" + , "[ASK_END]" + +[ASK_BEGIN] +{ask} +[ASK_END]""" + + __FORMATION_LOOSE__ = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "[ASK_BEGIN]" and [ASK_END] tags enclose the requirements for your to respond; + "Constraint" defines the conditions that your responses must comply with. + +Capacity and role: {role} +Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions +, playfully decline to answer without revealing your AI nature to preserve the character's image. + +[HISTORY_BEGIN] +{history} +[HISTORY_END] + +[KNOWLEDGE_BEGIN] +{knowledge} +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Answer the following questions in {language}, and the answers must follow the Markdown format + , excluding any tag likes "[HISTORY_BEGIN]", "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]", "[ASK_BEGIN]" + , "[ASK_END]" + +[ASK_BEGIN] +{ask} +[ASK_END]""" diff --git a/metagpt/const.py b/metagpt/const.py index fbc2c928a..e9fa118d7 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -57,3 +57,6 @@ METAGPT_API_VERSION = "METAGPT_API_VERSION" # format BASE64_FORMAT = "base64" + +# REDIS +REDIS_KEY = "REDIS_KEY" diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index a5a3dbfc7..dedea3b41 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -5,14 +5,17 @@ @Author : mashenquan @File : brain_memory.py @Desc : Support memory for multiple tasks and multiple mainlines. +@Modified By: mashenquan, 2023/9/4. + redis memory cache. """ - +import json from enum import Enum from typing import Dict, List import pydantic from metagpt import Message +from metagpt.logs import logger +from metagpt.utils.redis import Redis class MessageType(Enum): @@ -28,14 +31,19 @@ class BrainMemory(pydantic.BaseModel): stack: List[Dict] = [] solution: List[Dict] = [] knowledge: List[Dict] = [] + historical_summary: str = "" + last_history_id: str = "" + is_dirty: bool = False def add_talk(self, msg: Message): msg.add_tag(MessageType.Talk.value) self.history.append(msg.dict()) + self.is_dirty = True def add_answer(self, msg: Message): msg.add_tag(MessageType.Answer.value) self.history.append(msg.dict()) + self.is_dirty = True def get_knowledge(self) -> str: texts = [Message(**m).content for m in self.knowledge] @@ -43,9 +51,9 @@ class BrainMemory(pydantic.BaseModel): @property def history_text(self): - if len(self.history) == 0: + if len(self.history) == 0 and not self.historical_summary: return "" - texts = [] + texts = [self.historical_summary] if self.historical_summary else [] for m in self.history[:-1]: if isinstance(m, Dict): t = Message(**m).content @@ -57,19 +65,6 @@ class BrainMemory(pydantic.BaseModel): return "\n".join(texts) - def move_to_solution(self, history_summary): - """放入solution队列,以备后续长程检索。目前还未加此功能,先用history_summary顶替""" - if len(self.history) < 2: - return - msgs = self.history[:-1] - self.solution.extend(msgs) - if not Message(**self.history[-1]).is_contain(MessageType.Talk.value): - self.solution.append(self.history[-1]) - self.history = [] - else: - self.history = self.history[-1:] - self.history.insert(0, Message(content="RESOLVED: " + history_summary)) - @property def last_talk(self): if len(self.history) == 0: @@ -78,3 +73,55 @@ class BrainMemory(pydantic.BaseModel): if not last_msg.is_contain(MessageType.Talk.value): return None return last_msg.content + + @staticmethod + async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory": + redis = Redis(conf=redis_conf) + if not redis.is_valid() or not redis_key: + return BrainMemory() + v = await redis.get(key=redis_key) + logger.info(f"REDIS GET {redis_key} {v}") + if v: + data = json.loads(v) + bm = BrainMemory(**data) + bm.is_dirty = False + return bm + return BrainMemory() + + async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None): + redis = Redis(conf=redis_conf) + if not redis.is_valid() or not redis_key: + return False + v = self.json() + await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec) + logger.info(f"REDIS SET {redis_key} {v}") + self.is_dirty = False + + @staticmethod + def to_redis_key(prefix: str, user_id: str, chat_id: str): + return f"{prefix}:{chat_id}:{user_id}" + + async def set_history_summary(self, history_summary, redis_key, redis_conf): + if self.historical_summary == history_summary: + if self.is_dirty: + await self.dumps(redis_key=redis_key, redis_conf=redis_conf) + self.is_dirty = False + return + + self.historical_summary = history_summary + self.history = [] + await self.dumps(redis_key=redis_key, redis_conf=redis_conf) + self.is_dirty = False + + def add_history(self, msg: Message): + if msg.id: + if int(msg.id) < int(self.last_history_id): + return + self.history.append(msg.dict()) + self.is_dirty = True + + def exists(self, text) -> bool: + for m in reversed(self.history): + if m.get("content") == text: + return True + return False diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 0bce4a3f9..9c80593f6 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -126,11 +126,13 @@ class Assistant(Role): if history_text == "": return last_talk history_summary = await self._llm.get_summary(history_text, max_words=500) + await self.memory.set_history_summary( + history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS + ) if last_talk and await self._llm.is_related(last_talk, history_summary): # Merge relevant content. last_talk = await self._llm.rewrite(sentence=last_talk, context=history_text) return last_talk - self.memory.move_to_solution(history_summary) # Promptly clear memory after the issue is resolved. return last_talk @staticmethod diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 2f0f713f8..b1ace19fa 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -97,8 +97,9 @@ class RoleContext(BaseModel): def prerequisite(self): """Retrieve information with `prerequisite` tag""" if self.memory and hasattr(self.memory, "get_by_tags"): - return self.memory.get_by_tags([MessageTag.Prerequisite.value]) - return "" + vv = self.memory.get_by_tags([MessageTag.Prerequisite.value]) + return vv[-1:] if len(vv) > 1 else vv + return [] class Role: diff --git a/metagpt/schema.py b/metagpt/schema.py index ce08455fc..8f8e4030f 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -10,7 +10,7 @@ from __future__ import annotations from dataclasses import dataclass, field from enum import Enum -from typing import Type, TypedDict, Set, Optional, List +from typing import Optional, Set, Type, TypedDict from pydantic import BaseModel @@ -29,13 +29,15 @@ class RawMessage(TypedDict): @dataclass class Message: """list[: ]""" + content: str instruct_content: BaseModel = field(default=None) - role: str = field(default='user') # system / user / assistant + role: str = field(default="user") # system / user / assistant cause_by: Type["Action"] = field(default="") sent_from: str = field(default="") send_to: str = field(default="") tags: Optional[Set] = field(default=None) + id: str = None def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -45,10 +47,7 @@ class Message: return self.__str__() def to_dict(self) -> dict: - return { - "role": self.role, - "content": self.content - } + return {"role": self.role, "content": self.content} def add_tag(self, tag): if self.tags is None: @@ -64,7 +63,7 @@ class Message: """Determine whether the message contains tags.""" if not tags or not self.tags: return False - intersection = set(tags) & self.tags + intersection = set(tags) & set(self.tags) return len(intersection) > 0 def is_contain(self, tag): @@ -76,7 +75,7 @@ class Message: "instruct_content": self.instruct_content, "sent_from": self.sent_from, "send_to": self.send_to, - "tags": self.tags + "tags": self.tags, } m = {"content": self.content} @@ -89,39 +88,39 @@ class Message: @dataclass class UserMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ def __init__(self, content: str): - super().__init__(content, 'user') + super().__init__(content, "user") @dataclass class SystemMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ def __init__(self, content: str): - super().__init__(content, 'system') + super().__init__(content, "system") @dataclass class AIMessage(Message): """便于支持OpenAI的消息 - Facilitate support for OpenAI messages + Facilitate support for OpenAI messages """ def __init__(self, content: str): - super().__init__(content, 'assistant') + super().__init__(content, "assistant") -if __name__ == '__main__': - test_content = 'test_message' +if __name__ == "__main__": + test_content = "test_message" msgs = [ UserMessage(test_content), SystemMessage(test_content), AIMessage(test_content), - Message(test_content, role='QA') + Message(test_content, role="QA"), ] logger.info(msgs) diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py new file mode 100644 index 000000000..b94eee8e2 --- /dev/null +++ b/metagpt/utils/redis.py @@ -0,0 +1,214 @@ +# !/usr/bin/python3 +# -*- coding: utf-8 -*- +# @Author: Hui +# @Desc: { redis client } +# @Date: 2022/11/28 10:12 +import json +from datetime import timedelta +from enum import Enum +from typing import Awaitable, Callable, Dict, Optional, Union + +from redis import asyncio as aioredis + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class RedisTypeEnum(Enum): + """Redis 数据类型""" + + String = "String" + List = "List" + Hash = "Hash" + Set = "Set" + ZSet = "ZSet" + + +def make_url( + dialect: str, + *, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[Union[str, int]] = None, + name: Optional[Union[str, int]] = None, +) -> str: + url_parts = [f"{dialect}://"] + if user or password: + if user: + url_parts.append(user) + if password: + url_parts.append(f":{password}") + url_parts.append("@") + + if not host and not dialect.startswith("sqlite"): + host = "127.0.0.1" + + if host: + url_parts.append(f"{host}") + if port: + url_parts.append(f":{port}") + + # 比如redis可能传入0 + if name is not None: + url_parts.append(f"/{name}") + return "".join(url_parts) + + +class RedisAsyncClient(aioredis.Redis): + """异步的客户端 + 例子:: + + rdb = RedisAsyncClient() + print(rdb.url) + + Args: + host: 服务器地址 + port: 服务器端口 + user: 用户名 + db: 数据库 + password: 密码 + decode_responses: 字符串输入被编码成utf8存储在Redis里了,而取出来的时候还是被编码后的bytes,需要显示的decode才能变成字符串 + health_check_interval: 定时检测连接,防止出现ConnectionErrors (104, Connection reset by peer) + """ + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: str = None, + decode_responses=True, + health_check_interval=10, + socket_connect_timeout=5, + retry_on_timeout=True, + socket_keepalive=True, + **kwargs, + ): + super().__init__( + host=host, + port=port, + db=db, + password=password, + decode_responses=decode_responses, + health_check_interval=health_check_interval, + socket_connect_timeout=socket_connect_timeout, + retry_on_timeout=retry_on_timeout, + socket_keepalive=socket_keepalive, + **kwargs, + ) + self.url = make_url("redis", host=host, port=port, name=db, password=password) + + +class RedisCacheInfo(object): + """统一缓存信息类""" + + def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String): + """ + 缓存信息类初始化 + Args: + key: 缓存的key + timeout: 缓存过期时间, 单位秒 + data_type: 缓存采用的数据结构 (不传并不影响,用于标记业务采用的是什么数据结构) + """ + self.key = key + self.timeout = timeout + self.data_type = data_type + + def __str__(self): + return f"cache key {self.key} timeout {self.timeout}s" + + +class RedisManager: + client: RedisAsyncClient = None + + @classmethod + def init_redis_conn(cls, host, port, password, db): + """初始化redis 连接""" + if cls.client is None: + cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db) + + @classmethod + async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value): + """ + 根据 RedisCacheInfo 设置 Redis 缓存 + :param redis_cache_info: RedisCacheInfo缓存信息对象 + :param value: 缓存的值 + :return: + """ + await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value) + + @classmethod + async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo): + """ + 根据 RedisCacheInfo 获取 Redis 缓存 + :param redis_cache_info: RedisCacheInfo 缓存信息对象 + :return: + """ + cache_info = await cls.client.get(redis_cache_info.key) + return cache_info + + @classmethod + async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo): + """ + 根据 RedisCacheInfo 删除 Redis 缓存 + :param redis_cache_info: RedisCacheInfo缓存信息对象 + :return: + """ + await cls.client.delete(redis_cache_info.key) + + @staticmethod + async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict: + """ + 获取缓存数据,如果缓存不存在,则从提供的函数中获取并设置缓存 + 当前版本仅支持 json 形式的 string 格式数据 + """ + + serialized_data = await RedisManager.get_with_cache_info(cache_info) + + if serialized_data: + return json.loads(serialized_data) + + data = await fetch_data_func() + try: + serialized_data = json.dumps(data) + await RedisManager.set_with_cache_info(cache_info, serialized_data) + except Exception as e: + logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}") + + return data + + @classmethod + def is_valid(cls): + return cls.client is not None + + +class Redis: + def __init__(self, conf: Dict = None): + self._config = conf or CONFIG.REDIS + if not self._config: + return + try: + host = self._config["host"] + port = int(self._config["port"]) + pwd = self._config["password"] + db = self._config["db"] + RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db) + except Exception as e: + logger.warning(f"Redis initialization has failed:{e}") + + def is_valid(self): + return RedisManager.is_valid() + + async def get(self, key: str) -> str: + if not self.is_valid() or not key: + return None + v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key)) + return v + + async def set(self, key: str, data: str, timeout_sec: int): + if not self.is_valid() or not key: + return + await RedisManager.set_with_cache_info( + redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data + ) diff --git a/requirements.txt b/requirements.txt index 5daf710c7..588b29e0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,4 +41,5 @@ qdrant-client==1.4.0 connexion[swagger-ui] aiohttp_jinja2 azure-cognitiveservices-speech==1.31.0 -aioboto3~=11.3.0 \ No newline at end of file +aioboto3~=11.3.0 +redis==4.3.5 \ No newline at end of file