update werewolf_ext_env to add role permission protection and remove useless fields

This commit is contained in:
better629 2024-02-01 16:15:39 +08:00
parent c0a4d7c4c9
commit 3a4acb8f22
2 changed files with 162 additions and 66 deletions

View file

@ -3,7 +3,6 @@
# @Desc : The werewolf game external environment to integrate with
import random
import re
from collections import Counter
from enum import Enum
from typing import Callable, Optional
@ -101,17 +100,17 @@ STEP_INSTRUCTIONS = {
class WerewolfExtEnv(ExtEnv):
model_config = ConfigDict(arbitrary_types_allowed=True)
roles_state: dict[str, RoleState] = Field(default=dict(), description="the role's current state by role_name")
players_state: dict[str, tuple[str, RoleState]] = Field(
default=dict(), description="the player's role type and state by player_name"
)
round_idx: int = Field(default=0) # the current round
step_idx: int = Field(default=0) # the current step of current round
eval_step_idx: int = Field(default=0)
per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS))
# game global states
game_setup: str = Field(default="", description="game setup including role and its num")
living_players: list[str] = Field(default=[])
werewolf_players: list[str] = Field(default=[])
villager_players: list[str] = Field(default=[])
special_role_players: list[str] = Field(default=[])
winner: Optional[str] = Field(default=None)
win_reason: Optional[str] = Field(default=None)
@ -119,27 +118,50 @@ class WerewolfExtEnv(ExtEnv):
witch_antidote_left: int = Field(default=1)
# game current round states, a round is from closing your eyes to the next time you close your eyes
round_hunts: dict[str, str] = Field(default=dict(), description="nighttime wolf hunt result")
round_votes: dict[str, str] = Field(
default=dict(), description="daytime all players vote result, key=voteer, value=voted one"
)
player_hunted: Optional[str] = Field(default=None)
player_protected: Optional[str] = Field(default=None)
is_hunted_player_saved: bool = Field(default=False)
player_poisoned: Optional[str] = Field(default=None)
player_current_dead: list[str] = Field(default=[])
def parse_game_setup(self, game_setup: str):
self.game_setup = game_setup
self.living_players = re.findall(r"Player[0-9]+", game_setup)
self.werewolf_players = re.findall(r"Player[0-9]+: Werewolf", game_setup)
self.werewolf_players = [p.replace(": Werewolf", "") for p in self.werewolf_players]
self.villager_players = re.findall(r"Player[0-9]+: Villager", game_setup)
self.villager_players = [p.replace(": Villager", "") for p in self.villager_players]
@property
def living_players(self) -> list[str]:
player_names = []
for name, roletype_state in self.players_state.items():
if roletype_state[1] in [RoleState.ALIVE, RoleState.SAVED]:
player_names.append(name)
return player_names
def _role_type_players(self, role_type: str) -> list[str]:
"""return player name of particular role type"""
player_names = []
for name, roletype_state in self.players_state.items():
if role_type in roletype_state[0]:
player_names.append(name)
return player_names
@property
def werewolf_players(self) -> list[str]:
player_names = self._role_type_players(role_type="Werewolf")
return player_names
@property
def villager_players(self) -> list[str]:
player_names = self._role_type_players(role_type="Villager")
return player_names
def _init_players_state(self, players: list["Role"]):
for play in players:
self.players_state[play.name] = (play.profile, RoleState.ALIVE)
self.special_role_players = [
p for p in self.living_players if p not in self.werewolf_players + self.villager_players
]
# init role state
self.roles_state = {player_name: RoleState.ALIVE for player_name in self.living_players}
@mark_as_readable
def init_game_setup(
self,
role_uniq_objs: list[object],
@ -153,6 +175,7 @@ class WerewolfExtEnv(ExtEnv):
new_experience_version="",
prepare_human_player=Callable,
) -> tuple[str, list]:
"""init players using different roles' num"""
role_objs = []
for role_obj in role_uniq_objs:
if str(role_obj) == "Villager":
@ -183,9 +206,30 @@ class WerewolfExtEnv(ExtEnv):
logger.info(f"You are assigned {players[assigned_role_idx].name}({players[assigned_role_idx].profile})")
game_setup = ["Game setup:"] + [f"{player.name}: {player.profile}," for player in players]
game_setup = "\n".join(game_setup)
self.game_setup = "\n".join(game_setup)
return game_setup, players
self._init_players_state(players) # init players state
return self.game_setup, players
def _update_players_state(self, player_names: list[str], state: RoleState = RoleState.KILLED):
for player_name in player_names:
if player_name in self.players_state:
roletype_state = self.players_state[player_name]
self.players_state[player_name] = (roletype_state[0], state)
def _check_valid_role(self, player: "Role", role_type: str) -> bool:
return True if role_type in str(player) else False
def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool:
step_idx = self.step_idx % self.per_round_steps
if particular_step > 0 and step_idx != particular_step: # step no
# particular_step = 18, not daytime vote time, ignore
# particular_step = 15, not nighttime hunt time, ignore
return False
if player_name not in self.living_players:
return False
return True
@mark_as_readable
def curr_step_instruction(self) -> dict:
@ -194,32 +238,64 @@ class WerewolfExtEnv(ExtEnv):
self.step_idx += 1
return instruction
@mark_as_writeable
def update_players_state(self, player_names: list[str], state: RoleState = RoleState.KILLED):
for player_name in player_names:
if player_name in self.roles_state:
self.roles_state[player_name] = state
@mark_as_readable
def get_players_status(self, player_names: list[str]) -> dict[str, RoleState]:
roles_state = {
player_name: self.roles_state[player_name]
def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]:
players_state = {
player_name: self.players_state[player_name][1] # only return role state
for player_name in player_names
if player_name in self.roles_state
if player_name in self.players_state
}
return roles_state
return players_state
@mark_as_writeable
def wolf_kill_someone(self, player_name: str):
self.update_players_state([player_name], RoleState.KILLED)
def vote_kill_someone(self, voteer: "Role", player_name: str = None):
"""player vote result at daytime
player_name: if it's None, regard as abstaining from voting
"""
if not self._check_player_continue(voteer.name, particular_step=18): # 18=step no
return
self.round_votes[voteer.name] = player_name
# check if all living players finish voting, then get the dead one
if list(self.round_votes.keys()) == self.living_players:
voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first
voted_all = [item for item in voted_all if item]
self.player_current_dead = Counter(voted_all).most_common()[0][0]
self._update_players_state([self.player_current_dead])
@mark_as_writeable
def witch_poison_someone(self, player_name: str = None):
self.update_players_state([player_name], RoleState.POISONED)
def wolf_kill_someone(self, wolf: "Role", player_name: str):
if not self._check_valid_role(wolf, "Werewolf"):
return
if not self._check_player_continue(wolf.name, particular_step=5): # 5=step no
return
self.round_hunts[wolf.name] = player_name
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
# check if all living wolfs finish hunting, then get the hunted one
if list(self.round_hunts.keys()) == living_werewolf:
hunted_all = list(self.round_hunts.values())
self.player_hunted = Counter(hunted_all).most_common()[0][0]
@mark_as_writeable
def witch_save_someone(self, player_name: str = None):
self.update_players_state([player_name], RoleState.SAVED)
def witch_poison_someone(self, witch: "Role", player_name: str = None):
if not self._check_valid_role(witch, "Witch"):
return
if not self._check_player_continue(player_name):
return
self._update_players_state([player_name], RoleState.POISONED)
self.player_poisoned = player_name
@mark_as_writeable
def witch_save_someone(self, witch: "Role", player_name: str = None):
if not self._check_valid_role(witch, "Witch"):
return
if not self._check_player_continue(player_name):
return
self._update_players_state([player_name], RoleState.SAVED)
self.player_protected = player_name
@mark_as_writeable
def update_game_states(self, memories: list):
@ -238,28 +314,13 @@ class WerewolfExtEnv(ExtEnv):
if self.player_poisoned:
self.player_current_dead.append(self.player_poisoned)
self.living_players = [p for p in self.living_players if p not in self.player_current_dead]
self.update_player_status(self.player_current_dead)
self._update_players_state([self.player_current_dead])
# reset
self.player_hunted = None
self.player_protected = None
self.is_hunted_player_saved = False
self.player_poisoned = None
elif step_idx == 18: # step no
# day ends: after all roles voted, process all votings
voting_msgs = memories[-len(self.living_players) :]
voted_all = []
for msg in voting_msgs:
voted = re.search(r"Player[0-9]+", msg.content[-10:])
if not voted:
continue
voted_all.append(voted.group(0))
self.player_current_dead = [Counter(voted_all).most_common()[0][0]] # 平票时,杀最先被投的
# print("*" * 10, "dead", self.player_current_dead)
self.living_players = [p for p in self.living_players if p not in self.player_current_dead]
self.update_player_status(self.player_current_dead)
# game's termination condition
living_werewolf = [p for p in self.werewolf_players if p in self.living_players]
living_villagers = [p for p in self.villager_players if p in self.living_players]

View file

@ -2,28 +2,63 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of WerewolfExtEnv
from metagpt.environment.werewolf_env.werewolf_ext_env import RoleState, WerewolfExtEnv
from metagpt.roles.role import Role
class Werewolf(Role):
profile: str = "Werewolf"
class Villager(Role):
profile: str = "Villager"
class Witch(Role):
profile: str = "Witch"
class Guard(Role):
profile: str = "Guard"
def test_werewolf_ext_env():
ext_env = WerewolfExtEnv()
players_state = {
"Player0": ("Werewolf", RoleState.ALIVE),
"Player1": ("Werewolf", RoleState.ALIVE),
"Player2": ("Villager", RoleState.ALIVE),
"Player3": ("Witch", RoleState.ALIVE),
"Player4": ("Guard", RoleState.ALIVE),
}
ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"])
game_setup = """Game setup:
Player0: Werewolf,
Player1: Werewolf,
Player2: Villager,
Player3: Guard,
"""
ext_env.parse_game_setup(game_setup)
assert len(ext_env.living_players) == 4
assert len(ext_env.special_role_players) == 1
assert len(ext_env.living_players) == 5
assert len(ext_env.special_role_players) == 2
assert len(ext_env.werewolf_players) == 2
curr_instr = ext_env.curr_step_instruction()
assert ext_env.step_idx == 1
assert "close your eyes" in curr_instr["content"]
assert ext_env.step_idx == 5
assert "Werewolves, please open your eyes" in curr_instr["content"]
# current step_idx = 5
ext_env.wolf_kill_someone(wolf=Role(name="Player10"), player_name="Player4")
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player0"), player_name="Player4")
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player1"), player_name="Player4")
assert ext_env.player_hunted == "Player4"
assert len(ext_env.living_players) == 5 # hunted but can be saved by witch
for idx in range(13):
_ = ext_env.curr_step_instruction()
# current step_idx = 18
assert ext_env.step_idx == 18
ext_env.vote_kill_someone(voteer=Werewolf(name="Player0"), player_name="Player2")
ext_env.vote_kill_someone(voteer=Werewolf(name="Player1"), player_name="Player3")
ext_env.vote_kill_someone(voteer=Villager(name="Player2"), player_name="Player3")
ext_env.vote_kill_someone(voteer=Witch(name="Player3"), player_name="Player4")
ext_env.vote_kill_someone(voteer=Guard(name="Player4"), player_name="Player2")
assert ext_env.player_current_dead == "Player2"
assert len(ext_env.living_players) == 4
player_names = ["Player0", "Player2"]
ext_env.update_players_state(player_names, RoleState.KILLED)
assert ext_env.get_players_status(player_names) == dict(zip(player_names, [RoleState.KILLED] * 2))
assert ext_env.get_players_state(player_names) == dict(zip(player_names, [RoleState.ALIVE, RoleState.KILLED]))