update werewolf experience and add rag retrieve

This commit is contained in:
better629 2024-04-10 20:30:09 +08:00
parent 5caaea2aeb
commit a466bc9243
3 changed files with 44 additions and 29 deletions

View file

@ -1,11 +1,10 @@
import json
from typing import Optional
from chromadb.utils import embedding_functions
import chromadb
from pydantic import model_validator
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.environment.werewolf.const import RoleType
from metagpt.ext.werewolf.schema import RoleExperience
@ -15,13 +14,7 @@ from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig
from metagpt.utils.common import read_json_file, write_json_file
DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now
EMB_FN = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.llm.api_key,
api_base=config.llm.base_url,
api_type=config.llm.api_type,
model_name="text-embedding-ada-002",
)
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("/werewolf_game/chroma")
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("werewolf_game/chroma")
PERSIST_PATH.mkdir(parents=True, exist_ok=True)
@ -35,6 +28,13 @@ class AddNewExperiences(Action):
def validate_collection(self):
if self.engine:
return
if self.delete_existing:
try:
# implement engine `DELETE` method later
chromadb.PersistentClient(PERSIST_PATH.as_posix()).delete_collection(self.collection_name)
except Exception as exp:
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
self.engine = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(
@ -42,12 +42,6 @@ class AddNewExperiences(Action):
)
]
)
if self.delete_existing:
try:
# implement engine `DELETE` method later
self.engine.retriever._index._vector_store._collection.delete_collection(name=self.collection_name)
except Exception as exp:
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
def run(self, experiences: list[RoleExperience]):
if not experiences:
@ -64,7 +58,7 @@ class AddNewExperiences(Action):
experiences = [RoleExperience.model_validate(item) for item in experiences]
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
self.engine.add(experiences)
self.engine.add_objs(experiences)
@staticmethod
def _record_experiences_local(experiences: list[RoleExperience]):
@ -85,18 +79,25 @@ class RetrieveExperiences(Action):
collection_name: str = DEFAULT_COLLECTION_NAME
has_experiences: bool = True
engine: Optional[SimpleEngine] = None
topk: int = 5
topk: int = 10
@model_validator(mode="after")
def validate_collection(self):
if self.engine:
return
try:
self.engine.from_index(
self.engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
),
retriever_configs=ChromaRetrieverConfig(similarity_top_k=self.topk),
retriever_configs=[
ChromaRetrieverConfig(
similarity_top_k=self.topk,
persist_path=PERSIST_PATH,
collection_name=self.collection_name,
metadata={"hnsw:space": "cosine"},
)
],
)
except Exception as exp:
logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}")
@ -123,8 +124,14 @@ class RetrieveExperiences(Action):
results = self.engine.retrieve(query)
logger.info(f"retrieve {profile}'s experiences")
past_experiences = [res.metadata["obj"] for res in results]
if verbose:
experiences = [res.metadata["obj"] for res in results]
past_experiences = [] # currently use post-process to filter, and later add `filters` in rag
for exp in experiences:
if exp.profile == profile and exp.version != excluded_version:
past_experiences.append(exp)
if verbose and results:
logger.info("past_experiences: {}".format("\n\n".join(past_experiences)))
distances = results[0].score
logger.info(f"distances: {distances}")