diff --git a/metagpt/environment/werewolf/const.py b/metagpt/environment/werewolf/const.py index 32af660ad..7f810389d 100644 --- a/metagpt/environment/werewolf/const.py +++ b/metagpt/environment/werewolf/const.py @@ -59,7 +59,7 @@ Or you can pass. For example: Protect ...""", }, 5: { "content": """Werewolves, I secretly tell you that {werewolf_players} are -all of the 2 werewolves! Keep in mind you are teammates. The rest players are not werewolves. +all of the {werewolf_num} werewolves! Keep in mind you are teammates. The rest players are not werewolves. choose one from the following living options please: {living_players}. For example: Kill ...""", "send_to": {RoleType.WEREWOLF.value}, diff --git a/metagpt/environment/werewolf/werewolf_ext_env.py b/metagpt/environment/werewolf/werewolf_ext_env.py index 835981481..588fc0b9b 100644 --- a/metagpt/environment/werewolf/werewolf_ext_env.py +++ b/metagpt/environment/werewolf/werewolf_ext_env.py @@ -52,9 +52,11 @@ class WerewolfExtEnv(ExtEnv): seed: Optional[int] = None, options: Optional[dict[str, Any]] = None, ) -> tuple[dict[str, Any], dict[str, Any]]: + """currently unused""" pass def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """currently unused""" pass def _get_obs(self): @@ -62,7 +64,7 @@ class WerewolfExtEnv(ExtEnv): "game_setup": self.game_setup, "step_idx": self.step_idx, "living_players": self.living_players, - "werewolf_players": self.werewolf_players, + "werewolf_players": self.werewolf_players, # currently, lack observation isolation "player_hunted": self.player_hunted, "player_current_dead": self.player_current_dead, "witch_poison_left": self.witch_poison_left, diff --git a/metagpt/ext/werewolf/actions/common_actions.py b/metagpt/ext/werewolf/actions/common_actions.py index 0f1b3b74c..63afeede0 100644 --- a/metagpt/ext/werewolf/actions/common_actions.py +++ b/metagpt/ext/werewolf/actions/common_actions.py @@ -11,6 +11,14 @@ from metagpt.logs import logger from metagpt.utils.common import parse_json_code_block +def log_and_parse_json(name: str, rsp: str) -> dict: + rsp = rsp.replace("\n", " ") + logger.debug(f"{name} result: {rsp}") + json_blocks = parse_json_code_block(rsp) + rsp_json = json.loads(json_blocks[0]) + return rsp_json + + class Speak(Action): """Action: Any speak action in a game""" @@ -66,8 +74,7 @@ class Speak(Action): ) rsp = await self._aask(prompt) - rsp = rsp.replace("\n", " ") - rsp_json = json.loads(rsp) + rsp_json = log_and_parse_json(self.name, rsp) return rsp_json["RESPONSE"] @@ -183,8 +190,7 @@ class NighttimeWhispers(Action): ) rsp = await self._aask(prompt) - rsp = rsp.replace("\n", " ") - rsp_json = json.loads(rsp) + rsp_json = log_and_parse_json(self.name, rsp) return f"{self.name} " + rsp_json["RESPONSE"] @@ -229,9 +235,6 @@ class Reflect(Action): ) rsp = await self._aask(prompt) - rsp = rsp.replace("\n", " ") - logger.debug(f"{self.name} result: {rsp}") - json_blocks = parse_json_code_block(rsp) - rsp_json = json.loads(json_blocks[0]) + rsp_json = log_and_parse_json(self.name, rsp) return json.dumps(rsp_json["REFLECTION"]) diff --git a/metagpt/ext/werewolf/actions/moderator_actions.py b/metagpt/ext/werewolf/actions/moderator_actions.py index c61397892..8f37e3bc9 100644 --- a/metagpt/ext/werewolf/actions/moderator_actions.py +++ b/metagpt/ext/werewolf/actions/moderator_actions.py @@ -11,7 +11,8 @@ class InstructSpeak(Action): ) content = instruction_info["content"] if "{living_players}" in content and "{werewolf_players}" in content: - content = content.format(living_players=living_players, werewolf_players=werewolf_players) + content = content.format(living_players=living_players, werewolf_players=werewolf_players, + werewolf_num=len(werewolf_players)) if "{living_players}" in content: content = content.format(living_players=living_players) if "{werewolf_players}" in content: diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e443c3466..bd8d25013 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -722,7 +722,8 @@ def list_files(root: str | Path) -> List[Path]: def parse_json_code_block(markdown_text: str) -> List[str]: - json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) + json_blocks = re.findall(r"```json(.*?)```", markdown_text, re.DOTALL) if "```json" in markdown_text else [markdown_text] + return [v.strip() for v in json_blocks]