From 8525ec6d7bad52685fa93889296d3a5831f1233e Mon Sep 17 00:00:00 2001 From: garylin2099 Date: Tue, 17 Oct 2023 21:27:14 +0800 Subject: [PATCH] small bug fixed for reflection record and retrieve --- .../werewolf_game/actions/common_actions.py | 6 +- .../actions/experience_operation.py | 53 +++++++++++--- examples/werewolf_game/roles/base_player.py | 11 +-- examples/werewolf_game/roles/moderator.py | 2 +- examples/werewolf_game/schema.py | 2 +- examples/werewolf_game/start_game.py | 29 ++++---- .../actions/test_experience_operation.py | 72 +++++++++++++++---- 7 files changed, 131 insertions(+), 44 deletions(-) diff --git a/examples/werewolf_game/actions/common_actions.py b/examples/werewolf_game/actions/common_actions.py index cefdf4126..d9b886743 100644 --- a/examples/werewolf_game/actions/common_actions.py +++ b/examples/werewolf_game/actions/common_actions.py @@ -8,7 +8,7 @@ class Speak(Action): PROMPT_TEMPLATE = """ { - "BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__." + "BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night." ,"HISTORY": "You have knowledge to the following conversation: __context__" ,"ATTENTION": "You can NOT VOTE a player who is NOT ALIVE now!" ,"REFLECTION": "__reflection__" @@ -95,7 +95,7 @@ class NighttimeWhispers(Action): PROMPT_TEMPLATE = """ { - "BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__." + "BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night." ,"HISTORY": "You have knowledge to the following conversation: __context__" ,"ACTION": "Choose one living player to __action__." ,"ATTENTION": "1. You can only __action__ a player who is alive this night! And you can not __action__ a player who is dead this night! 2. `HISTORY` is all the information you observed, DONT hallucinate other player actions!" @@ -172,7 +172,7 @@ class Reflect(Action): PROMPT_TEMPLATE = """ { - "BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__." + "BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night." ,"HISTORY": "You have knowledge to the following conversation: __context__" ,"MODERATOR_INSTRUCTION": __latest_instruction__, ,"OUTPUT_FORMAT" (a json): diff --git a/examples/werewolf_game/actions/experience_operation.py b/examples/werewolf_game/actions/experience_operation.py index 2ced3d0c8..ea930d743 100644 --- a/examples/werewolf_game/actions/experience_operation.py +++ b/examples/werewolf_game/actions/experience_operation.py @@ -1,5 +1,6 @@ import json import os +import glob import chromadb from chromadb.utils import embedding_functions @@ -63,6 +64,7 @@ class AddNewExperiences(Action): with open(file_path, "r") as fl: lines = fl.readlines() experiences = [RoleExperience(**json.loads(line)) for line in lines] + experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""' ids = [exp.id for exp in experiences] documents = [exp.reflection for exp in experiences] @@ -77,13 +79,16 @@ class AddNewExperiences(Action): @staticmethod def _record_experiences_local(experiences: list[RoleExperience]): round_id = experiences[0].round_id + version = experiences[0].version + version = "test" if not version else version experiences = [exp.json() for exp in experiences] - experience_folder = WORKSPACE_ROOT / 'werewolf_game/experiences' + experience_folder = WORKSPACE_ROOT / f'werewolf_game/experiences/{version}' if not os.path.exists(experience_folder): os.makedirs(experience_folder) save_path = f"{experience_folder}/{round_id}.json" with open(save_path, "a") as fl: fl.write("\n".join(experiences)) + fl.write("\n") logger.info(f"experiences saved to {save_path}") class RetrieveExperiences(Action): @@ -102,7 +107,7 @@ class RetrieveExperiences(Action): logger.warning(f"No experience pool {collection_name}") self.has_experiences = False - def run(self, query: str, profile: str, topk: int = 5) -> str: + def run(self, query: str, profile: str, topk: int = 5, excluded_version: str = "", verbose: bool = False) -> str: """_summary_ Args: @@ -113,20 +118,30 @@ class RetrieveExperiences(Action): Returns: _type_: _description_ """ - if not self.has_experiences: + if not self.has_experiences or len(query) <= 2: # not "" or not '""' return "" + + filters = {"profile": profile} + ### 消融实验逻辑 ### + if profile == "Werewolf": # 狼人作为基线,不用经验 + logger.warning("Disable werewolves' experiences") + return "" + if excluded_version: + filters = {"$and": [{"profile": profile}, {"version": {"$ne": excluded_version}}]} # 不用同一版本的经验,只用之前的 + ################# results = self.collection.query( query_texts=[query], n_results=topk, - where={"profile": profile}, + where=filters, ) logger.info("retrieved exp") past_experiences = [RoleExperience(**res) for res in results["metadatas"][0]] - # print(*past_experiences, sep="\n\n") - distances = results["distances"][0] - print(distances) + if verbose: + print(*past_experiences, sep="\n\n") + distances = results["distances"][0] + print(distances) template = """ { @@ -148,9 +163,31 @@ class RetrieveExperiences(Action): return json.dumps(past_experiences) +# FIXME: below are some utility functions, should be moved to appropriate places def delete_collection(name): chroma_client = chromadb.PersistentClient(path=f"{WORKSPACE_ROOT}/werewolf_game/chroma") chroma_client.delete_collection(name=name) +def add_file_batch(folder, **kwargs): + action = AddNewExperiences(**kwargs) + file_paths = glob.glob(str(folder) + "/*") + for fp in file_paths: + print(fp) + action.add_from_file(fp) + +def modify_collection(): + chroma_client = chromadb.PersistentClient(path=f"{WORKSPACE_ROOT}/werewolf_game/chroma") + collection = chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME) + updated_name = DEFAULT_COLLECTION_NAME + "_backup" + collection.modify(name=updated_name) + try: + chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME) + except: + logger.info(f"collection {DEFAULT_COLLECTION_NAME} not found") + updated_collection = chroma_client.get_collection(name=updated_name) + print(updated_collection.get()["documents"][-5:]) + # if __name__ == "__main__": -# delete_collection(name="test") + # delete_collection(name="test") + # add_file_batch(WORKSPACE_ROOT / 'werewolf_game/experiences', collection_name=DEFAULT_COLLECTION_NAME, delete_existing=True) + # modify_collection() diff --git a/examples/werewolf_game/roles/base_player.py b/examples/werewolf_game/roles/base_player.py index 88073c559..6aa5c5e08 100644 --- a/examples/werewolf_game/roles/base_player.py +++ b/examples/werewolf_game/roles/base_player.py @@ -16,6 +16,7 @@ class BasePlayer(Role): use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False, + new_experience_version: str = "", **kwargs, ): super().__init__(name, profile, **kwargs) @@ -35,6 +36,7 @@ class BasePlayer(Role): self.use_experience = False else: self.use_experience = use_experience + self.new_experience_version = new_experience_version if self.use_experience else "" self.use_memory_selection = use_memory_selection self.experiences = [] @@ -76,8 +78,9 @@ class BasePlayer(Role): profile=self.profile, name=self.name, context=memories, latest_instruction=latest_instruction ) if self.use_reflection else "" - experiences = RetrieveExperiences().run(query=reflection, profile=self.profile) \ - if self.use_experience else "" + experiences = RetrieveExperiences().run( + query=reflection, profile=self.profile, excluded_version=self.new_experience_version + ) if self.use_experience else "" # 根据自己定义的角色Action,对应地去run,run的入参可能不同 if isinstance(todo, Speak): @@ -99,7 +102,7 @@ class BasePlayer(Role): self.experiences.append( RoleExperience(name=self.name, profile=self.profile, reflection=reflection, - instruction=latest_instruction, response=rsp) + instruction=latest_instruction, response=rsp, version=self.new_experience_version) ) logger.info(f"{self._setting}: {rsp}") @@ -121,7 +124,7 @@ class BasePlayer(Role): self.status = new_status def record_experiences(self, round_id: str, outcome: str, game_setup: str): - experiences = [exp for exp in self.experiences if exp.reflection] + experiences = [exp for exp in self.experiences if len(exp.reflection) > 2] # not "" or not '""' for exp in experiences: exp.round_id = round_id exp.outcome = outcome diff --git a/examples/werewolf_game/roles/moderator.py b/examples/werewolf_game/roles/moderator.py index 8bd1982b8..6bb294e91 100644 --- a/examples/werewolf_game/roles/moderator.py +++ b/examples/werewolf_game/roles/moderator.py @@ -166,7 +166,7 @@ class Moderator(Role): if not voted: continue voted_all.append(voted.group(0)) - self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀序号小的 + self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀最先被投的 # print("*" * 10, "dead", self.player_current_dead) self.living_players = [p for p in self.living_players if p not in self.player_current_dead] self.update_player_status(self.player_current_dead) diff --git a/examples/werewolf_game/schema.py b/examples/werewolf_game/schema.py index bfca7ad8e..311dfa30e 100644 --- a/examples/werewolf_game/schema.py +++ b/examples/werewolf_game/schema.py @@ -10,4 +10,4 @@ class RoleExperience(BaseModel): outcome: str = "" round_id: str = "" game_setup: str = "" - version: str = "01-10" + version: str = "" diff --git a/examples/werewolf_game/start_game.py b/examples/werewolf_game/start_game.py index 8d50898e0..18164b65a 100644 --- a/examples/werewolf_game/start_game.py +++ b/examples/werewolf_game/start_game.py @@ -10,7 +10,8 @@ from examples.werewolf_game.roles.human_player import prepare_human_player def init_game_setup( shuffle=True, add_human=False, - use_reflection=True, use_experience=False, use_memory_selection=False + use_reflection=True, use_experience=False, use_memory_selection=False, + new_experience_version="", ): roles = [ Villager, @@ -32,7 +33,8 @@ def init_game_setup( players = [ role( name=f"Player{i+1}", - use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection + use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection, + new_experience_version=new_experience_version ) for i, role in enumerate(roles) ] @@ -46,11 +48,14 @@ def init_game_setup( async def start_game( investment: float = 3.0, n_round: int = 5, shuffle : bool = True, add_human: bool = False, - use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False + use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False, + new_experience_version: str = "", ): game = WerewolfGame() - game_setup, players = init_game_setup(shuffle=shuffle, add_human=add_human, - use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection) + game_setup, players = init_game_setup( + shuffle=shuffle, add_human=add_human, use_reflection=use_reflection, use_experience=use_experience, + use_memory_selection=use_memory_selection, new_experience_version=new_experience_version, + ) players = [Moderator()] + players game.hire(players) game.invest(investment) @@ -58,15 +63,11 @@ async def start_game( await game.run(n_round=n_round) def main(investment: float = 20.0, n_round: int = 100, shuffle : bool = True, add_human: bool = False, - use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False): - """ - :param investment: contribute a certain dollar amount to watch the debate - :param n_round: maximum rounds of the debate - :return: - """ - asyncio.run( - start_game(investment, n_round, shuffle, add_human, use_reflection, use_experience, use_memory_selection) - ) + use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False, + new_experience_version: str = ""): + + asyncio.run(start_game(investment, n_round, shuffle, add_human, + use_reflection, use_experience, use_memory_selection, new_experience_version)) if __name__ == '__main__': diff --git a/examples/werewolf_game/tests/actions/test_experience_operation.py b/examples/werewolf_game/tests/actions/test_experience_operation.py index 54db3fec7..85a63cca4 100644 --- a/examples/werewolf_game/tests/actions/test_experience_operation.py +++ b/examples/werewolf_game/tests/actions/test_experience_operation.py @@ -11,20 +11,25 @@ from examples.werewolf_game.actions.experience_operation import AddNewExperience class TestExperiencesOperation: + collection_name = "test" test_round_id = "test_01" + version = "test" samples_to_add = [ - RoleExperience(profile="Witch", reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. Player4's behavior is suspicious.", response="", outcome="", round_id=test_round_id), - RoleExperience(profile="Witch", reflection="The game is in a critical state with only three players left, and I need to make a wise decision to save Player7 or not.", response="", outcome="", round_id=test_round_id), - RoleExperience(profile="Seer", reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, sided with him. I, as the real Seer, am under suspicion.", response="", outcome="", round_id=test_round_id), + RoleExperience(profile="Witch", reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. Player4's behavior is suspicious.", response="", outcome="", round_id=test_round_id, version=version), + RoleExperience(profile="Witch", reflection="The game is in a critical state with only three players left, and I need to make a wise decision to save Player7 or not.", response="", outcome="", round_id=test_round_id, version=version), + RoleExperience(profile="Seer", reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, sided with him. I, as the real Seer, am under suspicion.", response="", outcome="", round_id=test_round_id, version=version), + RoleExperience(profile="TestRole", reflection="Some test reflection1", response="", outcome="", round_id=test_round_id, version=version+"_01-10"), + RoleExperience(profile="TestRole", reflection="Some test reflection2", response="", outcome="", round_id=test_round_id, version=version+"_11-20"), + RoleExperience(profile="TestRole", reflection="Some test reflection3", response="", outcome="", round_id=test_round_id, version=version+"_21-30"), ] @pytest.mark.asyncio async def test_add(self): - saved_file = f"{WORKSPACE_ROOT}/werewolf_game/experiences/{self.test_round_id}.json" + saved_file = f"{WORKSPACE_ROOT}/werewolf_game/experiences/{self.version}/{self.test_round_id}.json" if os.path.exists(saved_file): os.remove(saved_file) - action = AddNewExperiences(collection_name="test", delete_existing=True) + action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True) action.run(self.samples_to_add) # test insertion @@ -33,32 +38,55 @@ class TestExperiencesOperation: # test if we record the samples correctly to local file # & test if we could recover a embedding db from the file - action = AddNewExperiences(collection_name="test", delete_existing=True) + action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True) action.add_from_file(saved_file) inserted = action.collection.get() assert len(inserted["documents"]) == len(self.samples_to_add) @pytest.mark.asyncio async def test_retrieve(self): - action = RetrieveExperiences(collection_name="test") + action = RetrieveExperiences(collection_name=self.collection_name) query = "one player claimed to be Seer and the other Witch" - results = action.run(query, "Witch") + results = action.run(query, profile="Witch") results = json.loads(results) - assert len(results) == 2 + assert len(results) == 2, "Witch should have 2 experiences" assert "The game is intense with two players" in results[0] + @pytest.mark.asyncio + async def test_retrieve_filtering(self): + action = RetrieveExperiences(collection_name=self.collection_name) + + query = "some test query" + profile = "TestRole" + + excluded_version = "" + results = action.run(query, profile=profile, excluded_version=excluded_version) + results = json.loads(results) + assert len(results) == 3 + + excluded_version = self.version + "_21-30" + results = action.run(query, profile=profile, excluded_version=excluded_version) + results = json.loads(results) + assert len(results) == 2 + +class TestActualRetrieve: + + collection_name = "role_reflection" + @pytest.mark.asyncio async def test_check_experience_pool(self): logger.info("check experience pool") - action = RetrieveExperiences(collection_name="role_reflection") - print(*action.collection.get()["metadatas"][-5:], sep="\n") + action = RetrieveExperiences(collection_name=self.collection_name) + all_experiences = action.collection.get() + logger.info(f"{len(all_experiences['metadatas'])=}") + print(*["metadatas"][-5:], sep="\n") @pytest.mark.asyncio async def test_retrieve_werewolf_experience(self): - action = RetrieveExperiences(collection_name="role_reflection") + action = RetrieveExperiences(collection_name=self.collection_name) query = "there are conflicts" @@ -68,9 +96,27 @@ class TestExperiencesOperation: @pytest.mark.asyncio async def test_retrieve_villager_experience(self): - action = RetrieveExperiences(collection_name="role_reflection") + action = RetrieveExperiences(collection_name=self.collection_name) query = "there are conflicts" logger.info(f"test retrieval with {query=}") results = action.run(query, "Seer") + assert "conflict" in results # 相似局面应该需要包含conflict关键词 + + @pytest.mark.asyncio + async def test_retrieve_villager_experience_filtering(self): + + action = RetrieveExperiences(collection_name=self.collection_name) + + query = "there are conflicts" + + excluded_version = "01-10" + logger.info(f"test retrieval with {excluded_version=}") + results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True) + + excluded_version = "11-20" + logger.info(f"test retrieval with {excluded_version=}") + results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True) + + assert results_01_10 != results_11_20