revert: 改用CONFIG

This commit is contained in:
莘权 马 2023-09-05 12:26:36 +08:00
parent 54120e7356
commit dec135ec83
4 changed files with 44 additions and 45 deletions

View file

@ -14,6 +14,7 @@ import faiss
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.document_store.base_store import LocalStore
from metagpt.document_store.document import Document
@ -21,7 +22,7 @@ from metagpt.logs import logger
class FaissStore(LocalStore):
def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_col='output'):
def __init__(self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output"):
self.meta_col = meta_col
self.content_col = content_col
super().__init__(raw_data, cache_dir)
@ -37,11 +38,12 @@ class FaissStore(LocalStore):
store.index = index
return store
def _write(self, docs, metadatas, **kwargs):
store = FAISS.from_texts(docs,
OpenAIEmbeddings(openai_api_version="2020-11-07",
openai_api_key=kwargs.get("OPENAI_API_KEY")),
metadatas=metadatas)
def _write(self, docs, metadatas):
store = FAISS.from_texts(
docs,
OpenAIEmbeddings(openai_api_version="2020-11-07", openai_api_key=CONFIG.OPENAI_API_KEY),
metadatas=metadatas,
)
return store
def persist(self):
@ -54,7 +56,7 @@ class FaissStore(LocalStore):
pickle.dump(store, f)
store.index = index
def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs):
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
@ -82,8 +84,8 @@ class FaissStore(LocalStore):
raise NotImplementedError
if __name__ == '__main__':
faiss_store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
logger.info(faiss_store.search('油皮洗面奶'))
faiss_store.add([f'油皮洗面奶-{i}' for i in range(3)])
logger.info(faiss_store.search('油皮洗面奶'))
if __name__ == "__main__":
faiss_store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
logger.info(faiss_store.search("油皮洗面奶"))
faiss_store.add([f"油皮洗面奶-{i}" for i in range(3)])
logger.info(faiss_store.search("油皮洗面奶"))

View file

@ -37,13 +37,13 @@ class LongTermMemory(Memory):
self.add_batch(messages)
self.msg_from_recover = False
def add(self, message: Message, **kwargs):
def add(self, message: Message):
super(LongTermMemory, self).add(message)
for action in self.rc.watch:
if message.cause_by == action and not self.msg_from_recover:
# currently, only add role's watching messages to its memory_storage
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message, **kwargs)
self.memory_storage.add(message)
def remember(self, observed: list[Message], k=0) -> list[Message]:
"""

View file

@ -5,16 +5,16 @@
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from typing import List
from pathlib import Path
from typing import List
from langchain.vectorstores.faiss import FAISS
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.serialize import serialize_message, deserialize_message
from metagpt.document_store.faiss_store import FaissStore
from metagpt.utils.serialize import deserialize_message, serialize_message
class MemoryStorage(FaissStore):
@ -37,7 +37,7 @@ class MemoryStorage(FaissStore):
def recover_memory(self, role_id: str) -> List[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/')
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
self.store = self._load()
@ -54,23 +54,23 @@ class MemoryStorage(FaissStore):
def _get_index_and_store_fname(self):
if not self.role_mem_path:
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory')
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index')
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl')
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
return index_fpath, storage_fpath
def persist(self):
super(MemoryStorage, self).persist()
logger.debug(f'Agent {self.role_id} persist memory into local')
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message, **kwargs) -> bool:
""" add message into memory storage"""
def add(self, message: Message) -> bool:
"""add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
if not self.store:
# init Faiss
self.store = self._write(docs, metadatas, **kwargs)
self.store = self._write(docs, metadatas)
self._initialized = True
else:
self.store.add_texts(texts=docs, metadatas=metadatas)
@ -82,10 +82,7 @@ class MemoryStorage(FaissStore):
if not self.store:
return []
resp = self.store.similarity_search_with_score(
query=message.content,
k=k
)
resp = self.store.similarity_search_with_score(query=message.content, k=k)
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp: