replace werewolf experience chromadb with rag

This commit is contained in:
better629 2024-04-10 19:35:46 +08:00
parent 835d6987b9
commit 5caaea2aeb
4 changed files with 57 additions and 94 deletions

View file

@ -1,18 +1,18 @@
import glob
import json
import os
from typing import Optional
import chromadb
from chromadb import Collection
from chromadb.utils import embedding_functions
from pydantic import model_validator
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.schema import RoleExperience
from metagpt.logs import logger
from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig
from metagpt.utils.common import read_json_file, write_json_file
DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now
EMB_FN = embedding_functions.OpenAIEmbeddingFunction(
@ -21,72 +21,62 @@ EMB_FN = embedding_functions.OpenAIEmbeddingFunction(
api_type=config.llm.api_type,
model_name="text-embedding-ada-002",
)
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("/werewolf_game/chroma")
PERSIST_PATH.mkdir(parents=True, exist_ok=True)
class AddNewExperiences(Action):
name: str = "AddNewExperience"
collection_name: str = DEFAULT_COLLECTION_NAME
delete_existing: bool = False
collection: Optional[Collection] = None
engine: Optional[SimpleEngine] = None
@model_validator(mode="after")
def validate_collection(self):
if self.collection:
if self.engine:
return
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
self.engine = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
]
)
if self.delete_existing:
try:
chroma_client.get_collection(name=self.collection_name)
chroma_client.delete_collection(name=self.collection_name)
logger.info(f"existing collection `{self.collection_name}` deleted")
except:
pass
self.collection = chroma_client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"},
embedding_function=EMB_FN,
)
# implement engine `DELETE` method later
self.engine.retriever._index._vector_store._collection.delete_collection(name=self.collection_name)
except Exception as exp:
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
def run(self, experiences: list[RoleExperience]):
if not experiences:
return
for i, exp in enumerate(experiences):
exp.id = f"{exp.profile}-{exp.name}-step{i}-round_{exp.round_id}"
ids = [exp.id for exp in experiences]
documents = [exp.reflection for exp in experiences]
metadatas = [exp.model_dump() for exp in experiences]
AddNewExperiences._record_experiences_local(experiences)
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
self.engine.add_objs(experiences)
def add_from_file(self, file_path):
with open(file_path, "r") as fl:
lines = fl.readlines()
experiences = [RoleExperience.model_validate_json(line) for line in lines]
experiences = read_json_file(file_path)
experiences = [RoleExperience.model_validate(item) for item in experiences]
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
ids = [exp.id for exp in experiences]
documents = [exp.reflection for exp in experiences]
metadatas = [exp.model_dump() for exp in experiences]
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
self.engine.add(experiences)
@staticmethod
def _record_experiences_local(experiences: list[RoleExperience]):
round_id = experiences[0].round_id
version = experiences[0].version
version = "test" if not version else version
experiences = [exp.model_dump_json() for exp in experiences]
experience_folder = DEFAULT_WORKSPACE_ROOT / f"werewolf_game/experiences/{version}"
if not os.path.exists(experience_folder):
os.makedirs(experience_folder)
save_path = f"{experience_folder}/{round_id}.json"
with open(save_path, "a") as fl:
fl.write("\n".join(experiences))
fl.write("\n")
experiences = [exp.model_dump() for exp in experiences]
experience_path = DEFAULT_WORKSPACE_ROOT.joinpath(f"werewolf_game/experiences/{version}")
experience_path.mkdir(parents=True, exist_ok=True)
save_path = f"{experience_path}/{round_id}.json"
write_json_file(save_path, experiences)
logger.info(f"experiences saved to {save_path}")
@ -94,56 +84,49 @@ class RetrieveExperiences(Action):
name: str = "RetrieveExperiences"
collection_name: str = DEFAULT_COLLECTION_NAME
has_experiences: bool = True
collection: Optional[Collection] = None
engine: Optional[SimpleEngine] = None
topk: int = 5
@model_validator(mode="after")
def validate_collection(self):
if self.collection:
if self.engine:
return
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
try:
self.collection = chroma_client.get_collection(
name=self.collection_name,
embedding_function=EMB_FN,
self.engine.from_index(
index_config=ChromaIndexConfig(
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
),
retriever_configs=ChromaRetrieverConfig(similarity_top_k=self.topk),
)
except:
logger.warning(f"No experience pool {self.collection_name}")
self.has_experiences = False
except Exception as exp:
logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}")
def run(self, query: str, profile: str, topk: int = 5, excluded_version: str = "", verbose: bool = False) -> str:
def run(self, query: str, profile: str, excluded_version: str = "", verbose: bool = False) -> str:
"""_summary_
Args:
query (str): 用当前的reflection作为query去检索过去相似的reflection
profile (str): _description_
topk (int, optional): _description_. Defaults to 5.
Returns:
_type_: _description_
"""
if not self.has_experiences or len(query) <= 2: # not "" or not '""'
if not self.engine or len(query) <= 2: # not "" or not '""'
logger.warning("engine is None or query too short")
return ""
filters = {"profile": profile}
### 消融实验逻辑 ###
if profile == "Werewolf": # 狼人作为基线,不用经验
# ablation experiment logic
if profile == RoleType.WEREWOLF.value: # role werewolf as baseline, don't use experiences
logger.warning("Disable werewolves' experiences")
return ""
if excluded_version:
filters = {"$and": [{"profile": profile}, {"version": {"$ne": excluded_version}}]} # 不用同一版本的经验,只用之前的
#################
results = self.collection.query(
query_texts=[query],
n_results=topk,
where=filters,
)
results = self.engine.retrieve(query)
logger.info(f"retrieve {profile}'s experiences")
past_experiences = [RoleExperience(**res) for res in results["metadatas"][0]]
past_experiences = [res.metadata["obj"] for res in results]
if verbose:
logger.info("past_experiences: {}".format("\n\n".join(past_experiences)))
distances = results["distances"][0]
distances = results[0].score
logger.info(f"distances: {distances}")
template = """
@ -170,30 +153,3 @@ class RetrieveExperiences(Action):
logger.info("retrieval done")
return json.dumps(past_experiences)
# FIXME: below are some utility functions, should be moved to appropriate places
def delete_collection(name):
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
chroma_client.delete_collection(name=name)
def add_file_batch(folder, **kwargs):
action = AddNewExperiences(**kwargs)
file_paths = glob.glob(str(folder) + "/*")
for fp in file_paths:
logger.info(f"file_path: {fp}")
action.add_from_file(fp)
def modify_collection():
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
collection = chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
updated_name = DEFAULT_COLLECTION_NAME + "_backup"
collection.modify(name=updated_name)
try:
chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
except:
logger.info(f"collection {DEFAULT_COLLECTION_NAME} not found")
updated_collection = chroma_client.get_collection(name=updated_name)
logger.info(f"updated_collection top5 documents {updated_collection.get()['documents'][-5:]}")

View file

@ -11,8 +11,9 @@ class InstructSpeak(Action):
)
content = instruction_info["content"]
if "{living_players}" in content and "{werewolf_players}" in content:
content = content.format(living_players=living_players, werewolf_players=werewolf_players,
werewolf_num=len(werewolf_players))
content = content.format(
living_players=living_players, werewolf_players=werewolf_players, werewolf_num=len(werewolf_players)
)
if "{living_players}" in content:
content = content.format(living_players=living_players)
if "{werewolf_players}" in content:

View file

@ -18,6 +18,10 @@ class RoleExperience(BaseModel):
game_setup: str = ""
version: str = ""
def rag_key(self) -> str:
"""For search"""
return self.reflection
class WwMessage(Message):
# Werewolf Message

View file

@ -722,7 +722,9 @@ def list_files(root: str | Path) -> List[Path]:
def parse_json_code_block(markdown_text: str) -> List[str]:
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
json_blocks = (
re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
)
return [v.strip() for v in json_blocks]