update plan actions

This commit is contained in:
better629 2023-10-08 13:42:55 +08:00
parent a150842619
commit 7b0323e76f
4 changed files with 76 additions and 18 deletions

View file

@ -358,9 +358,9 @@ class GenObjEventTriple(STAction):
prompt_template = "generate_event_triple_v1.txt"
prompt_input = create_prompt_input(act_game_object, act_obj_desp)
prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template)
self.fail_default_resp = self._func_fail_default_resp(role)
self.fail_default_resp = self._func_fail_default_resp(act_game_object)
output = self._run_v1(prompt)
output = (role.name, output[0], output[1])
output = (act_game_object, output[0], output[1])
return output
@ -404,7 +404,7 @@ class GenActionDetails(STAction):
result_dict = {
"action_address": new_address,
"action_duration": int(act_dura),
"act_desp": act_desp,
"action_description": act_desp,
"action_pronunciatio": act_pron,
"action_event": act_event,
"chatting_with": None,

View file

@ -8,6 +8,7 @@ import string
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.config import CONFIG
from .st_action import STAction
@ -39,17 +40,19 @@ class GenHourlySchedule(STAction):
return False
return True
def _func_cleanup(self, llm_resp: str, prompt: str) -> list:
def _func_cleanup(self, llm_resp: str, prompt: str) -> list:
cr = llm_resp.strip()
if cr[-1] == ".":
cr = cr[:-1]
# to only use the first line of output
cr = cr.split("\n")[0]
return cr
def _func_fail_default_resp(self) -> int:
fs = "asleep"
return fs
def _generate_schedule_for_given_hour(self, role: "STRole",
def _generate_schedule_for_given_hour(self, role: "STRole",
curr_hour_str,
p_f_ds_hourly_org,
hour_str,
@ -63,13 +66,13 @@ class GenHourlySchedule(STAction):
for i in hour_str:
schedule_format += f"[{persona.scratch.get_str_curr_date_str()} -- {i}]"
schedule_format += f" Activity: [Fill in]\n"
schedule_format = schedule_format[:-1]
schedule_format = schedule_format[:-1]
intermission_str = f"Here the originally intended hourly breakdown of"
intermission_str += f" {persona.scratch.get_str_firstname()}'s schedule today: "
for count, i in enumerate(persona.scratch.daily_req):
intermission_str += f"{str(count+1)}) {i}, "
intermission_str = intermission_str[:-2]
intermission_str = intermission_str[:-2]
prior_schedule = ""
if p_f_ds_hourly_org:
@ -109,21 +112,30 @@ class GenHourlySchedule(STAction):
p_f_ds_hourly_org,
hour_str,
intermission2)
logger.info(f"Role: {role.name} _generate_schedule_for_given_hour prompt_input: {prompt_input}")
prompt_input_str = "\n".join(prompt_input)
raw_max_tokens_rsp = CONFIG.max_tokens_rsp
prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template)
self.fail_default_resp = self._func_fail_default_resp()
CONFIG.max_tokens_rsp = 50
output = self._run_v1(prompt)
CONFIG.max_tokens_rsp = raw_max_tokens_rsp
logger.info(f"max_tokens_rsp: {CONFIG.max_tokens_rsp}")
logger.info(f"Role: {role.name} _generate_schedule_for_given_hour prompt_input: {prompt_input_str}, "
f"output: {output}")
return output
def run(self, role: "STRole", wake_up_hour: str):
def run(self, role: "STRole", wake_up_hour: int):
hour_str = ["00:00 AM", "01:00 AM", "02:00 AM", "03:00 AM", "04:00 AM",
"05:00 AM", "06:00 AM", "07:00 AM", "08:00 AM", "09:00 AM",
"10:00 AM", "11:00 AM", "12:00 PM", "01:00 PM", "02:00 PM",
"03:00 PM", "04:00 PM", "05:00 PM", "06:00 PM", "07:00 PM",
"08:00 PM", "09:00 PM", "10:00 PM", "11:00 PM"]
n_m1_activity = []
diversity_repeat_count = 3
for i in range(diversity_repeat_count):
diversity_repeat_count = 1 # TODO mg 1->3
for i in range(diversity_repeat_count):
logger.info(f"diversity_repeat_count idx: {i}")
n_m1_activity_set = set(n_m1_activity)
if len(n_m1_activity_set) < 5:
n_m1_activity = []
@ -131,7 +143,8 @@ class GenHourlySchedule(STAction):
if wake_up_hour > 0:
n_m1_activity += ["sleeping"]
wake_up_hour -= 1
else:
else:
logger.info(f"_generate_schedule_for_given_hour idx: {count}, n_m1_activity: {n_m1_activity}")
n_m1_activity += [self._generate_schedule_for_given_hour(
role, curr_hour_str, n_m1_activity, hour_str)]

View file

@ -72,17 +72,17 @@ class TaskDecomp(STAction):
# TODO -- this sometimes generates error
try:
self._func_cleanup(llm_resp)
except:
except Exception as exp:
return False
return True
def _func_fail_default_resp(self) -> int:
fs = ["asleep"]
fs = [["asleep", 0]]
return fs
def run(self,
role: "STRole",
main_act_dur: int,
task_desc: int,
truncated_act_dur: int,
*args, **kwargs):
@ -144,10 +144,36 @@ class TaskDecomp(STAction):
return prompt_input
prompt_input = create_prompt_input(role,
main_act_dur,
task_desc,
truncated_act_dur)
prompt = self.generate_prompt_with_tmpl_filename(prompt_input,
"task_decomp_v3.txt")
self.fail_default_resp = self._func_fail_default_resp()
output = self._run_v1(prompt)
logger.info(f"Role: {role.name} {self.__class__.__name__} output: {output}")
fin_output = []
time_sum = 0
for i_task, i_duration in output:
time_sum += i_duration
# HM?????????
# if time_sum < duration:
if time_sum <= truncated_act_dur:
fin_output += [[i_task, i_duration]]
else:
break
ftime_sum = 0
for fi_task, fi_duration in fin_output:
ftime_sum += fi_duration
# print ("for debugging... line 365", fin_output)
fin_output[-1][1] += (truncated_act_dur - ftime_sum)
output = fin_output
task_decomp = output
ret = []
for decomp_task, duration in task_decomp:
ret += [[f"{task_desc} ({decomp_task})", duration]]
output = ret
return output

View file

@ -23,7 +23,7 @@ from ..utils.utils import get_embedding
from ..memory.retrieve import new_agent_retrieve
def plan(role: "STRole", maze: Maze, roles: list["STRole"], new_day: bool, retrieved: dict):
def plan(role: "STRole", maze: Maze, roles: list["STRole"], new_day: bool, retrieved: dict) -> str:
# PART 1: Generate the hourly schedule.
if new_day:
_long_term_planning(role, new_day)
@ -61,6 +61,24 @@ def plan(role: "STRole", maze: Maze, roles: list["STRole"], new_day: bool, retri
elif reaction_mode[:4] == "wait":
_wait_react(role, reaction_mode)
# Step 3: Chat-related state clean up.
# If the persona is not chatting with anyone, we clean up any of the
# chat-related states here.
if role._rc.scratch.act_event[1] != "chat with":
role._rc.scratch.chatting_with = None
role._rc.scratch.chat = None
role._rc.scratch.chatting_end_time = None
# We want to make sure that the persona does not keep conversing with each
# other in an infinite loop. So, chatting_with_buffer maintains a form of
# buffer that makes the persona wait from talking to the same target
# immediately after chatting once. We keep track of the buffer value here.
curr_persona_chat_buffer = role._rc.scratch.chatting_with_buffer
for persona_name, buffer_count in curr_persona_chat_buffer.items():
if persona_name != role._rc.scratch.chatting_with:
role._rc.scratch.chatting_with_buffer[persona_name] -= 1
return role._rc.scratch.act_address
def _choose_retrieved(role_name: str, retrieved: dict) -> Union[None, dict]:
"""
@ -534,6 +552,7 @@ def _determine_action(role: "STRole", maze: Maze):
curr_index = role.scratch.get_f_daily_schedule_index()
curr_index_60 = role.scratch.get_f_daily_schedule_index(advance=60)
logger.info(f"f_daily_schedule: {role.scratch.f_daily_schedule}")
# * Decompose *
# During the first hour of the day, we need to decompose two hours
# sequence. We do that here.