mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
replace werewolf experience chromadb with rag
This commit is contained in:
parent
835d6987b9
commit
5caaea2aeb
4 changed files with 57 additions and 94 deletions
|
|
@ -1,18 +1,18 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb import Collection
|
||||
from chromadb.utils import embedding_functions
|
||||
from pydantic import model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.environment.werewolf.const import RoleType
|
||||
from metagpt.ext.werewolf.schema import RoleExperience
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig
|
||||
from metagpt.utils.common import read_json_file, write_json_file
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now
|
||||
EMB_FN = embedding_functions.OpenAIEmbeddingFunction(
|
||||
|
|
@ -21,72 +21,62 @@ EMB_FN = embedding_functions.OpenAIEmbeddingFunction(
|
|||
api_type=config.llm.api_type,
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("/werewolf_game/chroma")
|
||||
PERSIST_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class AddNewExperiences(Action):
|
||||
name: str = "AddNewExperience"
|
||||
collection_name: str = DEFAULT_COLLECTION_NAME
|
||||
delete_existing: bool = False
|
||||
collection: Optional[Collection] = None
|
||||
engine: Optional[SimpleEngine] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_collection(self):
|
||||
if self.collection:
|
||||
if self.engine:
|
||||
return
|
||||
|
||||
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
|
||||
self.engine = SimpleEngine.from_objs(
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
]
|
||||
)
|
||||
if self.delete_existing:
|
||||
try:
|
||||
chroma_client.get_collection(name=self.collection_name)
|
||||
chroma_client.delete_collection(name=self.collection_name)
|
||||
logger.info(f"existing collection `{self.collection_name}` deleted")
|
||||
except:
|
||||
pass
|
||||
|
||||
self.collection = chroma_client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
embedding_function=EMB_FN,
|
||||
)
|
||||
# implement engine `DELETE` method later
|
||||
self.engine.retriever._index._vector_store._collection.delete_collection(name=self.collection_name)
|
||||
except Exception as exp:
|
||||
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
|
||||
|
||||
def run(self, experiences: list[RoleExperience]):
|
||||
if not experiences:
|
||||
return
|
||||
for i, exp in enumerate(experiences):
|
||||
exp.id = f"{exp.profile}-{exp.name}-step{i}-round_{exp.round_id}"
|
||||
ids = [exp.id for exp in experiences]
|
||||
documents = [exp.reflection for exp in experiences]
|
||||
metadatas = [exp.model_dump() for exp in experiences]
|
||||
|
||||
AddNewExperiences._record_experiences_local(experiences)
|
||||
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
self.engine.add_objs(experiences)
|
||||
|
||||
def add_from_file(self, file_path):
|
||||
with open(file_path, "r") as fl:
|
||||
lines = fl.readlines()
|
||||
experiences = [RoleExperience.model_validate_json(line) for line in lines]
|
||||
experiences = read_json_file(file_path)
|
||||
experiences = [RoleExperience.model_validate(item) for item in experiences]
|
||||
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
|
||||
|
||||
ids = [exp.id for exp in experiences]
|
||||
documents = [exp.reflection for exp in experiences]
|
||||
metadatas = [exp.model_dump() for exp in experiences]
|
||||
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
self.engine.add(experiences)
|
||||
|
||||
@staticmethod
|
||||
def _record_experiences_local(experiences: list[RoleExperience]):
|
||||
round_id = experiences[0].round_id
|
||||
version = experiences[0].version
|
||||
version = "test" if not version else version
|
||||
experiences = [exp.model_dump_json() for exp in experiences]
|
||||
experience_folder = DEFAULT_WORKSPACE_ROOT / f"werewolf_game/experiences/{version}"
|
||||
if not os.path.exists(experience_folder):
|
||||
os.makedirs(experience_folder)
|
||||
save_path = f"{experience_folder}/{round_id}.json"
|
||||
with open(save_path, "a") as fl:
|
||||
fl.write("\n".join(experiences))
|
||||
fl.write("\n")
|
||||
experiences = [exp.model_dump() for exp in experiences]
|
||||
|
||||
experience_path = DEFAULT_WORKSPACE_ROOT.joinpath(f"werewolf_game/experiences/{version}")
|
||||
experience_path.mkdir(parents=True, exist_ok=True)
|
||||
save_path = f"{experience_path}/{round_id}.json"
|
||||
write_json_file(save_path, experiences)
|
||||
logger.info(f"experiences saved to {save_path}")
|
||||
|
||||
|
||||
|
|
@ -94,56 +84,49 @@ class RetrieveExperiences(Action):
|
|||
name: str = "RetrieveExperiences"
|
||||
collection_name: str = DEFAULT_COLLECTION_NAME
|
||||
has_experiences: bool = True
|
||||
collection: Optional[Collection] = None
|
||||
engine: Optional[SimpleEngine] = None
|
||||
topk: int = 5
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_collection(self):
|
||||
if self.collection:
|
||||
if self.engine:
|
||||
return
|
||||
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
|
||||
try:
|
||||
self.collection = chroma_client.get_collection(
|
||||
name=self.collection_name,
|
||||
embedding_function=EMB_FN,
|
||||
self.engine.from_index(
|
||||
index_config=ChromaIndexConfig(
|
||||
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
),
|
||||
retriever_configs=ChromaRetrieverConfig(similarity_top_k=self.topk),
|
||||
)
|
||||
except:
|
||||
logger.warning(f"No experience pool {self.collection_name}")
|
||||
self.has_experiences = False
|
||||
except Exception as exp:
|
||||
logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}")
|
||||
|
||||
def run(self, query: str, profile: str, topk: int = 5, excluded_version: str = "", verbose: bool = False) -> str:
|
||||
def run(self, query: str, profile: str, excluded_version: str = "", verbose: bool = False) -> str:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
query (str): 用当前的reflection作为query去检索过去相似的reflection
|
||||
profile (str): _description_
|
||||
topk (int, optional): _description_. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if not self.has_experiences or len(query) <= 2: # not "" or not '""'
|
||||
if not self.engine or len(query) <= 2: # not "" or not '""'
|
||||
logger.warning("engine is None or query too short")
|
||||
return ""
|
||||
|
||||
filters = {"profile": profile}
|
||||
### 消融实验逻辑 ###
|
||||
if profile == "Werewolf": # 狼人作为基线,不用经验
|
||||
# ablation experiment logic
|
||||
if profile == RoleType.WEREWOLF.value: # role werewolf as baseline, don't use experiences
|
||||
logger.warning("Disable werewolves' experiences")
|
||||
return ""
|
||||
if excluded_version:
|
||||
filters = {"$and": [{"profile": profile}, {"version": {"$ne": excluded_version}}]} # 不用同一版本的经验,只用之前的
|
||||
#################
|
||||
|
||||
results = self.collection.query(
|
||||
query_texts=[query],
|
||||
n_results=topk,
|
||||
where=filters,
|
||||
)
|
||||
results = self.engine.retrieve(query)
|
||||
|
||||
logger.info(f"retrieve {profile}'s experiences")
|
||||
past_experiences = [RoleExperience(**res) for res in results["metadatas"][0]]
|
||||
past_experiences = [res.metadata["obj"] for res in results]
|
||||
if verbose:
|
||||
logger.info("past_experiences: {}".format("\n\n".join(past_experiences)))
|
||||
distances = results["distances"][0]
|
||||
distances = results[0].score
|
||||
logger.info(f"distances: {distances}")
|
||||
|
||||
template = """
|
||||
|
|
@ -170,30 +153,3 @@ class RetrieveExperiences(Action):
|
|||
logger.info("retrieval done")
|
||||
|
||||
return json.dumps(past_experiences)
|
||||
|
||||
|
||||
# FIXME: below are some utility functions, should be moved to appropriate places
|
||||
def delete_collection(name):
|
||||
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
|
||||
chroma_client.delete_collection(name=name)
|
||||
|
||||
|
||||
def add_file_batch(folder, **kwargs):
|
||||
action = AddNewExperiences(**kwargs)
|
||||
file_paths = glob.glob(str(folder) + "/*")
|
||||
for fp in file_paths:
|
||||
logger.info(f"file_path: {fp}")
|
||||
action.add_from_file(fp)
|
||||
|
||||
|
||||
def modify_collection():
|
||||
chroma_client = chromadb.PersistentClient(path=f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/chroma")
|
||||
collection = chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
|
||||
updated_name = DEFAULT_COLLECTION_NAME + "_backup"
|
||||
collection.modify(name=updated_name)
|
||||
try:
|
||||
chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
|
||||
except:
|
||||
logger.info(f"collection {DEFAULT_COLLECTION_NAME} not found")
|
||||
updated_collection = chroma_client.get_collection(name=updated_name)
|
||||
logger.info(f"updated_collection top5 documents {updated_collection.get()['documents'][-5:]}")
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ class InstructSpeak(Action):
|
|||
)
|
||||
content = instruction_info["content"]
|
||||
if "{living_players}" in content and "{werewolf_players}" in content:
|
||||
content = content.format(living_players=living_players, werewolf_players=werewolf_players,
|
||||
werewolf_num=len(werewolf_players))
|
||||
content = content.format(
|
||||
living_players=living_players, werewolf_players=werewolf_players, werewolf_num=len(werewolf_players)
|
||||
)
|
||||
if "{living_players}" in content:
|
||||
content = content.format(living_players=living_players)
|
||||
if "{werewolf_players}" in content:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ class RoleExperience(BaseModel):
|
|||
game_setup: str = ""
|
||||
version: str = ""
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.reflection
|
||||
|
||||
|
||||
class WwMessage(Message):
|
||||
# Werewolf Message
|
||||
|
|
|
|||
|
|
@ -722,7 +722,9 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
|
||||
|
||||
def parse_json_code_block(markdown_text: str) -> List[str]:
|
||||
json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
|
||||
json_blocks = (
|
||||
re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text]
|
||||
)
|
||||
|
||||
return [v.strip() for v in json_blocks]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue