diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index a5a3dbfc7..275cd14df 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -6,7 +6,7 @@ @File : brain_memory.py @Desc : Support memory for multiple tasks and multiple mainlines. """ - +import hashlib from enum import Enum from typing import Dict, List @@ -28,6 +28,10 @@ class BrainMemory(pydantic.BaseModel): stack: List[Dict] = [] solution: List[Dict] = [] knowledge: List[Dict] = [] + # If the fingerprint of the history text is found in the `historical_summary_fingerprint`, + # it indicates that the text has already been incorporated into the `history summary`. + historical_summary_fingerprint: List[str] = [] + historical_summary: str = "" def add_talk(self, msg: Message): msg.add_tag(MessageType.Talk.value) @@ -58,17 +62,19 @@ class BrainMemory(pydantic.BaseModel): return "\n".join(texts) def move_to_solution(self, history_summary): - """放入solution队列,以备后续长程检索。目前还未加此功能,先用history_summary顶替""" - if len(self.history) < 2: - return - msgs = self.history[:-1] - self.solution.extend(msgs) - if not Message(**self.history[-1]).is_contain(MessageType.Talk.value): - self.solution.append(self.history[-1]) - self.history = [] - else: - self.history = self.history[-1:] - self.history.insert(0, Message(content="RESOLVED: " + history_summary)) + """Put it in the solution queue for future long-term retrieval. + This functionality hasn't been added yet, so use the history summary as a temporary substitute for now.""" + pass + # if len(self.history) < 2: + # return + # msgs = self.history[:-1] + # self.solution.extend(msgs) + # if not Message(**self.history[-1]).is_contain(MessageType.Talk.value): + # self.solution.append(self.history[-1]) + # self.history = [] + # else: + # self.history = self.history[-1:] + # self.history.insert(0, Message(content="RESOLVED: " + history_summary)) @property def last_talk(self): @@ -78,3 +84,7 @@ class BrainMemory(pydantic.BaseModel): if not last_msg.is_contain(MessageType.Talk.value): return None return last_msg.content + + @staticmethod + def get_md5(text: str) -> str: + return hashlib.md5(text.encode()).hexdigest() diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index e69de29bb..f2ae3222a 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -0,0 +1,198 @@ +# !/usr/bin/python3 +# -*- coding: utf-8 -*- +# @Author: Hui +# @Desc: { redis client } +# @Date: 2022/11/28 10:12 +import json +from datetime import timedelta +from enum import Enum +from typing import Awaitable, Callable, Optional, Union + +from redis import asyncio as aioredis + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class RedisTypeEnum(Enum): + """Redis 数据类型""" + + String = "String" + List = "List" + Hash = "Hash" + Set = "Set" + ZSet = "ZSet" + + +def make_url( + dialect: str, + *, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[Union[str, int]] = None, + name: Optional[Union[str, int]] = None, +) -> str: + url_parts = [f"{dialect}://"] + if user or password: + if user: + url_parts.append(user) + if password: + url_parts.append(f":{password}") + url_parts.append("@") + + if not host and not dialect.startswith("sqlite"): + host = "127.0.0.1" + + if host: + url_parts.append(f"{host}") + if port: + url_parts.append(f":{port}") + + # 比如redis可能传入0 + if name is not None: + url_parts.append(f"/{name}") + return "".join(url_parts) + + +class RedisAsyncClient(aioredis.Redis): + """异步的客户端 + 例子:: + + rdb = RedisAsyncClient() + print(rdb.url) + + Args: + host: 服务器地址 + port: 服务器端口 + user: 用户名 + db: 数据库 + password: 密码 + decode_responses: 字符串输入被编码成utf8存储在Redis里了,而取出来的时候还是被编码后的bytes,需要显示的decode才能变成字符串 + health_check_interval: 定时检测连接,防止出现ConnectionErrors (104, Connection reset by peer) + """ + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + password: str = None, + decode_responses=True, + health_check_interval=10, + socket_connect_timeout=5, + retry_on_timeout=True, + socket_keepalive=True, + **kwargs, + ): + super().__init__( + host=host, + port=port, + db=db, + password=password, + decode_responses=decode_responses, + health_check_interval=health_check_interval, + socket_connect_timeout=socket_connect_timeout, + retry_on_timeout=retry_on_timeout, + socket_keepalive=socket_keepalive, + **kwargs, + ) + self.url = make_url("redis", host=host, port=port, name=db, password=password) + + +class RedisCacheInfo(object): + """统一缓存信息类""" + + def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String): + """ + 缓存信息类初始化 + Args: + key: 缓存的key + timeout: 缓存过期时间, 单位秒 + data_type: 缓存采用的数据结构 (不传并不影响,用于标记业务采用的是什么数据结构) + """ + self.key = key + self.timeout = timeout + self.data_type = data_type + + def __str__(self): + return f"cache key {self.key} timeout {self.timeout}s" + + +class RedisManager: + client: RedisAsyncClient = None + + @classmethod + def init_redis_conn(cls, host, port, password, db): + """初始化redis 连接""" + if cls.client is None: + cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db) + + @classmethod + async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value): + """ + 根据 RedisCacheInfo 设置 Redis 缓存 + :param redis_cache_info: RedisCacheInfo缓存信息对象 + :param value: 缓存的值 + :return: + """ + await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value) + + @classmethod + async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo): + """ + 根据 RedisCacheInfo 获取 Redis 缓存 + :param redis_cache_info: RedisCacheInfo 缓存信息对象 + :return: + """ + cache_info = await cls.client.get(redis_cache_info.key) + return cache_info + + @classmethod + async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo): + """ + 根据 RedisCacheInfo 删除 Redis 缓存 + :param redis_cache_info: RedisCacheInfo缓存信息对象 + :return: + """ + await cls.client.delete(redis_cache_info.key) + + @staticmethod + async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict: + """ + 获取缓存数据,如果缓存不存在,则从提供的函数中获取并设置缓存 + 当前版本仅支持 json 形式的 string 格式数据 + """ + + serialized_data = await RedisManager.get_with_cache_info(cache_info) + + if serialized_data: + return json.loads(serialized_data) + + data = await fetch_data_func() + try: + serialized_data = json.dumps(data) + await RedisManager.set_with_cache_info(cache_info, serialized_data) + except Exception as e: + logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}") + + return data + + @classmethod + def is_valid(cls): + return cls.client is not None + + +class Redis: + def __init__(self): + self._config = CONFIG.REDIS + if not self._config: + return + try: + host = self._config["host"] + port = int(self._config["port"]) + pwd = self._config["password"] + db = int(self._config["db"]) + RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db) + except Exception as e: + logger.warning(f"Redis initialization has failed:{e}") diff --git a/requirements.txt b/requirements.txt index 5daf710c7..588b29e0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,4 +41,5 @@ qdrant-client==1.4.0 connexion[swagger-ui] aiohttp_jinja2 azure-cognitiveservices-speech==1.31.0 -aioboto3~=11.3.0 \ No newline at end of file +aioboto3~=11.3.0 +redis==4.3.5 \ No newline at end of file