diff --git a/expo/MCTS.py b/expo/MCTS.py index 4564cd682..9d778e4ed 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -10,7 +10,7 @@ import pandas as pd from expo.data.dataset import generate_task_requirement, get_split_dataset_path from expo.evaluation.evaluation import evaluate_score from expo.insights.instruction_generator import InstructionGenerator -from expo.research_assistant import ResearchAssistant +from expo.research_assistant import ResearchAssistant, TimeoutException from expo.utils import get_exp_pool_path, load_execute_notebook, mcts_logger from metagpt.tools.tool_recommend import ToolRecommender from metagpt.utils.common import read_json_file @@ -26,7 +26,9 @@ def initialize_di_root_node(state, reflection: bool = True): return role, Node(parent=None, state=state, action=None, value=0) -def create_initial_state(task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str): +def create_initial_state( + task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args +): initial_state = { "task": task, "work_dir": data_config["work_dir"], @@ -40,6 +42,7 @@ def create_initial_state(task, start_task_id, data_config, low_is_better: bool, "has_run": False, "start_task_id": start_task_id, "low_is_better": low_is_better, + "role_timeout": args.role_timeout, } os.makedirs(initial_state["node_dir"], exist_ok=True) return initial_state @@ -152,18 +155,15 @@ class Node: role = role.model_copy() role.save_state(static_save=True) - async def expand(self, max_children, use_fixed_insights): + async def expand(self, max_children: int, instruction_generator: InstructionGenerator): if self.is_fully_expanded(): return - insight_geneartor = InstructionGenerator() role = self.load_role() original_instruction = role.get_next_instruction() - insights = await insight_geneartor.generate_new_instructions( + insights = await instruction_generator.generate_new_instructions( task_id=role.start_task_id + 1, original_instruction=original_instruction, max_num=max_children, - file_path=self.state["exp_pool_path"], - use_fixed_insights=use_fixed_insights, ) new_state = self.state.copy() new_state["start_task_id"] += 1 @@ -211,10 +211,14 @@ class Node: score_dict = self.evaluate_simulation(score_dict) self.raw_reward = score_dict run_finished = True + except TimeoutException as e: + mcts_logger.log("MCTS", f"Role-level timeout: {e}") + break except Exception as e: print(f"Error: {e}") mcts_logger.log("MCTS", f"Error in running the role: {e}") num_runs += 1 + if not run_finished: mcts_logger.log("MCTS", f"Role {role.node_id} failed to run") if self.state["low_is_better"]: @@ -242,6 +246,8 @@ class MCTS: c_explore: float = 1.4 c_unvisited: float = 0.8 node_order: list = [] + # insight generator + instruction_generator: InstructionGenerator = None def __init__(self, root_node, max_depth, use_fixed_insights): self.root_node = root_node @@ -265,7 +271,7 @@ class MCTS: return max(all_children, key=uct) async def expand(self, node: Node, max_children=5): - await node.expand(max_children, self.use_fixed_insights) + await node.expand(max_children, self.instruction_generator) if node not in self.children or not self.children[node]: self.children[node] = node.children return node.children @@ -277,6 +283,7 @@ class MCTS: node = random.choice(node.children) reward = await node.run_node(role) mcts_logger.log("MCTS", f"Simulated node's reward: {reward}") + return reward def backpropagate(self, node: Node, reward): @@ -337,6 +344,10 @@ class MCTS: async def search(self, state, rollouts, load_tree=False, reflection=False): role, root = initialize_di_root_node(state, reflection=reflection) self.root_node = root + self.instruction_generator = InstructionGenerator( + file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights + ) + tree_loaded = False if load_tree: tree_loaded = self.load_tree() diff --git a/expo/README.md b/expo/README.md index e5da96708..598de039d 100644 --- a/expo/README.md +++ b/expo/README.md @@ -1,21 +1,20 @@ -# Expo +# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning ## 1. Data Preparation -- 下载数据集:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink -- 修改`data.yaml`的`datasets_dir`为数据集合集根目录存储位置 +- Download Datasets:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink ## 2. Configs ### Data Config -`datasets.yaml` 提供数据集对应的指标和基础提示词 +`datasets.yaml` Provide base prompts, metrics, target columns for respective datasets -`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下 +- Modify `datasets_dir` to the root directory of all the datasets in `data.yaml` ### LLM Config @@ -30,28 +29,64 @@ ### LLM Config ``` ### Budget -实验轮次 k = 10, 20 +Experiment rollouts k = 5, 10, 20 ### Prompt Usage -- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词 - - 非DI-based方法设置`is_di=False` - - `data_config`用`utils.DATA_CONFIG` -- 每一个数据集里有`dataset_info.json`,里面的内容需要提供给baselines以保证公平(`generate_task_requirement`已经默认提供) +- Use the function `generate_task_requirement` in `dataset.py` to get task requirement. + - If the method is non-DI-based, set `is_di=False`. + - Use `utils.DATA_CONFIG` as `data_config` -## 3. Evaluation +## 3. SELA -运行各个框架,运行后框架需要提供Dev和Test的`dev_predictions.csv`和`test_predictions.csv`,每个csv文件只需要单个名为target的列 +### Run SELA + +#### Setup +In the root directory, -- 使用`CustomExperimenter` ``` -experimenter = CustomExperimenter(task="titanic") -score_dict = experimenter.evaluate_pred_files(dev_pred_path, test_pred_path) +pip install -e . + +cd expo + +pip install -r requirements.txt ``` -## 4. Baselines +#### Run + +- `python run_experiment.py --exp_mode mcts --task titanic --rollouts 10` + +If the dataset has reg metric, remember to use `--low_is_better`: + +- `python run_experiment.py --exp_mode mcts --task house_prices --rollouts 10 --low_is_better` + + +In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json` +- `--use_fixed_insights` + + + +#### Ablation Study + +**DI RandomSearch** + +- Single insight +`python run_experiment.py --exp_mode aug --task titanic --aug_mode single` + +- Set insight +`python run_experiment.py --exp_mode aug --task titanic --aug_mode set` + + +## 4. Evaluation + +Each baseline needs to produce `dev_predictions.csv`和`test_predictions.csv`. Each csv file only needs a `target` column. + +- Use the function `evaluate_score` to evaluate. + + +## 5. Baselines ### DS Agent ``` git clone https://github.com/guosyjlu/DS-Agent.git @@ -257,53 +292,12 @@ #### Run ``` ### Base DI -For setup, check 5. - +For setup, check 4. - `python run_experiment.py --exp_mode base --task titanic --num_experiments 10` -- Ask DI to use AutoGluon: `--special_instruction ag` -- Ask DI to use the stacking ensemble method: `--special_instruction stacking` - - - - -## 5. DI MCTS - -### Run DI MCTS - -#### Setup -In the root directory, - -``` -pip install -e . - -cd expo - -pip install -r requirements.txt -``` - -#### Run - -- `python run_experiment.py --exp_mode mcts --task titanic --rollout 10` - -If the dataset has reg metric, remember to use `--low_is_better`: - -- `python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 10 --low_is_better` - - -In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json` -- `--use_fixed_insights` - - - -#### Ablation Study - -**DI RandomSearch** - -- Single insight -`python run_experiment.py --exp_mode aug --task titanic --aug_mode single` - -- Set insight -`python run_experiment.py --exp_mode aug --task titanic --aug_mode set` +- Specifically instruct DI to use AutoGluon: `--special_instruction ag` +- Specifically instruct DI to use the stacking ensemble method: `--special_instruction stacking` + + diff --git a/expo/data.yaml b/expo/data.yaml index d62e45309..4c6549490 100644 --- a/expo/data.yaml +++ b/expo/data.yaml @@ -1,160 +1,3 @@ datasets_dir: "D:/work/automl/datasets" # path to the datasets directory - -datasets: - titanic: - dataset: 04_titanic - metric: f1 - target_col: Survived - user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\ - \ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - house-prices: - dataset: 05_house-prices-advanced-regression-techniques - metric: rmse - target_col: SalePrice - user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\ - \ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\ - \ data preprocessing, feature engineering, and modeling to predict the target.\ - \ \nReport rmse on the eval data. Do not plot or make any visualizations.\n" - santander-customer: - dataset: 06_santander-customer-transaction-prediction - metric: f1 - target_col: target - user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\ - \ Your goal is to predict the target column `target`.\nPerform data analysis,\ - \ data preprocessing, feature engineering, and modeling to predict the target.\ - \ \nReport f1 on the eval data. Do not plot or make any visualizations.\n" - icr: - dataset: 07_icr-identify-age-related-conditions - metric: f1 - target_col: Class - user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\ - \ goal is to predict the target column `Class`.\nPerform data analysis, data\ - \ preprocessing, feature engineering, and modeling to predict the target. \n\ - Report f1 on the eval data. Do not plot or make any visualizations.\n" - Click_prediction_small: - dataset: Click_prediction_small - metric: f1 - target_col: click - user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\ - \ the target column `click`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ - \ Do not plot or make any visualizations.\n" - GesturePhaseSegmentationProcessed: - dataset: GesturePhaseSegmentationProcessed - metric: f1 weighted - target_col: Phase - user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\ - \ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\ - \ feature engineering, and modeling to predict the target. \nReport f1 weighted\ - \ on the eval data. Do not plot or make any visualizations.\n" - Moneyball: - dataset: Moneyball - metric: rmse - target_col: RS - user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\ - \ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ - \ plot or make any visualizations.\n" - SAT11-HAND-runtime-regression: - dataset: SAT11-HAND-runtime-regression - metric: rmse - target_col: runtime - user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\ - \ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\ - \ feature engineering, and modeling to predict the target. \nReport rmse on\ - \ the eval data. Do not plot or make any visualizations.\n" - boston: - dataset: boston - metric: rmse - target_col: MEDV - user_requirement: "This is a boston dataset. Your goal is to predict the target\ - \ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ - \ plot or make any visualizations.\n" - colleges: - dataset: colleges - metric: rmse - target_col: percent_pell_grant - user_requirement: "This is a colleges dataset. Your goal is to predict the target\ - \ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport rmse on the eval\ - \ data. Do not plot or make any visualizations.\n" - credit-g: - dataset: credit-g - metric: f1 - target_col: class - user_requirement: "This is a credit-g dataset. Your goal is to predict the target\ - \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - diamonds: - dataset: diamonds - metric: rmse - target_col: price - user_requirement: "This is a diamonds dataset. Your goal is to predict the target\ - \ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ - \ plot or make any visualizations.\n" - jasmine: - dataset: jasmine - metric: f1 - target_col: class - user_requirement: "This is a jasmine dataset. Your goal is to predict the target\ - \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - kc1: - dataset: kc1 - metric: f1 - target_col: defects - user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\ - \ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - kick: - dataset: kick - metric: f1 - target_col: IsBadBuy - user_requirement: "This is a kick dataset. Your goal is to predict the target\ - \ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - mfeat-factors: - dataset: mfeat-factors - metric: f1 weighted - target_col: class - user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\ - \ target column `class`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ - \ eval data. Do not plot or make any visualizations.\n" - segment: - dataset: segment - metric: f1 weighted - target_col: class - user_requirement: "This is a segment dataset. Your goal is to predict the target\ - \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ - \ Do not plot or make any visualizations.\n" - steel-plates-fault: - dataset: steel-plates-fault - metric: f1 weighted - target_col: target - user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\ - \ the target column `target`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ - \ eval data. Do not plot or make any visualizations.\n" - wine-quality-white: - dataset: wine-quality-white - metric: f1 weighted - target_col: Class - user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\ - \ the target column `Class`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ - \ eval data. Do not plot or make any visualizations.\n" - - work_dir: ../workspace # path to the workspace directory -role_dir: storage/team/environment/roles/ResearchAssistant_David -# analysis_pool_dir: D:/work/MG-open/MetaGPT/examples/MCTS_test/analysis_pool_sample.json \ No newline at end of file +role_dir: storage/SELA # path to the role directory \ No newline at end of file diff --git a/expo/data/dataset.py b/expo/data/dataset.py index dd4cb4543..e076284d6 100644 --- a/expo/data/dataset.py +++ b/expo/data/dataset.py @@ -9,6 +9,7 @@ import yaml from sklearn.model_selection import train_test_split from expo.insights.solution_designer import SolutionDesigner +from expo.utils import DATA_CONFIG BASE_USER_REQUIREMENT = """ This is a {datasetname} dataset. Your goal is to predict the target column `{target_col}`. @@ -361,7 +362,7 @@ async def process_dataset(dataset, solution_designer: SolutionDesigner, save_ana if __name__ == "__main__": - datasets_dir = "D:/work/automl/datasets" + datasets_dir = DATA_CONFIG["datasets_dir"] force_update = False save_analysis_pool = True datasets_dict = {"datasets": {}} diff --git a/expo/data/hf_data.py b/expo/data/hf_data.py index a43fcd415..133fbdfa6 100644 --- a/expo/data/hf_data.py +++ b/expo/data/hf_data.py @@ -9,6 +9,7 @@ from PIL import Image from expo.data.dataset import ExpDataset, process_dataset, save_datasets_dict_to_yaml from expo.insights.solution_designer import SolutionDesigner +from expo.utils import DATA_CONFIG HFDATSETS = [ {"name": "sms_spam", "dataset_name": "ucirvine/sms_spam", "target_col": "label", "modality": "text"}, @@ -114,7 +115,7 @@ class HFExpDataset(ExpDataset): if __name__ == "__main__": - dataset_dir = "D:/work/automl/datasets" + dataset_dir = DATA_CONFIG["datasets_dir"] save_analysis_pool = True force_update = False datasets_dict = {"datasets": {}} diff --git a/expo/experimenter/aug.py b/expo/experimenter/aug.py index 97b819802..bcfa5d4ad 100644 --- a/expo/experimenter/aug.py +++ b/expo/experimenter/aug.py @@ -34,7 +34,9 @@ class AugExperimenter(Experimenter): results = [] for i in range(self.args.num_experiments): - di = ResearchAssistant(node_id=str(i), use_reflection=self.args.reflection) + di = ResearchAssistant( + node_id=str(i), use_reflection=self.args.reflection, role_timeout=self.args.role_timeout + ) di.role_dir = f"{di.role_dir}_{self.args.task}" requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i]) print(requirement) diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index c6ead281b..9aa879e24 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -27,6 +27,7 @@ class Experimenter: low_is_better=self.args.low_is_better, name=self.args.name, special_instruction=self.args.special_instruction, + args=self.args, ) async def run_di(self, di, user_requirement, run_idx): @@ -82,7 +83,9 @@ class Experimenter: results = [] for i in range(self.args.num_experiments): - di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection) + di = ResearchAssistant( + node_id="0", use_reflection=self.args.reflection, role_timeout=self.args.role_timeout + ) score_dict = await self.run_di(di, user_requirement, run_idx=i) results.append( {"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)} diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 22b480caf..c063268c8 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -8,9 +8,12 @@ from expo.MCTS import MCTS class MCTSExperimenter(Experimenter): result_path: str = "results/mcts" - start_task_id = 2 def __init__(self, args, tree_mode=None, **kwargs): + if args.special_instruction == "image": + self.start_task_id = 1 # start from datapreprocessing if it is image task + else: + self.start_task_id = args.start_task_id super().__init__(args, **kwargs) self.tree_mode = tree_mode diff --git a/expo/insights/instruction_generator.py b/expo/insights/instruction_generator.py index 07e5fb655..330795730 100644 --- a/expo/insights/instruction_generator.py +++ b/expo/insights/instruction_generator.py @@ -2,6 +2,7 @@ import json import os import random +from expo.insights.solution_designer import SolutionDesigner from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger from metagpt.llm import LLM from metagpt.schema import Message @@ -32,6 +33,12 @@ DATA_CONFIG = load_data_config() class InstructionGenerator: data_config = DATA_CONFIG + def __init__(self, file_path, use_fixed_insights=False): + self.file_path = file_path + self.use_fixed_insights = use_fixed_insights + self.analysis_pool = self.load_insight_pool(file_path, use_fixed_insights) + self.proposer = SolutionDesigner() + @staticmethod def load_json_data(json_dir): with open(json_dir, "r") as file: @@ -69,7 +76,7 @@ class InstructionGenerator: return new_data @staticmethod - def load_analysis_pool(file_path, use_fixed_insights, task_id=None): + def load_insight_pool(file_path, use_fixed_insights, task_id=None): data = InstructionGenerator.load_json_data(file_path) if use_fixed_insights: current_directory = os.path.dirname(__file__) @@ -83,13 +90,8 @@ class InstructionGenerator: data = [item for item in data if int(item["task_id"]) == int(task_id)] return data - @staticmethod - async def generate_new_instructions( - task_id, original_instruction, max_num, file_path, ext_info=None, use_fixed_insights=False - ): - data = InstructionGenerator.load_analysis_pool( - file_path, task_id=task_id, use_fixed_insights=use_fixed_insights - ) + async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None): + data = self.analysis_pool new_instructions = [] if len(data) == 0: mcts_logger.log("MCTS", f"No insights available for task {task_id}") diff --git a/expo/insights/solution_designer.py b/expo/insights/solution_designer.py index b1fcf4188..9968131ca 100644 --- a/expo/insights/solution_designer.py +++ b/expo/insights/solution_designer.py @@ -21,6 +21,7 @@ The insights should be proposed based on the dataset description with different Each task type should have at least 5 insights. Make sure each method is diverse enough and can be implemented separately. Be specific about models' choices, ensemble and tuning techniques, and preprocessing & feature engineering techniques. +Your model choices should be advanced enough to be helpful. # Format ```json diff --git a/expo/research_assistant.py b/expo/research_assistant.py index 51de188d3..fb34ece38 100644 --- a/expo/research_assistant.py +++ b/expo/research_assistant.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import os @@ -34,11 +35,33 @@ If you cannot find the scores, please still return a dictionary with the keys 't """ +class TimeoutException(Exception): + pass + + +def async_timeout(): + def decorator(func): + async def wrapper(self, *args, **kwargs): + try: + result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=self.role_timeout) + except asyncio.TimeoutError: + text = f"Function timed out after {self.role_timeout} seconds" + mcts_logger.error(text) + self.save_state() + raise TimeoutException(text) + return result + + return wrapper + + return decorator + + class ResearchAssistant(DataInterpreter): node_id: str = "0" start_task_id: int = 1 state_saved: bool = False role_dir: str = SERDESER_PATH.joinpath("team", "environment", "roles", "Experimenter") + role_timeout: int = 1000 def get_node_name(self): return f"Node-{self.node_id}" @@ -117,6 +140,12 @@ class ResearchAssistant(DataInterpreter): return task_result def save_state(self, static_save=False): + """ + attribute: + state_saved - the state has been saved + input: + static_save - saving the state without changing the state_saved flag - used when a new role is created + """ if self.state_saved and not static_save: return if not static_save: @@ -135,18 +164,14 @@ class ResearchAssistant(DataInterpreter): self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys()) ] + @async_timeout() async def run(self, with_message=None) -> Message | None: """Observe, and think and act based on the results of the observation""" if with_message == "continue": - # self.set_todo(None) - # working_memory = self.working_memory - # self.remap_tasks() mcts_logger.info("Continue to run") self.rc.working_memory.clear() self.working_memory.clear() - # self.rc.todo = WriteAnalysisCode() rsp = await self.react() - # 发送响应消息给 Environment 对象,以便它将消息传递给订阅者 self.set_todo(None) self.publish_message(rsp) return rsp diff --git a/expo/run_experiment.py b/expo/run_experiment.py index fbd05d776..15be27d60 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -18,6 +18,7 @@ def get_args(): default="mcts", choices=["mcts", "aug", "base", "custom", "greedy", "autogluon", "random", "autosklearn"], ) + parser.add_argument("--role_timeout", type=int, default=1000) get_di_args(parser) get_mcts_args(parser) get_aug_exp_args(parser) @@ -30,6 +31,7 @@ def get_mcts_args(parser): parser.set_defaults(load_tree=False) parser.add_argument("--rollouts", type=int, default=5) parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true") + parser.add_argument("--start_task_id", type=int, default=2) def get_aug_exp_args(parser): diff --git a/expo/scripts/run_cls.sh b/expo/scripts/run_cls.sh new file mode 100644 index 000000000..f0ee5ddcf --- /dev/null +++ b/expo/scripts/run_cls.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +tasks=("smoker-status" "software-defects" "jasmine" "credit-g" "Click_prediction_small" "kick" "kc1" "titanic" "icr" "wine-quality-white" "mfeat-factors" "segment" "GesturePhaseSegmentationProcessed") + + +for i in {1..3} +do + for task in "${tasks[@]}"; do + echo "Running experiment for task: $task" + python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --special_instruction stacking + echo "Experiment for task $task completed." + done +done + +echo "All experiments completed." diff --git a/expo/scripts/run_cls_mod.sh b/expo/scripts/run_cls_mod.sh new file mode 100644 index 000000000..ae3622b7a --- /dev/null +++ b/expo/scripts/run_cls_mod.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +tasks=("banking77" "gnad10" "sms_spam" "oxford-iiit-pet" "stanford_cars" "fashion_mnist" ) + +for i in {1..3} +do + for task in "${tasks[@]}"; do + echo "Running experiment for task: $task" + python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 + echo "Experiment for task $task completed." + done +done +echo "All experiments completed." diff --git a/expo/scripts/run_reg.sh b/expo/scripts/run_reg.sh new file mode 100644 index 000000000..f8a742886 --- /dev/null +++ b/expo/scripts/run_reg.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +tasks=("concrete-strength" "Moneyball" "colleges" "SAT11-HAND-runtime-regression" "diamonds" "boston" "house-prices") + +for i in {1..3} +do + for task in "${tasks[@]}"; do + echo "Running experiment for task: $task" + python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --low_is_better --special_instruction stacking + echo "Experiment for task $task completed." + done +done + +echo "All experiments completed." diff --git a/expo/utils.py b/expo/utils.py index 56f3c21b9..b022879b0 100644 --- a/expo/utils.py +++ b/expo/utils.py @@ -1,6 +1,5 @@ import os import re -import sys from datetime import datetime from pathlib import Path @@ -21,22 +20,19 @@ def load_data_config(file_path="data.yaml"): DATASET_CONFIG = load_data_config("datasets.yaml") DATA_CONFIG = load_data_config() -DATA_CONFIG["datasets"].update(DATASET_CONFIG["datasets"]) +DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"] def get_mcts_logger(): - print_level = "INFO" - print_level2 = "MCTS" - logfile_level = "MCTS" + logfile_level = "DEBUG" name: str = None current_date = datetime.now() formatted_date = current_date.strftime("%Y%m%d") log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name - _logger.remove() - _logger.level(logfile_level, color="", no=25) - _logger.add(sys.stderr, level=print_level) - _logger.add(sys.stderr, level=print_level2) + # _logger.remove() + _logger.level("MCTS", color="", no=25) + # _logger.add(sys.stderr, level=print_level) _logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level) _logger.propagate = False return _logger