diff --git a/metagpt/ext/werewolf/actions/experience_operation.py b/metagpt/ext/werewolf/actions/experience_operation.py index a32549a2a..1f2e491d6 100644 --- a/metagpt/ext/werewolf/actions/experience_operation.py +++ b/metagpt/ext/werewolf/actions/experience_operation.py @@ -1,11 +1,10 @@ import json from typing import Optional -from chromadb.utils import embedding_functions +import chromadb from pydantic import model_validator from metagpt.actions import Action -from metagpt.config2 import config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.environment.werewolf.const import RoleType from metagpt.ext.werewolf.schema import RoleExperience @@ -15,13 +14,7 @@ from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig from metagpt.utils.common import read_json_file, write_json_file DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now -EMB_FN = embedding_functions.OpenAIEmbeddingFunction( - api_key=config.llm.api_key, - api_base=config.llm.base_url, - api_type=config.llm.api_type, - model_name="text-embedding-ada-002", -) -PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("/werewolf_game/chroma") +PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("werewolf_game/chroma") PERSIST_PATH.mkdir(parents=True, exist_ok=True) @@ -35,6 +28,13 @@ class AddNewExperiences(Action): def validate_collection(self): if self.engine: return + if self.delete_existing: + try: + # implement engine `DELETE` method later + chromadb.PersistentClient(PERSIST_PATH.as_posix()).delete_collection(self.collection_name) + except Exception as exp: + logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}") + self.engine = SimpleEngine.from_objs( retriever_configs=[ ChromaRetrieverConfig( @@ -42,12 +42,6 @@ class AddNewExperiences(Action): ) ] ) - if self.delete_existing: - try: - # implement engine `DELETE` method later - self.engine.retriever._index._vector_store._collection.delete_collection(name=self.collection_name) - except Exception as exp: - logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}") def run(self, experiences: list[RoleExperience]): if not experiences: @@ -64,7 +58,7 @@ class AddNewExperiences(Action): experiences = [RoleExperience.model_validate(item) for item in experiences] experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""' - self.engine.add(experiences) + self.engine.add_objs(experiences) @staticmethod def _record_experiences_local(experiences: list[RoleExperience]): @@ -85,18 +79,25 @@ class RetrieveExperiences(Action): collection_name: str = DEFAULT_COLLECTION_NAME has_experiences: bool = True engine: Optional[SimpleEngine] = None - topk: int = 5 + topk: int = 10 @model_validator(mode="after") def validate_collection(self): if self.engine: return try: - self.engine.from_index( + self.engine = SimpleEngine.from_index( index_config=ChromaIndexConfig( persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"} ), - retriever_configs=ChromaRetrieverConfig(similarity_top_k=self.topk), + retriever_configs=[ + ChromaRetrieverConfig( + similarity_top_k=self.topk, + persist_path=PERSIST_PATH, + collection_name=self.collection_name, + metadata={"hnsw:space": "cosine"}, + ) + ], ) except Exception as exp: logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}") @@ -123,8 +124,14 @@ class RetrieveExperiences(Action): results = self.engine.retrieve(query) logger.info(f"retrieve {profile}'s experiences") - past_experiences = [res.metadata["obj"] for res in results] - if verbose: + experiences = [res.metadata["obj"] for res in results] + + past_experiences = [] # currently use post-process to filter, and later add `filters` in rag + for exp in experiences: + if exp.profile == profile and exp.version != excluded_version: + past_experiences.append(exp) + + if verbose and results: logger.info("past_experiences: {}".format("\n\n".join(past_experiences))) distances = results[0].score logger.info(f"distances: {distances}") diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 5c5810308..34f925249 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -161,6 +161,13 @@ class SimpleEngine(RetrieverQueryEngine): """Inplement tools.SearchInterface""" return await self.aquery(content) + def retrieve(self, query: QueryType) -> list[NodeWithScore]: + query_bundle = QueryBundle(query) if isinstance(query, str) else query + + nodes = super().retrieve(query_bundle) + self._try_reconstruct_obj(nodes) + return nodes + async def aretrieve(self, query: QueryType) -> list[NodeWithScore]: """Allow query to be str.""" query_bundle = QueryBundle(query) if isinstance(query, str) else query diff --git a/tests/metagpt/ext/werewolf/actions/test_experience_operation.py b/tests/metagpt/ext/werewolf/actions/test_experience_operation.py index af7dfe807..a31abc49a 100644 --- a/tests/metagpt/ext/werewolf/actions/test_experience_operation.py +++ b/tests/metagpt/ext/werewolf/actions/test_experience_operation.py @@ -1,5 +1,4 @@ import json -import os import pytest @@ -69,22 +68,24 @@ class TestExperiencesOperation: @pytest.mark.asyncio async def test_add(self): - saved_file = f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/experiences/{self.version}/{self.test_round_id}.json" - if os.path.exists(saved_file): - os.remove(saved_file) + saved_file = DEFAULT_WORKSPACE_ROOT.joinpath( + f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json" + ) + if saved_file.exists(): + saved_file.unlink() action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True) action.run(self.samples_to_add) # test insertion - inserted = action.collection.get() + inserted = action.engine.retriever._index._vector_store._collection.get() assert len(inserted["documents"]) == len(self.samples_to_add) # test if we record the samples correctly to local file # & test if we could recover a embedding db from the file action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True) action.add_from_file(saved_file) - inserted = action.collection.get() + inserted = action.engine.retriever._index._vector_store._collection.get() assert len(inserted["documents"]) == len(self.samples_to_add) @pytest.mark.asyncio @@ -123,8 +124,8 @@ class TestActualRetrieve: async def test_check_experience_pool(self): logger.info("check experience pool") action = RetrieveExperiences(collection_name=self.collection_name) - if action.collection: - all_experiences = action.collection.get() + if action.engine: + all_experiences = action.engine.retriever._index._vector_store._collection.get() logger.info(f"{len(all_experiences['metadatas'])=}") @pytest.mark.asyncio