From 7b0323e76f11c39ebb027c373abc68f19dd62055 Mon Sep 17 00:00:00 2001 From: better629 Date: Sun, 8 Oct 2023 13:42:55 +0800 Subject: [PATCH] update plan actions --- .../st_game/actions/gen_action_details.py | 6 ++-- .../st_game/actions/gen_hourly_schedule.py | 33 ++++++++++++------ examples/st_game/actions/task_decomp.py | 34 ++++++++++++++++--- examples/st_game/plan/st_plan.py | 21 +++++++++++- 4 files changed, 76 insertions(+), 18 deletions(-) diff --git a/examples/st_game/actions/gen_action_details.py b/examples/st_game/actions/gen_action_details.py index a28116111..a2ebaacd6 100644 --- a/examples/st_game/actions/gen_action_details.py +++ b/examples/st_game/actions/gen_action_details.py @@ -358,9 +358,9 @@ class GenObjEventTriple(STAction): prompt_template = "generate_event_triple_v1.txt" prompt_input = create_prompt_input(act_game_object, act_obj_desp) prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) - self.fail_default_resp = self._func_fail_default_resp(role) + self.fail_default_resp = self._func_fail_default_resp(act_game_object) output = self._run_v1(prompt) - output = (role.name, output[0], output[1]) + output = (act_game_object, output[0], output[1]) return output @@ -404,7 +404,7 @@ class GenActionDetails(STAction): result_dict = { "action_address": new_address, "action_duration": int(act_dura), - "act_desp": act_desp, + "action_description": act_desp, "action_pronunciatio": act_pron, "action_event": act_event, "chatting_with": None, diff --git a/examples/st_game/actions/gen_hourly_schedule.py b/examples/st_game/actions/gen_hourly_schedule.py index 8b3dd358a..f4289d6ef 100644 --- a/examples/st_game/actions/gen_hourly_schedule.py +++ b/examples/st_game/actions/gen_hourly_schedule.py @@ -8,6 +8,7 @@ import string from metagpt.logs import logger from metagpt.schema import Message +from metagpt.config import CONFIG from .st_action import STAction @@ -39,17 +40,19 @@ class GenHourlySchedule(STAction): return False return True - def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: cr = llm_resp.strip() if cr[-1] == ".": cr = cr[:-1] + # to only use the first line of output + cr = cr.split("\n")[0] return cr def _func_fail_default_resp(self) -> int: fs = "asleep" return fs - - def _generate_schedule_for_given_hour(self, role: "STRole", + + def _generate_schedule_for_given_hour(self, role: "STRole", curr_hour_str, p_f_ds_hourly_org, hour_str, @@ -63,13 +66,13 @@ class GenHourlySchedule(STAction): for i in hour_str: schedule_format += f"[{persona.scratch.get_str_curr_date_str()} -- {i}]" schedule_format += f" Activity: [Fill in]\n" - schedule_format = schedule_format[:-1] + schedule_format = schedule_format[:-1] intermission_str = f"Here the originally intended hourly breakdown of" intermission_str += f" {persona.scratch.get_str_firstname()}'s schedule today: " for count, i in enumerate(persona.scratch.daily_req): intermission_str += f"{str(count+1)}) {i}, " - intermission_str = intermission_str[:-2] + intermission_str = intermission_str[:-2] prior_schedule = "" if p_f_ds_hourly_org: @@ -109,21 +112,30 @@ class GenHourlySchedule(STAction): p_f_ds_hourly_org, hour_str, intermission2) - logger.info(f"Role: {role.name} _generate_schedule_for_given_hour prompt_input: {prompt_input}") + prompt_input_str = "\n".join(prompt_input) + raw_max_tokens_rsp = CONFIG.max_tokens_rsp prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) self.fail_default_resp = self._func_fail_default_resp() + + CONFIG.max_tokens_rsp = 50 output = self._run_v1(prompt) + CONFIG.max_tokens_rsp = raw_max_tokens_rsp + + logger.info(f"max_tokens_rsp: {CONFIG.max_tokens_rsp}") + logger.info(f"Role: {role.name} _generate_schedule_for_given_hour prompt_input: {prompt_input_str}, " + f"output: {output}") return output - def run(self, role: "STRole", wake_up_hour: str): + def run(self, role: "STRole", wake_up_hour: int): hour_str = ["00:00 AM", "01:00 AM", "02:00 AM", "03:00 AM", "04:00 AM", "05:00 AM", "06:00 AM", "07:00 AM", "08:00 AM", "09:00 AM", "10:00 AM", "11:00 AM", "12:00 PM", "01:00 PM", "02:00 PM", "03:00 PM", "04:00 PM", "05:00 PM", "06:00 PM", "07:00 PM", "08:00 PM", "09:00 PM", "10:00 PM", "11:00 PM"] n_m1_activity = [] - diversity_repeat_count = 3 - for i in range(diversity_repeat_count): + diversity_repeat_count = 1 # TODO mg 1->3 + for i in range(diversity_repeat_count): + logger.info(f"diversity_repeat_count idx: {i}") n_m1_activity_set = set(n_m1_activity) if len(n_m1_activity_set) < 5: n_m1_activity = [] @@ -131,7 +143,8 @@ class GenHourlySchedule(STAction): if wake_up_hour > 0: n_m1_activity += ["sleeping"] wake_up_hour -= 1 - else: + else: + logger.info(f"_generate_schedule_for_given_hour idx: {count}, n_m1_activity: {n_m1_activity}") n_m1_activity += [self._generate_schedule_for_given_hour( role, curr_hour_str, n_m1_activity, hour_str)] diff --git a/examples/st_game/actions/task_decomp.py b/examples/st_game/actions/task_decomp.py index 3a17f98ac..3a4ffc810 100644 --- a/examples/st_game/actions/task_decomp.py +++ b/examples/st_game/actions/task_decomp.py @@ -72,17 +72,17 @@ class TaskDecomp(STAction): # TODO -- this sometimes generates error try: self._func_cleanup(llm_resp) - except: + except Exception as exp: return False return True def _func_fail_default_resp(self) -> int: - fs = ["asleep"] + fs = [["asleep", 0]] return fs def run(self, role: "STRole", - main_act_dur: int, + task_desc: int, truncated_act_dur: int, *args, **kwargs): @@ -144,10 +144,36 @@ class TaskDecomp(STAction): return prompt_input prompt_input = create_prompt_input(role, - main_act_dur, + task_desc, truncated_act_dur) prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "task_decomp_v3.txt") self.fail_default_resp = self._func_fail_default_resp() output = self._run_v1(prompt) + logger.info(f"Role: {role.name} {self.__class__.__name__} output: {output}") + + fin_output = [] + time_sum = 0 + for i_task, i_duration in output: + time_sum += i_duration + # HM????????? + # if time_sum < duration: + if time_sum <= truncated_act_dur: + fin_output += [[i_task, i_duration]] + else: + break + ftime_sum = 0 + for fi_task, fi_duration in fin_output: + ftime_sum += fi_duration + + # print ("for debugging... line 365", fin_output) + fin_output[-1][1] += (truncated_act_dur - ftime_sum) + output = fin_output + + task_decomp = output + ret = [] + for decomp_task, duration in task_decomp: + ret += [[f"{task_desc} ({decomp_task})", duration]] + output = ret + return output diff --git a/examples/st_game/plan/st_plan.py b/examples/st_game/plan/st_plan.py index 2e96aecc4..39173c600 100644 --- a/examples/st_game/plan/st_plan.py +++ b/examples/st_game/plan/st_plan.py @@ -23,7 +23,7 @@ from ..utils.utils import get_embedding from ..memory.retrieve import new_agent_retrieve -def plan(role: "STRole", maze: Maze, roles: list["STRole"], new_day: bool, retrieved: dict): +def plan(role: "STRole", maze: Maze, roles: list["STRole"], new_day: bool, retrieved: dict) -> str: # PART 1: Generate the hourly schedule. if new_day: _long_term_planning(role, new_day) @@ -61,6 +61,24 @@ def plan(role: "STRole", maze: Maze, roles: list["STRole"], new_day: bool, retri elif reaction_mode[:4] == "wait": _wait_react(role, reaction_mode) + # Step 3: Chat-related state clean up. + # If the persona is not chatting with anyone, we clean up any of the + # chat-related states here. + if role._rc.scratch.act_event[1] != "chat with": + role._rc.scratch.chatting_with = None + role._rc.scratch.chat = None + role._rc.scratch.chatting_end_time = None + # We want to make sure that the persona does not keep conversing with each + # other in an infinite loop. So, chatting_with_buffer maintains a form of + # buffer that makes the persona wait from talking to the same target + # immediately after chatting once. We keep track of the buffer value here. + curr_persona_chat_buffer = role._rc.scratch.chatting_with_buffer + for persona_name, buffer_count in curr_persona_chat_buffer.items(): + if persona_name != role._rc.scratch.chatting_with: + role._rc.scratch.chatting_with_buffer[persona_name] -= 1 + + return role._rc.scratch.act_address + def _choose_retrieved(role_name: str, retrieved: dict) -> Union[None, dict]: """ @@ -534,6 +552,7 @@ def _determine_action(role: "STRole", maze: Maze): curr_index = role.scratch.get_f_daily_schedule_index() curr_index_60 = role.scratch.get_f_daily_schedule_index(advance=60) + logger.info(f"f_daily_schedule: {role.scratch.f_daily_schedule}") # * Decompose * # During the first hour of the day, we need to decompose two hours # sequence. We do that here.