From 24db19fa13c70f43297de44330d8b7fe3f702ef7 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Sat, 14 Sep 2024 20:33:40 +0800 Subject: [PATCH] fix start task id consistency --- expo/MCTS.py | 7 +++++-- expo/experimenter/experimenter.py | 3 ++- expo/experimenter/mcts.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index ef408b2dd..c96c57b47 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -16,12 +16,11 @@ from metagpt.utils.common import read_json_file def initialize_di_root_node(state, reflection: bool = True): - start_task_id = 2 # state = create_initial_state( # task, start_task_id=start_task_id, data_config=data_config, low_is_better=low_is_better, name=name # ) role = ResearchAssistant( - node_id="0", start_task_id=start_task_id, use_reflection=reflection, role_dir=state["node_dir"] + node_id="0", start_task_id=state["start_task_id"], use_reflection=reflection, role_dir=state["node_dir"] ) return role, Node(parent=None, state=state, action=None, value=0) @@ -208,6 +207,10 @@ class Node: self.raw_reward = score_dict run_finished = True except Exception as e: + print(f"Error: {e}") + import pdb + + pdb.set_trace() mcts_logger.log("MCTS", f"Error in running the role: {e}") num_runs += 1 if not run_finished: diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index b7a0e0b2f..155108f8d 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -13,6 +13,7 @@ from expo.utils import DATA_CONFIG, save_notebook class Experimenter: result_path: str = "results/base" data_config = DATA_CONFIG + start_task_id = 1 def __init__(self, args, **kwargs): self.args = args @@ -20,7 +21,7 @@ class Experimenter: self.start_time = self.start_time_raw.strftime("%Y%m%d%H%M") self.state = create_initial_state( self.args.task, - start_task_id=1, + start_task_id=self.start_task_id, data_config=self.data_config, low_is_better=self.args.low_is_better, name=self.args.name, diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index f0db72841..89f362b6b 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -6,6 +6,7 @@ from expo.MCTS import MCTS class MCTSExperimenter(Experimenter): result_path: str = "results/mcts" + start_task_id = 2 def __init__(self, args, tree_mode=None, **kwargs): super().__init__(args, **kwargs)