revert: 改用CONFIG

This commit is contained in:
莘权 马 2023-09-05 12:26:36 +08:00
parent 54120e7356
commit dec135ec83
4 changed files with 44 additions and 45 deletions

View file

@ -14,6 +14,7 @@ import faiss
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.document_store.base_store import LocalStore
from metagpt.document_store.document import Document
@ -21,7 +22,7 @@ from metagpt.logs import logger
class FaissStore(LocalStore):
def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_col='output'):
def __init__(self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output"):
self.meta_col = meta_col
self.content_col = content_col
super().__init__(raw_data, cache_dir)
@ -37,11 +38,12 @@ class FaissStore(LocalStore):
store.index = index
return store
def _write(self, docs, metadatas, **kwargs):
store = FAISS.from_texts(docs,
OpenAIEmbeddings(openai_api_version="2020-11-07",
openai_api_key=kwargs.get("OPENAI_API_KEY")),
metadatas=metadatas)
def _write(self, docs, metadatas):
store = FAISS.from_texts(
docs,
OpenAIEmbeddings(openai_api_version="2020-11-07", openai_api_key=CONFIG.OPENAI_API_KEY),
metadatas=metadatas,
)
return store
def persist(self):
@ -54,7 +56,7 @@ class FaissStore(LocalStore):
pickle.dump(store, f)
store.index = index
def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs):
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
@ -82,8 +84,8 @@ class FaissStore(LocalStore):
raise NotImplementedError
if __name__ == '__main__':
faiss_store = FaissStore(DATA_PATH / 'qcs/qcs_4w.json')
logger.info(faiss_store.search('油皮洗面奶'))
faiss_store.add([f'油皮洗面奶-{i}' for i in range(3)])
logger.info(faiss_store.search('油皮洗面奶'))
if __name__ == "__main__":
faiss_store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
logger.info(faiss_store.search("油皮洗面奶"))
faiss_store.add([f"油皮洗面奶-{i}" for i in range(3)])
logger.info(faiss_store.search("油皮洗面奶"))

View file

@ -37,13 +37,13 @@ class LongTermMemory(Memory):
self.add_batch(messages)
self.msg_from_recover = False
def add(self, message: Message, **kwargs):
def add(self, message: Message):
super(LongTermMemory, self).add(message)
for action in self.rc.watch:
if message.cause_by == action and not self.msg_from_recover:
# currently, only add role's watching messages to its memory_storage
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message, **kwargs)
self.memory_storage.add(message)
def remember(self, observed: list[Message], k=0) -> list[Message]:
"""

View file

@ -5,16 +5,16 @@
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from typing import List
from pathlib import Path
from typing import List
from langchain.vectorstores.faiss import FAISS
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.serialize import serialize_message, deserialize_message
from metagpt.document_store.faiss_store import FaissStore
from metagpt.utils.serialize import deserialize_message, serialize_message
class MemoryStorage(FaissStore):
@ -37,7 +37,7 @@ class MemoryStorage(FaissStore):
def recover_memory(self, role_id: str) -> List[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/')
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
self.store = self._load()
@ -54,23 +54,23 @@ class MemoryStorage(FaissStore):
def _get_index_and_store_fname(self):
if not self.role_mem_path:
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory')
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index')
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl')
index_fpath = Path(self.role_mem_path / f"{self.role_id}.index")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}.pkl")
return index_fpath, storage_fpath
def persist(self):
super(MemoryStorage, self).persist()
logger.debug(f'Agent {self.role_id} persist memory into local')
logger.debug(f"Agent {self.role_id} persist memory into local")
def add(self, message: Message, **kwargs) -> bool:
""" add message into memory storage"""
def add(self, message: Message) -> bool:
"""add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
if not self.store:
# init Faiss
self.store = self._write(docs, metadatas, **kwargs)
self.store = self._write(docs, metadatas)
self._initialized = True
else:
self.store.add_texts(texts=docs, metadatas=metadatas)
@ -82,10 +82,7 @@ class MemoryStorage(FaissStore):
if not self.store:
return []
resp = self.store.similarity_search_with_score(
query=message.content,
k=k
)
resp = self.store.similarity_search_with_score(query=message.content, k=k)
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp:

View file

@ -4,11 +4,11 @@
@Desc : unittest of `metagpt/memory/longterm_memory.py`
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from metagpt.config import Config
from metagpt.schema import Message
from metagpt.actions import BossRequirement
from metagpt.roles.role import RoleContext
from metagpt.config import Config
from metagpt.memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
def test_ltm_search():
@ -17,28 +17,28 @@ def test_ltm_search():
openai_api_key = conf.openai_api_key
assert len(openai_api_key) > 20
role_id = 'UTUserLtm(Product Manager)'
rc = RoleContext(options=conf.runtime_options, watch=[BossRequirement])
role_id = "UTUserLtm(Product Manager)"
rc = RoleContext(watch=[BossRequirement])
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = 'Write a cli snake game'
message = Message(role='BOSS', content=idea, cause_by=BossRequirement)
idea = "Write a cli snake game"
message = Message(role="BOSS", content=idea, cause_by=BossRequirement)
news = ltm.remember([message])
assert len(news) == 1
ltm.add(message, **conf.runtime_options)
ltm.add(message)
sim_idea = 'Write a game of cli snake'
sim_message = Message(role='BOSS', content=sim_idea, cause_by=BossRequirement)
sim_idea = "Write a game of cli snake"
sim_message = Message(role="BOSS", content=sim_idea, cause_by=BossRequirement)
news = ltm.remember([sim_message])
assert len(news) == 0
ltm.add(sim_message, **conf.runtime_options)
ltm.add(sim_message)
new_idea = 'Write a 2048 web game'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a 2048 web game"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm.remember([new_message])
assert len(news) == 1
ltm.add(new_message, **conf.runtime_options)
ltm.add(new_message)
# restore from local index
ltm_new = LongTermMemory()
@ -50,8 +50,8 @@ def test_ltm_search():
news = ltm_new.remember([sim_message])
assert len(news) == 0
new_idea = 'Write a Battle City'
new_message = Message(role='BOSS', content=new_idea, cause_by=BossRequirement)
new_idea = "Write a Battle City"
new_message = Message(role="BOSS", content=new_idea, cause_by=BossRequirement)
news = ltm_new.remember([new_message])
assert len(news) == 1