diff --git a/expo/MCTS.py b/expo/MCTS.py index 4564cd682..8e685cc0a 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -10,7 +10,7 @@ import pandas as pd from expo.data.dataset import generate_task_requirement, get_split_dataset_path from expo.evaluation.evaluation import evaluate_score from expo.insights.instruction_generator import InstructionGenerator -from expo.research_assistant import ResearchAssistant +from expo.research_assistant import ResearchAssistant, TimeoutException from expo.utils import get_exp_pool_path, load_execute_notebook, mcts_logger from metagpt.tools.tool_recommend import ToolRecommender from metagpt.utils.common import read_json_file @@ -211,10 +211,14 @@ class Node: score_dict = self.evaluate_simulation(score_dict) self.raw_reward = score_dict run_finished = True + except TimeoutException as e: + mcts_logger.log("MCTS", f"Role-level timeout: {e}") + break except Exception as e: print(f"Error: {e}") mcts_logger.log("MCTS", f"Error in running the role: {e}") num_runs += 1 + if not run_finished: mcts_logger.log("MCTS", f"Role {role.node_id} failed to run") if self.state["low_is_better"]: diff --git a/expo/README.md b/expo/README.md index e5da96708..0a807c928 100644 --- a/expo/README.md +++ b/expo/README.md @@ -1,21 +1,21 @@ -# Expo +# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning +![pipeline](resources/MCTS-Experimenter.jpg) ## 1. Data Preparation -- 下载数据集:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink -- 修改`data.yaml`的`datasets_dir`为数据集合集根目录存储位置 +- Download Datasets:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink ## 2. Configs ### Data Config -`datasets.yaml` 提供数据集对应的指标和基础提示词 +`datasets.yaml` Provide base prompts, metrics, target columns for respective datasets -`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下 +- Modify `datasets_dir` to the root directory of all the datasets in `data.yaml` ### LLM Config @@ -30,28 +30,64 @@ ### LLM Config ``` ### Budget -实验轮次 k = 10, 20 +Experiment rollouts k = 5, 10, 20 ### Prompt Usage -- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词 - - 非DI-based方法设置`is_di=False` - - `data_config`用`utils.DATA_CONFIG` -- 每一个数据集里有`dataset_info.json`,里面的内容需要提供给baselines以保证公平(`generate_task_requirement`已经默认提供) +- Use the function `generate_task_requirement` in `dataset.py` to get task requirement. + - If the method is non-DI-based, set `is_di=False`. + - Use `utils.DATA_CONFIG` as `data_config` -## 3. Evaluation +## 3. SELA -运行各个框架,运行后框架需要提供Dev和Test的`dev_predictions.csv`和`test_predictions.csv`,每个csv文件只需要单个名为target的列 +### Run SELA + +#### Setup +In the root directory, -- 使用`CustomExperimenter` ``` -experimenter = CustomExperimenter(task="titanic") -score_dict = experimenter.evaluate_pred_files(dev_pred_path, test_pred_path) +pip install -e . + +cd expo + +pip install -r requirements.txt ``` -## 4. Baselines +#### Run + +- `python run_experiment.py --exp_mode mcts --task titanic --rollouts 10` + +If the dataset has reg metric, remember to use `--low_is_better`: + +- `python run_experiment.py --exp_mode mcts --task house_prices --rollouts 10 --low_is_better` + + +In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json` +- `--use_fixed_insights` + + + +#### Ablation Study + +**DI RandomSearch** + +- Single insight +`python run_experiment.py --exp_mode aug --task titanic --aug_mode single` + +- Set insight +`python run_experiment.py --exp_mode aug --task titanic --aug_mode set` + + +## 4. Evaluation + +Each baseline needs to produce `dev_predictions.csv`和`test_predictions.csv`. Each csv file only needs a `target` column. + +- Use the function `evaluate_score` to evaluate. + + +## 5. Baselines ### DS Agent ``` git clone https://github.com/guosyjlu/DS-Agent.git @@ -257,53 +293,12 @@ #### Run ``` ### Base DI -For setup, check 5. - +For setup, check 4. - `python run_experiment.py --exp_mode base --task titanic --num_experiments 10` -- Ask DI to use AutoGluon: `--special_instruction ag` -- Ask DI to use the stacking ensemble method: `--special_instruction stacking` - - - - -## 5. DI MCTS - -### Run DI MCTS - -#### Setup -In the root directory, - -``` -pip install -e . - -cd expo - -pip install -r requirements.txt -``` - -#### Run - -- `python run_experiment.py --exp_mode mcts --task titanic --rollout 10` - -If the dataset has reg metric, remember to use `--low_is_better`: - -- `python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 10 --low_is_better` - - -In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json` -- `--use_fixed_insights` - - - -#### Ablation Study - -**DI RandomSearch** - -- Single insight -`python run_experiment.py --exp_mode aug --task titanic --aug_mode single` - -- Set insight -`python run_experiment.py --exp_mode aug --task titanic --aug_mode set` +- Specifically instruct DI to use AutoGluon: `--special_instruction ag` +- Specifically instruct DI to use the stacking ensemble method: `--special_instruction stacking` + + diff --git a/expo/data.yaml b/expo/data.yaml index d62e45309..8273fecad 100644 --- a/expo/data.yaml +++ b/expo/data.yaml @@ -1,160 +1,3 @@ datasets_dir: "D:/work/automl/datasets" # path to the datasets directory - -datasets: - titanic: - dataset: 04_titanic - metric: f1 - target_col: Survived - user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\ - \ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - house-prices: - dataset: 05_house-prices-advanced-regression-techniques - metric: rmse - target_col: SalePrice - user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\ - \ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\ - \ data preprocessing, feature engineering, and modeling to predict the target.\ - \ \nReport rmse on the eval data. Do not plot or make any visualizations.\n" - santander-customer: - dataset: 06_santander-customer-transaction-prediction - metric: f1 - target_col: target - user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\ - \ Your goal is to predict the target column `target`.\nPerform data analysis,\ - \ data preprocessing, feature engineering, and modeling to predict the target.\ - \ \nReport f1 on the eval data. Do not plot or make any visualizations.\n" - icr: - dataset: 07_icr-identify-age-related-conditions - metric: f1 - target_col: Class - user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\ - \ goal is to predict the target column `Class`.\nPerform data analysis, data\ - \ preprocessing, feature engineering, and modeling to predict the target. \n\ - Report f1 on the eval data. Do not plot or make any visualizations.\n" - Click_prediction_small: - dataset: Click_prediction_small - metric: f1 - target_col: click - user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\ - \ the target column `click`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ - \ Do not plot or make any visualizations.\n" - GesturePhaseSegmentationProcessed: - dataset: GesturePhaseSegmentationProcessed - metric: f1 weighted - target_col: Phase - user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\ - \ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\ - \ feature engineering, and modeling to predict the target. \nReport f1 weighted\ - \ on the eval data. Do not plot or make any visualizations.\n" - Moneyball: - dataset: Moneyball - metric: rmse - target_col: RS - user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\ - \ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ - \ plot or make any visualizations.\n" - SAT11-HAND-runtime-regression: - dataset: SAT11-HAND-runtime-regression - metric: rmse - target_col: runtime - user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\ - \ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\ - \ feature engineering, and modeling to predict the target. \nReport rmse on\ - \ the eval data. Do not plot or make any visualizations.\n" - boston: - dataset: boston - metric: rmse - target_col: MEDV - user_requirement: "This is a boston dataset. Your goal is to predict the target\ - \ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ - \ plot or make any visualizations.\n" - colleges: - dataset: colleges - metric: rmse - target_col: percent_pell_grant - user_requirement: "This is a colleges dataset. Your goal is to predict the target\ - \ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport rmse on the eval\ - \ data. Do not plot or make any visualizations.\n" - credit-g: - dataset: credit-g - metric: f1 - target_col: class - user_requirement: "This is a credit-g dataset. Your goal is to predict the target\ - \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - diamonds: - dataset: diamonds - metric: rmse - target_col: price - user_requirement: "This is a diamonds dataset. Your goal is to predict the target\ - \ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ - \ plot or make any visualizations.\n" - jasmine: - dataset: jasmine - metric: f1 - target_col: class - user_requirement: "This is a jasmine dataset. Your goal is to predict the target\ - \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - kc1: - dataset: kc1 - metric: f1 - target_col: defects - user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\ - \ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - kick: - dataset: kick - metric: f1 - target_col: IsBadBuy - user_requirement: "This is a kick dataset. Your goal is to predict the target\ - \ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ - \ or make any visualizations.\n" - mfeat-factors: - dataset: mfeat-factors - metric: f1 weighted - target_col: class - user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\ - \ target column `class`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ - \ eval data. Do not plot or make any visualizations.\n" - segment: - dataset: segment - metric: f1 weighted - target_col: class - user_requirement: "This is a segment dataset. Your goal is to predict the target\ - \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ - \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ - \ Do not plot or make any visualizations.\n" - steel-plates-fault: - dataset: steel-plates-fault - metric: f1 weighted - target_col: target - user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\ - \ the target column `target`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ - \ eval data. Do not plot or make any visualizations.\n" - wine-quality-white: - dataset: wine-quality-white - metric: f1 weighted - target_col: Class - user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\ - \ the target column `Class`.\nPerform data analysis, data preprocessing, feature\ - \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ - \ eval data. Do not plot or make any visualizations.\n" - - work_dir: ../workspace # path to the workspace directory -role_dir: storage/team/environment/roles/ResearchAssistant_David -# analysis_pool_dir: D:/work/MG-open/MetaGPT/examples/MCTS_test/analysis_pool_sample.json \ No newline at end of file +role_dir: storage/SELA # path to the role directory diff --git a/expo/research_assistant.py b/expo/research_assistant.py index 51de188d3..8ee7dc204 100644 --- a/expo/research_assistant.py +++ b/expo/research_assistant.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import os @@ -10,7 +11,7 @@ from metagpt.actions.di.write_analysis_code import WriteAnalysisCode from metagpt.const import SERDESER_PATH from metagpt.roles.di.data_interpreter import DataInterpreter from metagpt.schema import Message, Task, TaskResult -from metagpt.utils.common import CodeParser, write_json_file +from metagpt.utils.common import CodeParser, role_raise_decorator, write_json_file EXTRACT_SCORE_PROMPT = """ # Code: @@ -34,6 +35,27 @@ If you cannot find the scores, please still return a dictionary with the keys 't """ +class TimeoutException(Exception): + pass + + +def async_timeout(seconds): + def decorator(func): + async def wrapper(self, *args, **kwargs): + try: + result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=seconds) + except asyncio.TimeoutError: + text = f"Function timed out after {seconds} seconds" + mcts_logger.error(text) + self.save_state() + raise TimeoutException(text) + return result + + return wrapper + + return decorator + + class ResearchAssistant(DataInterpreter): node_id: str = "0" start_task_id: int = 1 @@ -117,6 +139,12 @@ class ResearchAssistant(DataInterpreter): return task_result def save_state(self, static_save=False): + """ + attribute: + state_saved - the state has been saved + input: + static_save - saving the state without changing the state_saved flag - used when a new role is created + """ if self.state_saved and not static_save: return if not static_save: @@ -135,18 +163,15 @@ class ResearchAssistant(DataInterpreter): self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys()) ] + @async_timeout(1000) + @role_raise_decorator async def run(self, with_message=None) -> Message | None: """Observe, and think and act based on the results of the observation""" if with_message == "continue": - # self.set_todo(None) - # working_memory = self.working_memory - # self.remap_tasks() mcts_logger.info("Continue to run") self.rc.working_memory.clear() self.working_memory.clear() - # self.rc.todo = WriteAnalysisCode() rsp = await self.react() - # 发送响应消息给 Environment 对象,以便它将消息传递给订阅者 self.set_todo(None) self.publish_message(rsp) return rsp diff --git a/expo/utils.py b/expo/utils.py index 56f3c21b9..b022879b0 100644 --- a/expo/utils.py +++ b/expo/utils.py @@ -1,6 +1,5 @@ import os import re -import sys from datetime import datetime from pathlib import Path @@ -21,22 +20,19 @@ def load_data_config(file_path="data.yaml"): DATASET_CONFIG = load_data_config("datasets.yaml") DATA_CONFIG = load_data_config() -DATA_CONFIG["datasets"].update(DATASET_CONFIG["datasets"]) +DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"] def get_mcts_logger(): - print_level = "INFO" - print_level2 = "MCTS" - logfile_level = "MCTS" + logfile_level = "DEBUG" name: str = None current_date = datetime.now() formatted_date = current_date.strftime("%Y%m%d") log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name - _logger.remove() - _logger.level(logfile_level, color="", no=25) - _logger.add(sys.stderr, level=print_level) - _logger.add(sys.stderr, level=print_level2) + # _logger.remove() + _logger.level("MCTS", color="", no=25) + # _logger.add(sys.stderr, level=print_level) _logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level) _logger.propagate = False return _logger