mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
small bug fixed for reflection record and retrieve
This commit is contained in:
parent
205d1c9843
commit
8525ec6d7b
7 changed files with 131 additions and 44 deletions
|
|
@ -8,7 +8,7 @@ class Speak(Action):
|
|||
|
||||
PROMPT_TEMPLATE = """
|
||||
{
|
||||
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__."
|
||||
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
|
||||
,"HISTORY": "You have knowledge to the following conversation: __context__"
|
||||
,"ATTENTION": "You can NOT VOTE a player who is NOT ALIVE now!"
|
||||
,"REFLECTION": "__reflection__"
|
||||
|
|
@ -95,7 +95,7 @@ class NighttimeWhispers(Action):
|
|||
|
||||
PROMPT_TEMPLATE = """
|
||||
{
|
||||
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__."
|
||||
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
|
||||
,"HISTORY": "You have knowledge to the following conversation: __context__"
|
||||
,"ACTION": "Choose one living player to __action__."
|
||||
,"ATTENTION": "1. You can only __action__ a player who is alive this night! And you can not __action__ a player who is dead this night! 2. `HISTORY` is all the information you observed, DONT hallucinate other player actions!"
|
||||
|
|
@ -172,7 +172,7 @@ class Reflect(Action):
|
|||
|
||||
PROMPT_TEMPLATE = """
|
||||
{
|
||||
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__."
|
||||
"BACKGROUND": "It's a Werewolf game, in this game, we have 2 werewolves, 2 villagers, 1 guard, 1 witch, 1 seer. You are __profile__. Note that villager, seer, guard and witch are all in villager side, they have the same objective. Werewolves can collectively hunt ONE player at night."
|
||||
,"HISTORY": "You have knowledge to the following conversation: __context__"
|
||||
,"MODERATOR_INSTRUCTION": __latest_instruction__,
|
||||
,"OUTPUT_FORMAT" (a json):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
import glob
|
||||
|
||||
import chromadb
|
||||
from chromadb.utils import embedding_functions
|
||||
|
|
@ -63,6 +64,7 @@ class AddNewExperiences(Action):
|
|||
with open(file_path, "r") as fl:
|
||||
lines = fl.readlines()
|
||||
experiences = [RoleExperience(**json.loads(line)) for line in lines]
|
||||
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
|
||||
|
||||
ids = [exp.id for exp in experiences]
|
||||
documents = [exp.reflection for exp in experiences]
|
||||
|
|
@ -77,13 +79,16 @@ class AddNewExperiences(Action):
|
|||
@staticmethod
|
||||
def _record_experiences_local(experiences: list[RoleExperience]):
|
||||
round_id = experiences[0].round_id
|
||||
version = experiences[0].version
|
||||
version = "test" if not version else version
|
||||
experiences = [exp.json() for exp in experiences]
|
||||
experience_folder = WORKSPACE_ROOT / 'werewolf_game/experiences'
|
||||
experience_folder = WORKSPACE_ROOT / f'werewolf_game/experiences/{version}'
|
||||
if not os.path.exists(experience_folder):
|
||||
os.makedirs(experience_folder)
|
||||
save_path = f"{experience_folder}/{round_id}.json"
|
||||
with open(save_path, "a") as fl:
|
||||
fl.write("\n".join(experiences))
|
||||
fl.write("\n")
|
||||
logger.info(f"experiences saved to {save_path}")
|
||||
|
||||
class RetrieveExperiences(Action):
|
||||
|
|
@ -102,7 +107,7 @@ class RetrieveExperiences(Action):
|
|||
logger.warning(f"No experience pool {collection_name}")
|
||||
self.has_experiences = False
|
||||
|
||||
def run(self, query: str, profile: str, topk: int = 5) -> str:
|
||||
def run(self, query: str, profile: str, topk: int = 5, excluded_version: str = "", verbose: bool = False) -> str:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
|
|
@ -113,20 +118,30 @@ class RetrieveExperiences(Action):
|
|||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if not self.has_experiences:
|
||||
if not self.has_experiences or len(query) <= 2: # not "" or not '""'
|
||||
return ""
|
||||
|
||||
filters = {"profile": profile}
|
||||
### 消融实验逻辑 ###
|
||||
if profile == "Werewolf": # 狼人作为基线,不用经验
|
||||
logger.warning("Disable werewolves' experiences")
|
||||
return ""
|
||||
if excluded_version:
|
||||
filters = {"$and": [{"profile": profile}, {"version": {"$ne": excluded_version}}]} # 不用同一版本的经验,只用之前的
|
||||
#################
|
||||
|
||||
results = self.collection.query(
|
||||
query_texts=[query],
|
||||
n_results=topk,
|
||||
where={"profile": profile},
|
||||
where=filters,
|
||||
)
|
||||
|
||||
logger.info("retrieved exp")
|
||||
past_experiences = [RoleExperience(**res) for res in results["metadatas"][0]]
|
||||
# print(*past_experiences, sep="\n\n")
|
||||
distances = results["distances"][0]
|
||||
print(distances)
|
||||
if verbose:
|
||||
print(*past_experiences, sep="\n\n")
|
||||
distances = results["distances"][0]
|
||||
print(distances)
|
||||
|
||||
template = """
|
||||
{
|
||||
|
|
@ -148,9 +163,31 @@ class RetrieveExperiences(Action):
|
|||
|
||||
return json.dumps(past_experiences)
|
||||
|
||||
# FIXME: below are some utility functions, should be moved to appropriate places
|
||||
def delete_collection(name):
|
||||
chroma_client = chromadb.PersistentClient(path=f"{WORKSPACE_ROOT}/werewolf_game/chroma")
|
||||
chroma_client.delete_collection(name=name)
|
||||
|
||||
def add_file_batch(folder, **kwargs):
|
||||
action = AddNewExperiences(**kwargs)
|
||||
file_paths = glob.glob(str(folder) + "/*")
|
||||
for fp in file_paths:
|
||||
print(fp)
|
||||
action.add_from_file(fp)
|
||||
|
||||
def modify_collection():
|
||||
chroma_client = chromadb.PersistentClient(path=f"{WORKSPACE_ROOT}/werewolf_game/chroma")
|
||||
collection = chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
|
||||
updated_name = DEFAULT_COLLECTION_NAME + "_backup"
|
||||
collection.modify(name=updated_name)
|
||||
try:
|
||||
chroma_client.get_collection(name=DEFAULT_COLLECTION_NAME)
|
||||
except:
|
||||
logger.info(f"collection {DEFAULT_COLLECTION_NAME} not found")
|
||||
updated_collection = chroma_client.get_collection(name=updated_name)
|
||||
print(updated_collection.get()["documents"][-5:])
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# delete_collection(name="test")
|
||||
# delete_collection(name="test")
|
||||
# add_file_batch(WORKSPACE_ROOT / 'werewolf_game/experiences', collection_name=DEFAULT_COLLECTION_NAME, delete_existing=True)
|
||||
# modify_collection()
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class BasePlayer(Role):
|
|||
use_reflection: bool = True,
|
||||
use_experience: bool = False,
|
||||
use_memory_selection: bool = False,
|
||||
new_experience_version: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name, profile, **kwargs)
|
||||
|
|
@ -35,6 +36,7 @@ class BasePlayer(Role):
|
|||
self.use_experience = False
|
||||
else:
|
||||
self.use_experience = use_experience
|
||||
self.new_experience_version = new_experience_version if self.use_experience else ""
|
||||
self.use_memory_selection = use_memory_selection
|
||||
|
||||
self.experiences = []
|
||||
|
|
@ -76,8 +78,9 @@ class BasePlayer(Role):
|
|||
profile=self.profile, name=self.name, context=memories, latest_instruction=latest_instruction
|
||||
) if self.use_reflection else ""
|
||||
|
||||
experiences = RetrieveExperiences().run(query=reflection, profile=self.profile) \
|
||||
if self.use_experience else ""
|
||||
experiences = RetrieveExperiences().run(
|
||||
query=reflection, profile=self.profile, excluded_version=self.new_experience_version
|
||||
) if self.use_experience else ""
|
||||
|
||||
# 根据自己定义的角色Action,对应地去run,run的入参可能不同
|
||||
if isinstance(todo, Speak):
|
||||
|
|
@ -99,7 +102,7 @@ class BasePlayer(Role):
|
|||
|
||||
self.experiences.append(
|
||||
RoleExperience(name=self.name, profile=self.profile, reflection=reflection,
|
||||
instruction=latest_instruction, response=rsp)
|
||||
instruction=latest_instruction, response=rsp, version=self.new_experience_version)
|
||||
)
|
||||
|
||||
logger.info(f"{self._setting}: {rsp}")
|
||||
|
|
@ -121,7 +124,7 @@ class BasePlayer(Role):
|
|||
self.status = new_status
|
||||
|
||||
def record_experiences(self, round_id: str, outcome: str, game_setup: str):
|
||||
experiences = [exp for exp in self.experiences if exp.reflection]
|
||||
experiences = [exp for exp in self.experiences if len(exp.reflection) > 2] # not "" or not '""'
|
||||
for exp in experiences:
|
||||
exp.round_id = round_id
|
||||
exp.outcome = outcome
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class Moderator(Role):
|
|||
if not voted:
|
||||
continue
|
||||
voted_all.append(voted.group(0))
|
||||
self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀序号小的
|
||||
self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀最先被投的
|
||||
# print("*" * 10, "dead", self.player_current_dead)
|
||||
self.living_players = [p for p in self.living_players if p not in self.player_current_dead]
|
||||
self.update_player_status(self.player_current_dead)
|
||||
|
|
|
|||
|
|
@ -10,4 +10,4 @@ class RoleExperience(BaseModel):
|
|||
outcome: str = ""
|
||||
round_id: str = ""
|
||||
game_setup: str = ""
|
||||
version: str = "01-10"
|
||||
version: str = ""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from examples.werewolf_game.roles.human_player import prepare_human_player
|
|||
|
||||
def init_game_setup(
|
||||
shuffle=True, add_human=False,
|
||||
use_reflection=True, use_experience=False, use_memory_selection=False
|
||||
use_reflection=True, use_experience=False, use_memory_selection=False,
|
||||
new_experience_version="",
|
||||
):
|
||||
roles = [
|
||||
Villager,
|
||||
|
|
@ -32,7 +33,8 @@ def init_game_setup(
|
|||
players = [
|
||||
role(
|
||||
name=f"Player{i+1}",
|
||||
use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection
|
||||
use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection,
|
||||
new_experience_version=new_experience_version
|
||||
) for i, role in enumerate(roles)
|
||||
]
|
||||
|
||||
|
|
@ -46,11 +48,14 @@ def init_game_setup(
|
|||
|
||||
async def start_game(
|
||||
investment: float = 3.0, n_round: int = 5, shuffle : bool = True, add_human: bool = False,
|
||||
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False
|
||||
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False,
|
||||
new_experience_version: str = "",
|
||||
):
|
||||
game = WerewolfGame()
|
||||
game_setup, players = init_game_setup(shuffle=shuffle, add_human=add_human,
|
||||
use_reflection=use_reflection, use_experience=use_experience, use_memory_selection=use_memory_selection)
|
||||
game_setup, players = init_game_setup(
|
||||
shuffle=shuffle, add_human=add_human, use_reflection=use_reflection, use_experience=use_experience,
|
||||
use_memory_selection=use_memory_selection, new_experience_version=new_experience_version,
|
||||
)
|
||||
players = [Moderator()] + players
|
||||
game.hire(players)
|
||||
game.invest(investment)
|
||||
|
|
@ -58,15 +63,11 @@ async def start_game(
|
|||
await game.run(n_round=n_round)
|
||||
|
||||
def main(investment: float = 20.0, n_round: int = 100, shuffle : bool = True, add_human: bool = False,
|
||||
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False):
|
||||
"""
|
||||
:param investment: contribute a certain dollar amount to watch the debate
|
||||
:param n_round: maximum rounds of the debate
|
||||
:return:
|
||||
"""
|
||||
asyncio.run(
|
||||
start_game(investment, n_round, shuffle, add_human, use_reflection, use_experience, use_memory_selection)
|
||||
)
|
||||
use_reflection: bool = True, use_experience: bool = False, use_memory_selection: bool = False,
|
||||
new_experience_version: str = ""):
|
||||
|
||||
asyncio.run(start_game(investment, n_round, shuffle, add_human,
|
||||
use_reflection, use_experience, use_memory_selection, new_experience_version))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -11,20 +11,25 @@ from examples.werewolf_game.actions.experience_operation import AddNewExperience
|
|||
|
||||
class TestExperiencesOperation:
|
||||
|
||||
collection_name = "test"
|
||||
test_round_id = "test_01"
|
||||
version = "test"
|
||||
samples_to_add = [
|
||||
RoleExperience(profile="Witch", reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. Player4's behavior is suspicious.", response="", outcome="", round_id=test_round_id),
|
||||
RoleExperience(profile="Witch", reflection="The game is in a critical state with only three players left, and I need to make a wise decision to save Player7 or not.", response="", outcome="", round_id=test_round_id),
|
||||
RoleExperience(profile="Seer", reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, sided with him. I, as the real Seer, am under suspicion.", response="", outcome="", round_id=test_round_id),
|
||||
RoleExperience(profile="Witch", reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. Player4's behavior is suspicious.", response="", outcome="", round_id=test_round_id, version=version),
|
||||
RoleExperience(profile="Witch", reflection="The game is in a critical state with only three players left, and I need to make a wise decision to save Player7 or not.", response="", outcome="", round_id=test_round_id, version=version),
|
||||
RoleExperience(profile="Seer", reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, sided with him. I, as the real Seer, am under suspicion.", response="", outcome="", round_id=test_round_id, version=version),
|
||||
RoleExperience(profile="TestRole", reflection="Some test reflection1", response="", outcome="", round_id=test_round_id, version=version+"_01-10"),
|
||||
RoleExperience(profile="TestRole", reflection="Some test reflection2", response="", outcome="", round_id=test_round_id, version=version+"_11-20"),
|
||||
RoleExperience(profile="TestRole", reflection="Some test reflection3", response="", outcome="", round_id=test_round_id, version=version+"_21-30"),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add(self):
|
||||
saved_file = f"{WORKSPACE_ROOT}/werewolf_game/experiences/{self.test_round_id}.json"
|
||||
saved_file = f"{WORKSPACE_ROOT}/werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
if os.path.exists(saved_file):
|
||||
os.remove(saved_file)
|
||||
|
||||
action = AddNewExperiences(collection_name="test", delete_existing=True)
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.run(self.samples_to_add)
|
||||
|
||||
# test insertion
|
||||
|
|
@ -33,32 +38,55 @@ class TestExperiencesOperation:
|
|||
|
||||
# test if we record the samples correctly to local file
|
||||
# & test if we could recover a embedding db from the file
|
||||
action = AddNewExperiences(collection_name="test", delete_existing=True)
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.add_from_file(saved_file)
|
||||
inserted = action.collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve(self):
|
||||
action = RetrieveExperiences(collection_name="test")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "one player claimed to be Seer and the other Witch"
|
||||
results = action.run(query, "Witch")
|
||||
results = action.run(query, profile="Witch")
|
||||
results = json.loads(results)
|
||||
|
||||
assert len(results) == 2
|
||||
assert len(results) == 2, "Witch should have 2 experiences"
|
||||
assert "The game is intense with two players" in results[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_filtering(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "some test query"
|
||||
profile = "TestRole"
|
||||
|
||||
excluded_version = ""
|
||||
results = action.run(query, profile=profile, excluded_version=excluded_version)
|
||||
results = json.loads(results)
|
||||
assert len(results) == 3
|
||||
|
||||
excluded_version = self.version + "_21-30"
|
||||
results = action.run(query, profile=profile, excluded_version=excluded_version)
|
||||
results = json.loads(results)
|
||||
assert len(results) == 2
|
||||
|
||||
class TestActualRetrieve:
|
||||
|
||||
collection_name = "role_reflection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_experience_pool(self):
|
||||
logger.info("check experience pool")
|
||||
action = RetrieveExperiences(collection_name="role_reflection")
|
||||
print(*action.collection.get()["metadatas"][-5:], sep="\n")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
all_experiences = action.collection.get()
|
||||
logger.info(f"{len(all_experiences['metadatas'])=}")
|
||||
print(*["metadatas"][-5:], sep="\n")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_werewolf_experience(self):
|
||||
|
||||
action = RetrieveExperiences(collection_name="role_reflection")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
|
|
@ -68,9 +96,27 @@ class TestExperiencesOperation:
|
|||
@pytest.mark.asyncio
|
||||
async def test_retrieve_villager_experience(self):
|
||||
|
||||
action = RetrieveExperiences(collection_name="role_reflection")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
logger.info(f"test retrieval with {query=}")
|
||||
results = action.run(query, "Seer")
|
||||
assert "conflict" in results # 相似局面应该需要包含conflict关键词
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_villager_experience_filtering(self):
|
||||
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
excluded_version = "01-10"
|
||||
logger.info(f"test retrieval with {excluded_version=}")
|
||||
results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
|
||||
|
||||
excluded_version = "11-20"
|
||||
logger.info(f"test retrieval with {excluded_version=}")
|
||||
results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
|
||||
|
||||
assert results_01_10 != results_11_20
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue