diff --git a/examples/werewolf_game/evals/eval.py b/examples/werewolf_game/evals/eval.py index 3093d80f2..c890773de 100644 --- a/examples/werewolf_game/evals/eval.py +++ b/examples/werewolf_game/evals/eval.py @@ -17,6 +17,7 @@ from tqdm import tqdm from utils import Utils from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT +from metagpt.environment.werewolf.const import RoleType class Vote: @@ -92,20 +93,20 @@ class Vote: # find all werewolves werewolves = [] for match in pattern.finditer(text): - if match.group(2) == "Werewolf": + if match.group(2) == RoleType.WEREWOLF.value: werewolves.append(match.group(1)) # find all non_werewolves non_werewolves = [] for match in pattern.finditer(text): - if match.group(2) != "Werewolf": + if match.group(2) != RoleType.WEREWOLF.value: non_werewolves.append(match.group(1)) num_non_werewolves = len(non_werewolves) # count players other than werewolves made the correct votes correct_votes = 0 for match in pattern.finditer(text): - if match.group(2) != "Werewolf" and match.group(3) in werewolves: + if match.group(2) != RoleType.WEREWOLF.value and match.group(3) in werewolves: correct_votes += 1 # cal the rateability of non_werewolves diff --git a/metagpt/environment/werewolf/werewolf_ext_env.py b/metagpt/environment/werewolf/werewolf_ext_env.py index 588fc0b9b..a8636536b 100644 --- a/metagpt/environment/werewolf/werewolf_ext_env.py +++ b/metagpt/environment/werewolf/werewolf_ext_env.py @@ -166,9 +166,9 @@ class WerewolfExtEnv(ExtEnv): """init players using different roles' num""" role_objs = [] for role_obj in role_uniq_objs: - if "Villager" in str(role_obj): + if RoleType.VILLAGER.value in str(role_obj): role_objs.extend([role_obj] * num_villager) - elif "Werewolf" in str(role_obj): + elif RoleType.WEREWOLF.value in str(role_obj): role_objs.extend([role_obj] * num_werewolf) else: role_objs.append(role_obj) diff --git a/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py b/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py index d515c3159..986d55e1a 100644 --- a/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py +++ b/tests/metagpt/environment/werewolf_env/test_werewolf_ext_env.py @@ -2,34 +2,34 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of WerewolfExtEnv -from metagpt.environment.werewolf.const import RoleState +from metagpt.environment.werewolf.const import RoleState, RoleType from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv from metagpt.roles.role import Role class Werewolf(Role): - profile: str = "Werewolf" + profile: str = RoleType.WEREWOLF.value class Villager(Role): - profile: str = "Villager" + profile: str = RoleType.VILLAGER.value class Witch(Role): - profile: str = "Witch" + profile: str = RoleType.WITCH.value class Guard(Role): - profile: str = "Guard" + profile: str = RoleType.GUARD.value def test_werewolf_ext_env(): players_state = { - "Player0": ("Werewolf", RoleState.ALIVE), - "Player1": ("Werewolf", RoleState.ALIVE), - "Player2": ("Villager", RoleState.ALIVE), - "Player3": ("Witch", RoleState.ALIVE), - "Player4": ("Guard", RoleState.ALIVE), + "Player0": (RoleType.WEREWOLF.value, RoleState.ALIVE), + "Player1": (RoleType.WEREWOLF.value, RoleState.ALIVE), + "Player2": (RoleType.VILLAGER.value, RoleState.ALIVE), + "Player3": (RoleType.WITCH.value, RoleState.ALIVE), + "Player4": (RoleType.GUARD.value, RoleState.ALIVE), } ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"])