make timeout as argument

This commit is contained in:
Yizhou Chi 2024-10-10 18:54:40 +08:00
parent 2fc8f20de6
commit eb460d3e19
7 changed files with 19 additions and 10 deletions

View file

@ -26,7 +26,9 @@ def initialize_di_root_node(state, reflection: bool = True):
return role, Node(parent=None, state=state, action=None, value=0)
def create_initial_state(task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str):
def create_initial_state(
task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args
):
initial_state = {
"task": task,
"work_dir": data_config["work_dir"],
@ -40,6 +42,7 @@ def create_initial_state(task, start_task_id, data_config, low_is_better: bool,
"has_run": False,
"start_task_id": start_task_id,
"low_is_better": low_is_better,
"role_timeout": args.role_timeout,
}
os.makedirs(initial_state["node_dir"], exist_ok=True)
return initial_state

View file

@ -1,4 +1,3 @@
datasets_dir: "D:/work/automl/datasets" # path to the datasets directory
work_dir: ../workspace # path to the workspace directory
role_dir: storage/SELA # path to the role directory
role_timeout: 1000 # timeout for each node/role in seconds
role_dir: storage/SELA # path to the role directory

View file

@ -34,7 +34,9 @@ class AugExperimenter(Experimenter):
results = []
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id=str(i), use_reflection=self.args.reflection)
di = ResearchAssistant(
node_id=str(i), use_reflection=self.args.reflection, role_timeout=self.args.role_timeout
)
di.role_dir = f"{di.role_dir}_{self.args.task}"
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
print(requirement)

View file

@ -27,6 +27,7 @@ class Experimenter:
low_is_better=self.args.low_is_better,
name=self.args.name,
special_instruction=self.args.special_instruction,
args=self.args,
)
async def run_di(self, di, user_requirement, run_idx):
@ -82,7 +83,9 @@ class Experimenter:
results = []
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
di = ResearchAssistant(
node_id="0", use_reflection=self.args.reflection, role_timeout=self.args.role_timeout
)
score_dict = await self.run_di(di, user_requirement, run_idx=i)
results.append(
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}

View file

@ -6,7 +6,7 @@ import os
from pydantic import model_validator
from expo.utils import DATA_CONFIG, mcts_logger, save_notebook
from expo.utils import mcts_logger, save_notebook
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.const import SERDESER_PATH
from metagpt.roles.di.data_interpreter import DataInterpreter
@ -39,13 +39,13 @@ class TimeoutException(Exception):
pass
def async_timeout(seconds):
def async_timeout():
def decorator(func):
async def wrapper(self, *args, **kwargs):
try:
result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=seconds)
result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=self.role_timeout)
except asyncio.TimeoutError:
text = f"Function timed out after {seconds} seconds"
text = f"Function timed out after {self.role_timeout} seconds"
mcts_logger.error(text)
self.save_state()
raise TimeoutException(text)
@ -61,6 +61,7 @@ class ResearchAssistant(DataInterpreter):
start_task_id: int = 1
state_saved: bool = False
role_dir: str = SERDESER_PATH.joinpath("team", "environment", "roles", "Experimenter")
role_timeout: int = 1000
def get_node_name(self):
return f"Node-{self.node_id}"
@ -163,7 +164,7 @@ class ResearchAssistant(DataInterpreter):
self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys())
]
@async_timeout(DATA_CONFIG["role_timeout"])
@async_timeout()
@role_raise_decorator
async def run(self, with_message=None) -> Message | None:
"""Observe, and think and act based on the results of the observation"""

Binary file not shown.

After

Width:  |  Height:  |  Size: 644 KiB

View file

@ -18,6 +18,7 @@ def get_args():
default="mcts",
choices=["mcts", "aug", "base", "custom", "greedy", "autogluon", "random", "autosklearn"],
)
parser.add_argument("--role_timeout", type=int, default=1000)
get_di_args(parser)
get_mcts_args(parser)
get_aug_exp_args(parser)