small bug fixed for reflection record and retrieve

This commit is contained in:
garylin2099 2023-10-17 21:27:14 +08:00
parent 205d1c9843
commit 8525ec6d7b
7 changed files with 131 additions and 44 deletions

View file

@ -8,7 +8,7 @@ class Speak(Action):
PROMPT_TEMPLATE = """
{
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__."
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
,"HISTORY": "You have knowledge to the following conversation: __context__"
,"ATTENTION": "You can NOT VOTE a player who is NOT ALIVE now!"
,"REFLECTION": "__reflection__"
@ -95,7 +95,7 @@ class NighttimeWhispers(Action):
PROMPT_TEMPLATE = """
{
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__."
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
,"HISTORY": "You have knowledge to the following conversation: __context__"
,"ACTION": "Choose one living player to __action__."
,"ATTENTION": "1. You can only __action__ a player who is alive this night! And you can not __action__ a player who is dead this night! 2. `HISTORY` is all the information you observed, DONT hallucinate other player actions!"
@ -172,7 +172,7 @@ class Reflect(Action):
PROMPT_TEMPLATE = """
{
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__."
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
,"HISTORY": "You have knowledge to the following conversation: __context__"
,"MODERATOR_INSTRUCTION": __latest_instruction__,
,"OUTPUT_FORMAT" (a json):

View file

@ -1,5 +1,6 @@
import json
import os
import glob
import chromadb
from chromadb.utils import embedding_functions
@ -63,6 +64,7 @@ class AddNewExperiences(Action):
with open(file_path, "r") as fl:
lines = fl.readlines()
experiences = [RoleExperience(**json.loads(line)) for line in lines]
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
ids = [exp.id for exp in experiences]
documents = [exp.reflection for exp in experiences]
@ -77,13 +79,16 @@ class AddNewExperiences(Action):
@staticmethod
def _record_experiences_local(experiences: list[RoleExperience]):
round_id = experiences[0].round_id
version = experiences[0].version
version = "test" if not version else version
experiences = [exp.json() for exp in experiences]
experience_folder = WORKSPACE_ROOT / 'werewolf_game/experiences'
experience_folder = WORKSPACE_ROOT / f'werewolf_game/experiences/{version}'
if not os.path.exists(experience_folder):
os.makedirs(experience_folder)
save_path = f"{experience_folder}/{round_id}.json"
with open(save_path, "a") as fl:
fl.write("\n".join(experiences))
fl.write("\n")
logger.info(f"experiences saved to {save_path}")
class RetrieveExperiences(Action):
@ -102,7 +107,7 @@ class RetrieveExperiences(Action):
logger.warning(f"No experience pool {collection_name}")
self.has_experiences = False
def run(self, query: str, profile: str, topk: int = 5) -> str:
def run(self, query: str, profile: str, topk: int = 5, excluded_version: str = "", verbose: bool = False) -> str:
"""_summary_
Args:
@ -113,20 +118,30 @@ class RetrieveExperiences(Action):
Returns:
_type_: _description_
"""
if not self.has_experiences:
if not self.has_experiences or len(query) <= 2: # not "" or not '""'
return ""
filters = {"profile": profile}
### 消融实验逻辑 ###
if profile == "Werewolf": # 狼人作为基线,不用经验
logger.warning("Disable werewolves' experiences")
return ""
if excluded_version:
filters = {"$and": [{"profile": profile}, {"version": {"$ne": excluded_version}}]} # 不用同一版本的经验,只用之前的
#################
results = self.collection.query(
query_texts=[query],
n_results=topk,
where={"profile": profile},
where=filters,
)
logger.info("retrieved exp")
past_experiences = [RoleExperience(**res) for res in results["metadatas"][0]]
# print(*past_experiences, sep="\n\n")
distances = results["distances"][0]
print(distances)
if verbose:
print(*past_experiences, sep="\n\n")
distances = results["distances"][0]
print(distances)
template = """
{
@ -148,9 +163,31 @@ class RetrieveExperiences(Action):
return json.dumps(past_experiences)
# FIXME: below are some utility functions, should be moved to appropriate places
def delete_collection(name):
chroma_client = chromadb.PersistentClient(path=f"{WORKSPACE_ROOT}/werewolf_game/chroma")
chroma_client.delete_collection(name=name)
def add_file_batch(folder, **kwargs):
action = AddNewExperiences(**kwargs)
file_paths = glob.glob(str(folder) + "/*")
for fp in file_paths:
print(fp)
action.add_from_file(fp)
def modify_collection():
chroma_client = chromadb.PersistentClient(path=f"{WORKSPACE_ROOT}/werewolf_game/chroma")
collection = chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
updated_name = DEFAULT_COLLECTION_NAME + "_backup"
collection.modify(name=updated_name)
try:
chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
except:
logger.info(f"collection {DEFAULT_COLLECTION_NAME} not found")
updated_collection = chroma_client.get_collection(name=updated_name)
print(updated_collection.get()["documents"][-5:])
# if __name__ == "__main__":
# delete_collection(name="test")
# delete_collection(name="test")
# add_file_batch(WORKSPACE_ROOT / 'werewolf_game/experiences', collection_name=DEFAULT_COLLECTION_NAME, delete_existing=True)
# modify_collection()

View file

@ -16,6 +16,7 @@ class BasePlayer(Role):
use_reflection: bool = True,
use_experience: bool = False,
use_memory_selection: bool = False,
new_experience_version: str = "",
**kwargs,
):
super().__init__(name, profile, **kwargs)
@ -35,6 +36,7 @@ class BasePlayer(Role):
self.use_experience = False
else:
self.use_experience = use_experience
self.new_experience_version = new_experience_version if self.use_experience else ""
self.use_memory_selection = use_memory_selection
self.experiences = []
@ -76,8 +78,9 @@ class BasePlayer(Role):
profile=self.profile, name=self.name, context=memories, latest_instruction=latest_instruction
) if self.use_reflection else ""
experiences = RetrieveExperiences().run(query=reflection, profile=self.profile) \
if self.use_experience else ""
experiences = RetrieveExperiences().run(
query=reflection, profile=self.profile, excluded_version=self.new_experience_version
) if self.use_experience else ""
# 根据自己定义的角色Action对应地去runrun的入参可能不同
if isinstance(todo, Speak):
@ -99,7 +102,7 @@ class BasePlayer(Role):
self.experiences.append(
RoleExperience(name=self.name, profile=self.profile, reflection=reflection,
instruction=latest_instruction, response=rsp)
instruction=latest_instruction, response=rsp, version=self.new_experience_version)
)
logger.info(f"{self._setting}: {rsp}")
@ -121,7 +124,7 @@ class BasePlayer(Role):
self.status = new_status
def record_experiences(self, round_id: str, outcome: str, game_setup: str):
experiences = [exp for exp in self.experiences if exp.reflection]
experiences = [exp for exp in self.experiences if len(exp.reflection) > 2] # not "" or not '""'
for exp in experiences:
exp.round_id = round_id
exp.outcome = outcome

View file

@ -166,7 +166,7 @@ class Moderator(Role):
if not voted:
continue
voted_all.append(voted.group(0))
self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀序号小
self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀最先被投
# print("*" * 10, "dead", self.player_current_dead)
self.living_players = [p for p in self.living_players if p not in self.player_current_dead]
self.update_player_status(self.player_current_dead)

View file

@ -10,4 +10,4 @@ class RoleExperience(BaseModel):
outcome: str = ""
round_id: str = ""
game_setup: str = ""
version: str = "01-10"
version: str = ""

View file

@ -10,7 +10,8 @@ from examples.werewolf_game.roles.human_player import prepare_human_player
def init_game_setup(
shuffle=True, add_human=False,
use_reflection=True, use_experience=False, use_memory_selection=False
use_reflection=True, use_experience=False, use_memory_selection=False,
new_experience_version="",
):
roles = [
Villager,
@ -32,7 +33,8 @@ def init_game_setup(
players = [
role(
name=f"Player{i+1}",
use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection
use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection,
new_experience_version=new_experience_version
) for i, role in enumerate(roles)
]
@ -46,11 +48,14 @@ def init_game_setup(
async def start_game(
investment: float = 3.0, n_round: int = 5, shuffle : bool = True, add_human: bool = False,
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False,
new_experience_version: str = "",
):
game = WerewolfGame()
game_setup, players = init_game_setup(shuffle=shuffle, add_human=add_human,
use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection)
game_setup, players = init_game_setup(
shuffle=shuffle, add_human=add_human, use_reflection=use_reflection, use_experience=use_experience,
use_memory_selection=use_memory_selection, new_experience_version=new_experience_version,
)
players = [Moderator()] + players
game.hire(players)
game.invest(investment)
@ -58,15 +63,11 @@ async def start_game(
await game.run(n_round=n_round)
def main(investment: float = 20.0, n_round: int = 100, shuffle : bool = True, add_human: bool = False,
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False):
"""
:param investment: contribute a certain dollar amount to watch the debate
:param n_round: maximum rounds of the debate
:return:
"""
asyncio.run(
start_game(investment, n_round, shuffle, add_human, use_reflection, use_experience, use_memory_selection)
)
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False,
new_experience_version: str = ""):
asyncio.run(start_game(investment, n_round, shuffle, add_human,
use_reflection, use_experience, use_memory_selection, new_experience_version))
if __name__ == '__main__':

View file

@ -11,20 +11,25 @@ from examples.werewolf_game.actions.experience_operation import AddNewExperience
class TestExperiencesOperation:
collection_name = "test"
test_round_id = "test_01"
version = "test"
samples_to_add = [
RoleExperience(profile="Witch", reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. Player4's behavior is suspicious.", response="", outcome="", round_id=test_round_id),
RoleExperience(profile="Witch", reflection="The game is in a critical state with only three players left, and I need to make a wise decision to save Player7 or not.", response="", outcome="", round_id=test_round_id),
RoleExperience(profile="Seer", reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, sided with him. I, as the real Seer, am under suspicion.", response="", outcome="", round_id=test_round_id),
RoleExperience(profile="Witch", reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. Player4's behavior is suspicious.", response="", outcome="", round_id=test_round_id, version=version),
RoleExperience(profile="Witch", reflection="The game is in a critical state with only three players left, and I need to make a wise decision to save Player7 or not.", response="", outcome="", round_id=test_round_id, version=version),
RoleExperience(profile="Seer", reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, sided with him. I, as the real Seer, am under suspicion.", response="", outcome="", round_id=test_round_id, version=version),
RoleExperience(profile="TestRole", reflection="Some test reflection1", response="", outcome="", round_id=test_round_id, version=version+"_01-10"),
RoleExperience(profile="TestRole", reflection="Some test reflection2", response="", outcome="", round_id=test_round_id, version=version+"_11-20"),
RoleExperience(profile="TestRole", reflection="Some test reflection3", response="", outcome="", round_id=test_round_id, version=version+"_21-30"),
]
@pytest.mark.asyncio
async def test_add(self):
saved_file = f"{WORKSPACE_ROOT}/werewolf_game/experiences/{self.test_round_id}.json"
saved_file = f"{WORKSPACE_ROOT}/werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
if os.path.exists(saved_file):
os.remove(saved_file)
action = AddNewExperiences(collection_name="test", delete_existing=True)
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
action.run(self.samples_to_add)
# test insertion
@ -33,32 +38,55 @@ class TestExperiencesOperation:
# test if we record the samples correctly to local file
# & test if we could recover a embedding db from the file
action = AddNewExperiences(collection_name="test", delete_existing=True)
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
action.add_from_file(saved_file)
inserted = action.collection.get()
assert len(inserted["documents"]) == len(self.samples_to_add)
@pytest.mark.asyncio
async def test_retrieve(self):
action = RetrieveExperiences(collection_name="test")
action = RetrieveExperiences(collection_name=self.collection_name)
query = "one player claimed to be Seer and the other Witch"
results = action.run(query, "Witch")
results = action.run(query, profile="Witch")
results = json.loads(results)
assert len(results) == 2
assert len(results) == 2, "Witch should have 2 experiences"
assert "The game is intense with two players" in results[0]
@pytest.mark.asyncio
async def test_retrieve_filtering(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "some test query"
profile = "TestRole"
excluded_version = ""
results = action.run(query, profile=profile, excluded_version=excluded_version)
results = json.loads(results)
assert len(results) == 3
excluded_version = self.version + "_21-30"
results = action.run(query, profile=profile, excluded_version=excluded_version)
results = json.loads(results)
assert len(results) == 2
class TestActualRetrieve:
collection_name = "role_reflection"
@pytest.mark.asyncio
async def test_check_experience_pool(self):
logger.info("check experience pool")
action = RetrieveExperiences(collection_name="role_reflection")
print(*action.collection.get()["metadatas"][-5:], sep="\n")
action = RetrieveExperiences(collection_name=self.collection_name)
all_experiences = action.collection.get()
logger.info(f"{len(all_experiences['metadatas'])=}")
print(*["metadatas"][-5:], sep="\n")
@pytest.mark.asyncio
async def test_retrieve_werewolf_experience(self):
action = RetrieveExperiences(collection_name="role_reflection")
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
@ -68,9 +96,27 @@ class TestExperiencesOperation:
@pytest.mark.asyncio
async def test_retrieve_villager_experience(self):
action = RetrieveExperiences(collection_name="role_reflection")
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
logger.info(f"test retrieval with {query=}")
results = action.run(query, "Seer")
assert "conflict" in results # 相似局面应该需要包含conflict关键词
@pytest.mark.asyncio
async def test_retrieve_villager_experience_filtering(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
excluded_version = "01-10"
logger.info(f"test retrieval with {excluded_version=}")
results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
excluded_version = "11-20"
logger.info(f"test retrieval with {excluded_version=}")
results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
assert results_01_10 != results_11_20