From fcd1ba66a6ae70f93e7e575f5a9395ebfea5d6ff Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Wed, 4 Sep 2024 16:38:33 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0try=20catch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- expo/MCTS.py | 26 +++++++++++++++++++------- expo/README.md | 5 +++-- expo/experimenter/aug.py | 8 +++----- expo/experimenter/experimenter.py | 31 +++++++++++++++++++++++++++---- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index 9787ea5e9..ab9957a7a 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -194,13 +194,25 @@ class Node(): if self.is_terminal() and role is not None: if role.state_saved: return self.raw_reward - - if not role: - role = self.load_role() - await load_execute_notebook(role) # execute previous notebook's code - await role.run(with_message='continue') - else: - await role.run(with_message=self.state['requirement']) + + max_retries = 3 + num_runs = 1 + run_finished = False + while num_runs <= max_retries and not run_finished: + try: + if not role: + role = self.load_role() + await load_execute_notebook(role) # execute previous notebook's code + await role.run(with_message='continue') + else: + await role.run(with_message=self.state['requirement']) + run_finished = True + except Exception as e: + mcts_logger.log("MCTS", f"Error in running the role: {e}") + num_runs += 1 + if not run_finished: + mcts_logger.log("MCTS", f"Role {role.node_id} failed to run") + return {"test_score": 0, "dev_score": 0, "score": 0} score_dict = await role.get_score() score_dict = self.evaluate_simulation(score_dict) self.raw_reward = score_dict diff --git a/expo/README.md b/expo/README.md index dfaf1ab0a..4cc4daf25 100644 --- a/expo/README.md +++ b/expo/README.md @@ -35,7 +35,8 @@ ### Budget ### 提示词使用 -通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词 +- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词 +- 每一个数据集里有`dataset_info.json`,里面的内容需要提供给baselines以保证公平 ## 3. Evaluation @@ -74,7 +75,7 @@ #### Setup ### Base DI For setup, check 5. -- `python run_experiment.py --exp_mode base --task titanic` +- `python run_experiment.py --exp_mode base --task titanic --num_experiments 10` ### DI RandomSearch diff --git a/expo/experimenter/aug.py b/expo/experimenter/aug.py index 956849717..86c98fd42 100644 --- a/expo/experimenter/aug.py +++ b/expo/experimenter/aug.py @@ -18,8 +18,8 @@ class AugExperimenter(Experimenter): result_path : str = "results/aug" async def run_experiment(self): - state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") - user_requirement = state["requirement"] + # state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") + user_requirement = self.state["requirement"] exp_pool_path = get_exp_pool_path(self.args.task, self.data_config, pool_name="ds_analysis_pool") exp_pool = InstructionGenerator.load_analysis_pool(exp_pool_path) if self.args.aug_mode == "single": @@ -38,9 +38,7 @@ class AugExperimenter(Experimenter): di.role_dir = f"{di.role_dir}_{self.args.task}" requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i]) print(requirement) - await di.run(requirement) - score_dict = await di.get_score() - score_dict = self.evaluate(score_dict, state) + score_dict = await self.run_di(di, requirement) results.append({ "idx": i, "score_dict": score_dict, diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index e53bae972..709eefdfc 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -16,17 +16,40 @@ class Experimenter: def __init__(self, args, **kwargs): self.args = args self.start_time = datetime.datetime.now().strftime("%Y%m%d%H%M") + self.state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") + + + async def run_di(self, di, user_requirement): + max_retries = 3 + num_runs = 1 + run_finished = False + while num_runs <= max_retries and not run_finished: + try: + await di.run(user_requirement) + score_dict = await di.get_score() + score_dict = self.evaluate(score_dict, self.state) + run_finished = True + except Exception as e: + print(f"Error: {e}") + num_runs += 1 + if not run_finished: + score_dict = { + "train_score": -1, + "dev_score": -1, + "test_score": -1, + "score": -1 + } + return score_dict + async def run_experiment(self): - state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") + state = self.state user_requirement = state["requirement"] results = [] for i in range(self.args.num_experiments): di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection) - await di.run(user_requirement) - score_dict = await di.get_score() - score_dict = self.evaluate(score_dict, state) + score_dict = await self.run_di(di, user_requirement) results.append({ "idx": i, "score_dict": score_dict,