diff --git a/expo/MCTS.py b/expo/MCTS.py index 8e685cc0a..7de123572 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -26,7 +26,9 @@ def initialize_di_root_node(state, reflection: bool = True): return role, Node(parent=None, state=state, action=None, value=0) -def create_initial_state(task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str): +def create_initial_state( + task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args +): initial_state = { "task": task, "work_dir": data_config["work_dir"], @@ -40,6 +42,7 @@ def create_initial_state(task, start_task_id, data_config, low_is_better: bool, "has_run": False, "start_task_id": start_task_id, "low_is_better": low_is_better, + "role_timeout": args.role_timeout, } os.makedirs(initial_state["node_dir"], exist_ok=True) return initial_state diff --git a/expo/data.yaml b/expo/data.yaml index f1556c519..4c6549490 100644 --- a/expo/data.yaml +++ b/expo/data.yaml @@ -1,4 +1,3 @@ datasets_dir: "D:/work/automl/datasets" # path to the datasets directory work_dir: ../workspace # path to the workspace directory -role_dir: storage/SELA # path to the role directory -role_timeout: 1000 # timeout for each node/role in seconds \ No newline at end of file +role_dir: storage/SELA # path to the role directory \ No newline at end of file diff --git a/expo/experimenter/aug.py b/expo/experimenter/aug.py index 97b819802..bcfa5d4ad 100644 --- a/expo/experimenter/aug.py +++ b/expo/experimenter/aug.py @@ -34,7 +34,9 @@ class AugExperimenter(Experimenter): results = [] for i in range(self.args.num_experiments): - di = ResearchAssistant(node_id=str(i), use_reflection=self.args.reflection) + di = ResearchAssistant( + node_id=str(i), use_reflection=self.args.reflection, role_timeout=self.args.role_timeout + ) di.role_dir = f"{di.role_dir}_{self.args.task}" requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i]) print(requirement) diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index c6ead281b..9aa879e24 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -27,6 +27,7 @@ class Experimenter: low_is_better=self.args.low_is_better, name=self.args.name, special_instruction=self.args.special_instruction, + args=self.args, ) async def run_di(self, di, user_requirement, run_idx): @@ -82,7 +83,9 @@ class Experimenter: results = [] for i in range(self.args.num_experiments): - di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection) + di = ResearchAssistant( + node_id="0", use_reflection=self.args.reflection, role_timeout=self.args.role_timeout + ) score_dict = await self.run_di(di, user_requirement, run_idx=i) results.append( {"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)} diff --git a/expo/research_assistant.py b/expo/research_assistant.py index 8fadeb7fb..c574d5b18 100644 --- a/expo/research_assistant.py +++ b/expo/research_assistant.py @@ -6,7 +6,7 @@ import os from pydantic import model_validator -from expo.utils import DATA_CONFIG, mcts_logger, save_notebook +from expo.utils import mcts_logger, save_notebook from metagpt.actions.di.write_analysis_code import WriteAnalysisCode from metagpt.const import SERDESER_PATH from metagpt.roles.di.data_interpreter import DataInterpreter @@ -39,13 +39,13 @@ class TimeoutException(Exception): pass -def async_timeout(seconds): +def async_timeout(): def decorator(func): async def wrapper(self, *args, **kwargs): try: - result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=seconds) + result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=self.role_timeout) except asyncio.TimeoutError: - text = f"Function timed out after {seconds} seconds" + text = f"Function timed out after {self.role_timeout} seconds" mcts_logger.error(text) self.save_state() raise TimeoutException(text) @@ -61,6 +61,7 @@ class ResearchAssistant(DataInterpreter): start_task_id: int = 1 state_saved: bool = False role_dir: str = SERDESER_PATH.joinpath("team", "environment", "roles", "Experimenter") + role_timeout: int = 1000 def get_node_name(self): return f"Node-{self.node_id}" @@ -163,7 +164,7 @@ class ResearchAssistant(DataInterpreter): self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys()) ] - @async_timeout(DATA_CONFIG["role_timeout"]) + @async_timeout() @role_raise_decorator async def run(self, with_message=None) -> Message | None: """Observe, and think and act based on the results of the observation""" diff --git a/expo/resources/MCTS-Experimenter.jpg b/expo/resources/MCTS-Experimenter.jpg new file mode 100644 index 000000000..bbae98ee3 Binary files /dev/null and b/expo/resources/MCTS-Experimenter.jpg differ diff --git a/expo/run_experiment.py b/expo/run_experiment.py index fbd05d776..49d058f13 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -18,6 +18,7 @@ def get_args(): default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "autogluon", "random", "autosklearn"], ) + parser.add_argument("--role_timeout", type=int, default=1000) get_di_args(parser) get_mcts_args(parser) get_aug_exp_args(parser)