Merge pull request #30 from iorisa/feature/cache

feat: +memory cache
This commit is contained in:
send18 2023-09-04 20:17:13 +08:00 committed by GitHub
commit 543f6d4900
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 383 additions and 38 deletions

View file

@ -83,4 +83,11 @@ MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
S3:
access_key: "YOUR_S3_ACCESS_KEY"
secret_key: "YOUR_S3_SECRET_KEY"
endpoint_url: "YOUR_S3_ENDPOINT_URL"
endpoint_url: "YOUR_S3_ENDPOINT_URL"
### Redis config
REDIS:
host: "YOUR_REDIS_HOST"
port: YOUR_REDIS_PORT, int
password: "YOUR_REDIS_PASSWORD"
db: "YOUR_REDIS_DB_INDEX, str, 0-based"

View file

@ -45,6 +45,20 @@ class TalkAction(Action):
)
return prompt
@property
def formation_prompt(self):
kvs = {
"{role}": CONFIG.agent_description or "",
"{history}": self._history_summary or "",
"{knowledge}": self._knowledge or "",
"{language}": CONFIG.language or DEFAULT_LANGUAGE,
"{ask}": self._talk,
}
prompt = TalkAction.__FORMATION_LOOSE__
for k, v in kvs.items():
prompt = prompt.replace(k, v)
return prompt
async def run(self, *args, **kwargs) -> ActionOutput:
prompt = self.prompt
logger.info(prompt)
@ -52,3 +66,60 @@ class TalkAction(Action):
logger.info(rsp)
self._rsp = ActionOutput(content=rsp)
return self._rsp
__FORMATION__ = """Formation: "Capacity and role" defines the role you are currently playing;
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
"Statement" defines the work detail you need to complete at this stage;
"[ASK_BEGIN]" and [ASK_END] tags enclose the requirements for your to respond;
"Constraint" defines the conditions that your responses must comply with.
Capacity and role: {role}
Statement: Your responses should align with the role-play agreement, maintaining the
character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing
your AI nature to preserve the character's image.
[HISTORY_BEGIN]
{history}
[HISTORY_END]
[KNOWLEDGE_BEGIN]
{knowledge}
[KNOWLEDGE_END]
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
Statement: Answer the following questions in {language}, and the answers must follow the Markdown format
, excluding any tag likes "[HISTORY_BEGIN]", "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]", "[ASK_BEGIN]"
, "[ASK_END]"
[ASK_BEGIN]
{ask}
[ASK_END]"""
__FORMATION_LOOSE__ = """Formation: "Capacity and role" defines the role you are currently playing;
"[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation;
"[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses;
"Statement" defines the work detail you need to complete at this stage;
"[ASK_BEGIN]" and [ASK_END] tags enclose the requirements for your to respond;
"Constraint" defines the conditions that your responses must comply with.
Capacity and role: {role}
Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions
, playfully decline to answer without revealing your AI nature to preserve the character's image.
[HISTORY_BEGIN]
{history}
[HISTORY_END]
[KNOWLEDGE_BEGIN]
{knowledge}
[KNOWLEDGE_END]
Statement: If the information is insufficient, you can search in the historical conversation or knowledge.
Statement: Answer the following questions in {language}, and the answers must follow the Markdown format
, excluding any tag likes "[HISTORY_BEGIN]", "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]", "[ASK_BEGIN]"
, "[ASK_END]"
[ASK_BEGIN]
{ask}
[ASK_END]"""

View file

@ -57,3 +57,6 @@ METAGPT_API_VERSION = "METAGPT_API_VERSION"
# format
BASE64_FORMAT = "base64"
# REDIS
REDIS_KEY = "REDIS_KEY"

View file

@ -5,14 +5,17 @@
@Author : mashenquan
@File : brain_memory.py
@Desc : Support memory for multiple tasks and multiple mainlines.
@Modified By: mashenquan, 2023/9/4. + redis memory cache.
"""
import json
from enum import Enum
from typing import Dict, List
import pydantic
from metagpt import Message
from metagpt.logs import logger
from metagpt.utils.redis import Redis
class MessageType(Enum):
@ -28,14 +31,19 @@ class BrainMemory(pydantic.BaseModel):
stack: List[Dict] = []
solution: List[Dict] = []
knowledge: List[Dict] = []
historical_summary: str = ""
last_history_id: str = ""
is_dirty: bool = False
def add_talk(self, msg: Message):
msg.add_tag(MessageType.Talk.value)
self.history.append(msg.dict())
self.is_dirty = True
def add_answer(self, msg: Message):
msg.add_tag(MessageType.Answer.value)
self.history.append(msg.dict())
self.is_dirty = True
def get_knowledge(self) -> str:
texts = [Message(**m).content for m in self.knowledge]
@ -43,9 +51,9 @@ class BrainMemory(pydantic.BaseModel):
@property
def history_text(self):
if len(self.history) == 0:
if len(self.history) == 0 and not self.historical_summary:
return ""
texts = []
texts = [self.historical_summary] if self.historical_summary else []
for m in self.history[:-1]:
if isinstance(m, Dict):
t = Message(**m).content
@ -57,19 +65,6 @@ class BrainMemory(pydantic.BaseModel):
return "\n".join(texts)
def move_to_solution(self, history_summary):
"""放入solution队列以备后续长程检索。目前还未加此功能先用history_summary顶替"""
if len(self.history) < 2:
return
msgs = self.history[:-1]
self.solution.extend(msgs)
if not Message(**self.history[-1]).is_contain(MessageType.Talk.value):
self.solution.append(self.history[-1])
self.history = []
else:
self.history = self.history[-1:]
self.history.insert(0, Message(content="RESOLVED: " + history_summary))
@property
def last_talk(self):
if len(self.history) == 0:
@ -78,3 +73,55 @@ class BrainMemory(pydantic.BaseModel):
if not last_msg.is_contain(MessageType.Talk.value):
return None
return last_msg.content
@staticmethod
async def loads(redis_key: str, redis_conf: Dict = None) -> "BrainMemory":
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
return BrainMemory()
v = await redis.get(key=redis_key)
logger.info(f"REDIS GET {redis_key} {v}")
if v:
data = json.loads(v)
bm = BrainMemory(**data)
bm.is_dirty = False
return bm
return BrainMemory()
async def dumps(self, redis_key: str, timeout_sec: int = 30 * 60, redis_conf: Dict = None):
redis = Redis(conf=redis_conf)
if not redis.is_valid() or not redis_key:
return False
v = self.json()
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
logger.info(f"REDIS SET {redis_key} {v}")
self.is_dirty = False
@staticmethod
def to_redis_key(prefix: str, user_id: str, chat_id: str):
return f"{prefix}:{chat_id}:{user_id}"
async def set_history_summary(self, history_summary, redis_key, redis_conf):
if self.historical_summary == history_summary:
if self.is_dirty:
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
self.is_dirty = False
return
self.historical_summary = history_summary
self.history = []
await self.dumps(redis_key=redis_key, redis_conf=redis_conf)
self.is_dirty = False
def add_history(self, msg: Message):
if msg.id:
if int(msg.id) < int(self.last_history_id):
return
self.history.append(msg.dict())
self.is_dirty = True
def exists(self, text) -> bool:
for m in reversed(self.history):
if m.get("content") == text:
return True
return False

View file

@ -126,11 +126,13 @@ class Assistant(Role):
if history_text == "":
return last_talk
history_summary = await self._llm.get_summary(history_text, max_words=500)
await self.memory.set_history_summary(
history_summary=history_summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS
)
if last_talk and await self._llm.is_related(last_talk, history_summary): # Merge relevant content.
last_talk = await self._llm.rewrite(sentence=last_talk, context=history_text)
return last_talk
self.memory.move_to_solution(history_summary) # Promptly clear memory after the issue is resolved.
return last_talk
@staticmethod

View file

@ -97,8 +97,9 @@ class RoleContext(BaseModel):
def prerequisite(self):
"""Retrieve information with `prerequisite` tag"""
if self.memory and hasattr(self.memory, "get_by_tags"):
return self.memory.get_by_tags([MessageTag.Prerequisite.value])
return ""
vv = self.memory.get_by_tags([MessageTag.Prerequisite.value])
return vv[-1:] if len(vv) > 1 else vv
return []
class Role:

View file

@ -10,7 +10,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Type, TypedDict, Set, Optional, List
from typing import Optional, Set, Type, TypedDict
from pydantic import BaseModel
@ -29,13 +29,15 @@ class RawMessage(TypedDict):
@dataclass
class Message:
"""list[<role>: <content>]"""
content: str
instruct_content: BaseModel = field(default=None)
role: str = field(default='user') # system / user / assistant
role: str = field(default="user") # system / user / assistant
cause_by: Type["Action"] = field(default="")
sent_from: str = field(default="")
send_to: str = field(default="")
tags: Optional[Set] = field(default=None)
id: str = None
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
@ -45,10 +47,7 @@ class Message:
return self.__str__()
def to_dict(self) -> dict:
return {
"role": self.role,
"content": self.content
}
return {"role": self.role, "content": self.content}
def add_tag(self, tag):
if self.tags is None:
@ -64,7 +63,7 @@ class Message:
"""Determine whether the message contains tags."""
if not tags or not self.tags:
return False
intersection = set(tags) & self.tags
intersection = set(tags) & set(self.tags)
return len(intersection) > 0
def is_contain(self, tag):
@ -76,7 +75,7 @@ class Message:
"instruct_content": self.instruct_content,
"sent_from": self.sent_from,
"send_to": self.send_to,
"tags": self.tags
"tags": self.tags,
}
m = {"content": self.content}
@ -89,39 +88,39 @@ class Message:
@dataclass
class UserMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'user')
super().__init__(content, "user")
@dataclass
class SystemMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'system')
super().__init__(content, "system")
@dataclass
class AIMessage(Message):
"""便于支持OpenAI的消息
Facilitate support for OpenAI messages
Facilitate support for OpenAI messages
"""
def __init__(self, content: str):
super().__init__(content, 'assistant')
super().__init__(content, "assistant")
if __name__ == '__main__':
test_content = 'test_message'
if __name__ == "__main__":
test_content = "test_message"
msgs = [
UserMessage(test_content),
SystemMessage(test_content),
AIMessage(test_content),
Message(test_content, role='QA')
Message(test_content, role="QA"),
]
logger.info(msgs)

214
metagpt/utils/redis.py Normal file
View file

@ -0,0 +1,214 @@
# !/usr/bin/python3
# -*- coding: utf-8 -*-
# @Author: Hui
# @Desc: { redis client }
# @Date: 2022/11/28 10:12
import json
from datetime import timedelta
from enum import Enum
from typing import Awaitable, Callable, Dict, Optional, Union
from redis import asyncio as aioredis
from metagpt.config import CONFIG
from metagpt.logs import logger
class RedisTypeEnum(Enum):
"""Redis 数据类型"""
String = "String"
List = "List"
Hash = "Hash"
Set = "Set"
ZSet = "ZSet"
def make_url(
dialect: str,
*,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
port: Optional[Union[str, int]] = None,
name: Optional[Union[str, int]] = None,
) -> str:
url_parts = [f"{dialect}://"]
if user or password:
if user:
url_parts.append(user)
if password:
url_parts.append(f":{password}")
url_parts.append("@")
if not host and not dialect.startswith("sqlite"):
host = "127.0.0.1"
if host:
url_parts.append(f"{host}")
if port:
url_parts.append(f":{port}")
# 比如redis可能传入0
if name is not None:
url_parts.append(f"/{name}")
return "".join(url_parts)
class RedisAsyncClient(aioredis.Redis):
"""异步的客户端
例子::
rdb = RedisAsyncClient()
print(rdb.url)
Args:
host: 服务器地址
port: 服务器端口
user: 用户名
db: 数据库
password: 密码
decode_responses: 字符串输入被编码成utf8存储在Redis里了而取出来的时候还是被编码后的bytes需要显示的decode才能变成字符串
health_check_interval: 定时检测连接防止出现ConnectionErrors (104, Connection reset by peer)
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str = None,
decode_responses=True,
health_check_interval=10,
socket_connect_timeout=5,
retry_on_timeout=True,
socket_keepalive=True,
**kwargs,
):
super().__init__(
host=host,
port=port,
db=db,
password=password,
decode_responses=decode_responses,
health_check_interval=health_check_interval,
socket_connect_timeout=socket_connect_timeout,
retry_on_timeout=retry_on_timeout,
socket_keepalive=socket_keepalive,
**kwargs,
)
self.url = make_url("redis", host=host, port=port, name=db, password=password)
class RedisCacheInfo(object):
"""统一缓存信息类"""
def __init__(self, key, timeout: Union[int, timedelta] = timedelta(seconds=60), data_type=RedisTypeEnum.String):
"""
缓存信息类初始化
Args:
key: 缓存的key
timeout: 缓存过期时间, 单位秒
data_type: 缓存采用的数据结构 (不传并不影响用于标记业务采用的是什么数据结构)
"""
self.key = key
self.timeout = timeout
self.data_type = data_type
def __str__(self):
return f"cache key {self.key} timeout {self.timeout}s"
class RedisManager:
client: RedisAsyncClient = None
@classmethod
def init_redis_conn(cls, host, port, password, db):
"""初始化redis 连接"""
if cls.client is None:
cls.client = RedisAsyncClient(host=host, port=port, password=password, db=db)
@classmethod
async def set_with_cache_info(cls, redis_cache_info: RedisCacheInfo, value):
"""
根据 RedisCacheInfo 设置 Redis 缓存
:param redis_cache_info: RedisCacheInfo缓存信息对象
:param value: 缓存的值
:return:
"""
await cls.client.setex(redis_cache_info.key, redis_cache_info.timeout, value)
@classmethod
async def get_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
"""
根据 RedisCacheInfo 获取 Redis 缓存
:param redis_cache_info: RedisCacheInfo 缓存信息对象
:return:
"""
cache_info = await cls.client.get(redis_cache_info.key)
return cache_info
@classmethod
async def del_with_cache_info(cls, redis_cache_info: RedisCacheInfo):
"""
根据 RedisCacheInfo 删除 Redis 缓存
:param redis_cache_info: RedisCacheInfo缓存信息对象
:return:
"""
await cls.client.delete(redis_cache_info.key)
@staticmethod
async def get_or_set_cache(cache_info: RedisCacheInfo, fetch_data_func: Callable[[], Awaitable[dict]]) -> dict:
"""
获取缓存数据如果缓存不存在则从提供的函数中获取并设置缓存
当前版本仅支持 json 形式的 string 格式数据
"""
serialized_data = await RedisManager.get_with_cache_info(cache_info)
if serialized_data:
return json.loads(serialized_data)
data = await fetch_data_func()
try:
serialized_data = json.dumps(data)
await RedisManager.set_with_cache_info(cache_info, serialized_data)
except Exception as e:
logger.warning(f"数据 {data} 通过 json 进行序列化缓存失败:{e}")
return data
@classmethod
def is_valid(cls):
return cls.client is not None
class Redis:
def __init__(self, conf: Dict = None):
self._config = conf or CONFIG.REDIS
if not self._config:
return
try:
host = self._config["host"]
port = int(self._config["port"])
pwd = self._config["password"]
db = self._config["db"]
RedisManager.init_redis_conn(host=host, port=port, password=pwd, db=db)
except Exception as e:
logger.warning(f"Redis initialization has failed:{e}")
def is_valid(self):
return RedisManager.is_valid()
async def get(self, key: str) -> str:
if not self.is_valid() or not key:
return None
v = await RedisManager.get_with_cache_info(redis_cache_info=RedisCacheInfo(key=key))
return v
async def set(self, key: str, data: str, timeout_sec: int):
if not self.is_valid() or not key:
return
await RedisManager.set_with_cache_info(
redis_cache_info=RedisCacheInfo(key=key, timeout=timeout_sec), value=data
)

View file

@ -41,4 +41,5 @@ qdrant-client==1.4.0
connexion[swagger-ui]
aiohttp_jinja2
azure-cognitiveservices-speech==1.31.0
aioboto3~=11.3.0
aioboto3~=11.3.0
redis==4.3.5