From 5caaea2aeb05edcc35a1c4727910c01644c0a5dc Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 10 Apr 2024 19:35:46 +0800 Subject: [PATCH] replace werewolf experience chromadb with rag --- .../werewolf/actions/experience_operation.py | 138 ++++++------------ .../ext/werewolf/actions/moderator_actions.py | 5 +- metagpt/ext/werewolf/schema.py | 4 + metagpt/utils/common.py | 4 +- 4 files changed, 57 insertions(+), 94 deletions(-) diff --git a/metagpt/ext/werewolf/actions/experience_operation.py b/metagpt/ext/werewolf/actions/experience_operation.py index 73e360470..a32549a2a 100644 --- a/metagpt/ext/werewolf/actions/experience_operation.py +++ b/metagpt/ext/werewolf/actions/experience_operation.py @@ -1,18 +1,18 @@ -import glob import json -import os from typing import Optional -import chromadb -from chromadb import Collection from chromadb.utils import embedding_functions from pydantic import model_validator from metagpt.actions import Action from metagpt.config2 import config from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.environment.werewolf.const import RoleType from metagpt.ext.werewolf.schema import RoleExperience from metagpt.logs import logger +from metagpt.rag.engines.simple import SimpleEngine +from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig +from metagpt.utils.common import read_json_file, write_json_file DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now EMB_FN = embedding_functions.OpenAIEmbeddingFunction( @@ -21,72 +21,62 @@ EMB_FN = embedding_functions.OpenAIEmbeddingFunction( api_type=config.llm.api_type, model_name="text-embedding-ada-002", ) +PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("/werewolf_game/chroma") +PERSIST_PATH.mkdir(parents=True, exist_ok=True) class AddNewExperiences(Action): name: str = "AddNewExperience" collection_name: str = DEFAULT_COLLECTION_NAME delete_existing: bool = False - collection: Optional[Collection] = None + engine: Optional[SimpleEngine] = None @model_validator(mode="after") def validate_collection(self): - if self.collection: + if self.engine: return - - chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma") + self.engine = SimpleEngine.from_objs( + retriever_configs=[ + ChromaRetrieverConfig( + persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"} + ) + ] + ) if self.delete_existing: try: - chroma_client.get_collection(name=self.collection_name) - chroma_client.delete_collection(name=self.collection_name) - logger.info(f"existing collection `{self.collection_name}` deleted") - except: - pass - - self.collection = chroma_client.get_or_create_collection( - name=self.collection_name, - metadata={"hnsw:space": "cosine"}, - embedding_function=EMB_FN, - ) + # implement engine `DELETE` method later + self.engine.retriever._index._vector_store._collection.delete_collection(name=self.collection_name) + except Exception as exp: + logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}") def run(self, experiences: list[RoleExperience]): if not experiences: return for i, exp in enumerate(experiences): exp.id = f"{exp.profile}-{exp.name}-step{i}-round_{exp.round_id}" - ids = [exp.id for exp in experiences] - documents = [exp.reflection for exp in experiences] - metadatas = [exp.model_dump() for exp in experiences] AddNewExperiences._record_experiences_local(experiences) - self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + self.engine.add_objs(experiences) def add_from_file(self, file_path): - with open(file_path, "r") as fl: - lines = fl.readlines() - experiences = [RoleExperience.model_validate_json(line) for line in lines] + experiences = read_json_file(file_path) + experiences = [RoleExperience.model_validate(item) for item in experiences] experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""' - ids = [exp.id for exp in experiences] - documents = [exp.reflection for exp in experiences] - metadatas = [exp.model_dump() for exp in experiences] - - self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + self.engine.add(experiences) @staticmethod def _record_experiences_local(experiences: list[RoleExperience]): round_id = experiences[0].round_id version = experiences[0].version version = "test" if not version else version - experiences = [exp.model_dump_json() for exp in experiences] - experience_folder = DEFAULT_WORKSPACE_ROOT / f"werewolf_game/experiences/{version}" - if not os.path.exists(experience_folder): - os.makedirs(experience_folder) - save_path = f"{experience_folder}/{round_id}.json" - with open(save_path, "a") as fl: - fl.write("\n".join(experiences)) - fl.write("\n") + experiences = [exp.model_dump() for exp in experiences] + + experience_path = DEFAULT_WORKSPACE_ROOT.joinpath(f"werewolf_game/experiences/{version}") + experience_path.mkdir(parents=True, exist_ok=True) + save_path = f"{experience_path}/{round_id}.json" + write_json_file(save_path, experiences) logger.info(f"experiences saved to {save_path}") @@ -94,56 +84,49 @@ class RetrieveExperiences(Action): name: str = "RetrieveExperiences" collection_name: str = DEFAULT_COLLECTION_NAME has_experiences: bool = True - collection: Optional[Collection] = None + engine: Optional[SimpleEngine] = None + topk: int = 5 @model_validator(mode="after") def validate_collection(self): - if self.collection: + if self.engine: return - chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma") try: - self.collection = chroma_client.get_collection( - name=self.collection_name, - embedding_function=EMB_FN, + self.engine.from_index( + index_config=ChromaIndexConfig( + persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"} + ), + retriever_configs=ChromaRetrieverConfig(similarity_top_k=self.topk), ) - except: - logger.warning(f"No experience pool {self.collection_name}") - self.has_experiences = False + except Exception as exp: + logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}") - def run(self, query: str, profile: str, topk: int = 5, excluded_version: str = "", verbose: bool = False) -> str: + def run(self, query: str, profile: str, excluded_version: str = "", verbose: bool = False) -> str: """_summary_ Args: query (str): 用当前的reflection作为query去检索过去相似的reflection profile (str): _description_ - topk (int, optional): _description_. Defaults to 5. Returns: _type_: _description_ """ - if not self.has_experiences or len(query) <= 2: # not "" or not '""' + if not self.engine or len(query) <= 2: # not "" or not '""' + logger.warning("engine is None or query too short") return "" - filters = {"profile": profile} - ### 消融实验逻辑 ### - if profile == "Werewolf": # 狼人作为基线,不用经验 + # ablation experiment logic + if profile == RoleType.WEREWOLF.value: # role werewolf as baseline, don't use experiences logger.warning("Disable werewolves' experiences") return "" - if excluded_version: - filters = {"$and": [{"profile": profile}, {"version": {"$ne": excluded_version}}]} # 不用同一版本的经验,只用之前的 - ################# - results = self.collection.query( - query_texts=[query], - n_results=topk, - where=filters, - ) + results = self.engine.retrieve(query) logger.info(f"retrieve {profile}'s experiences") - past_experiences = [RoleExperience(**res) for res in results["metadatas"][0]] + past_experiences = [res.metadata["obj"] for res in results] if verbose: logger.info("past_experiences: {}".format("\n\n".join(past_experiences))) - distances = results["distances"][0] + distances = results[0].score logger.info(f"distances: {distances}") template = """ @@ -170,30 +153,3 @@ class RetrieveExperiences(Action): logger.info("retrieval done") return json.dumps(past_experiences) - - -# FIXME: below are some utility functions, should be moved to appropriate places -def delete_collection(name): - chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma") - chroma_client.delete_collection(name=name) - - -def add_file_batch(folder, **kwargs): - action = AddNewExperiences(**kwargs) - file_paths = glob.glob(str(folder) + "/*") - for fp in file_paths: - logger.info(f"file_path: {fp}") - action.add_from_file(fp) - - -def modify_collection(): - chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma") - collection = chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME) - updated_name = DEFAULT_COLLECTION_NAME + "_backup" - collection.modify(name=updated_name) - try: - chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME) - except: - logger.info(f"collection {DEFAULT_COLLECTION_NAME} not found") - updated_collection = chroma_client.get_collection(name=updated_name) - logger.info(f"updated_collection top5 documents {updated_collection.get()['documents'][-5:]}") diff --git a/metagpt/ext/werewolf/actions/moderator_actions.py b/metagpt/ext/werewolf/actions/moderator_actions.py index 8f37e3bc9..ba5d13e64 100644 --- a/metagpt/ext/werewolf/actions/moderator_actions.py +++ b/metagpt/ext/werewolf/actions/moderator_actions.py @@ -11,8 +11,9 @@ class InstructSpeak(Action): ) content = instruction_info["content"] if "{living_players}" in content and "{werewolf_players}" in content: - content = content.format(living_players=living_players, werewolf_players=werewolf_players, - werewolf_num=len(werewolf_players)) + content = content.format( + living_players=living_players, werewolf_players=werewolf_players, werewolf_num=len(werewolf_players) + ) if "{living_players}" in content: content = content.format(living_players=living_players) if "{werewolf_players}" in content: diff --git a/metagpt/ext/werewolf/schema.py b/metagpt/ext/werewolf/schema.py index ad55da516..1502a2391 100644 --- a/metagpt/ext/werewolf/schema.py +++ b/metagpt/ext/werewolf/schema.py @@ -18,6 +18,10 @@ class RoleExperience(BaseModel): game_setup: str = "" version: str = "" + def rag_key(self) -> str: + """For search""" + return self.reflection + class WwMessage(Message): # Werewolf Message diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index bd8d25013..0876b85ad 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -722,7 +722,9 @@ def list_files(root: str | Path) -> List[Path]: def parse_json_code_block(markdown_text: str) -> List[str]: - json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text] + json_blocks = ( + re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text] + ) return [v.strip() for v in json_blocks]