1. add role level timeout 限制是1000s

2. 修改log的层级逻辑
3. data.yaml 只用于存路径
This commit is contained in:
Yizhou Chi 2024-10-10 16:30:07 +08:00
parent 573e9b6d9e
commit f80ebc4d67
5 changed files with 99 additions and 236 deletions

View file

@ -10,7 +10,7 @@ import pandas as pd
from expo.data.dataset import generate_task_requirement, get_split_dataset_path
from expo.evaluation.evaluation import evaluate_score
from expo.insights.instruction_generator import InstructionGenerator
from expo.research_assistant import ResearchAssistant
from expo.research_assistant import ResearchAssistant, TimeoutException
from expo.utils import get_exp_pool_path, load_execute_notebook, mcts_logger
from metagpt.tools.tool_recommend import ToolRecommender
from metagpt.utils.common import read_json_file
@ -211,10 +211,14 @@ class Node:
score_dict = self.evaluate_simulation(score_dict)
self.raw_reward = score_dict
run_finished = True
except TimeoutException as e:
mcts_logger.log("MCTS", f"Role-level timeout: {e}")
break
except Exception as e:
print(f"Error: {e}")
mcts_logger.log("MCTS", f"Error in running the role: {e}")
num_runs += 1
if not run_finished:
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
if self.state["low_is_better"]:

View file

@ -1,21 +1,21 @@
# Expo
# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning
![pipeline](resources/MCTS-Experimenter.jpg)
## 1. Data Preparation
- 下载数据集https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
- 修改`data.yaml``datasets_dir`为数据集合集根目录存储位置
- Download Datasetshttps://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
## 2. Configs
### Data Config
`datasets.yaml` 提供数据集对应的指标和基础提示词
`datasets.yaml` Provide base prompts, metrics, target columns for respective datasets
`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下
- Modify `datasets_dir` to the root directory of all the datasets in `data.yaml`
### LLM Config
@ -30,28 +30,64 @@ ### LLM Config
```
### Budget
实验轮次 k = 10, 20
Experiment rollouts k = 5, 10, 20
### Prompt Usage
- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
- 非DI-based方法设置`is_di=False`
- `data_config``utils.DATA_CONFIG`
- 每一个数据集里有`dataset_info.json`里面的内容需要提供给baselines以保证公平`generate_task_requirement`已经默认提供)
- Use the function `generate_task_requirement` in `dataset.py` to get task requirement.
- If the method is non-DI-based, set `is_di=False`.
- Use `utils.DATA_CONFIG` as `data_config`
## 3. Evaluation
## 3. SELA
运行各个框架运行后框架需要提供Dev和Test的`dev_predictions.csv``test_predictions.csv`每个csv文件只需要单个名为target的列
### Run SELA
#### Setup
In the root directory,
- 使用`CustomExperimenter`
```
experimenter = CustomExperimenter(task="titanic")
score_dict = experimenter.evaluate_pred_files(dev_pred_path, test_pred_path)
pip install -e .
cd expo
pip install -r requirements.txt
```
## 4. Baselines
#### Run
- `python run_experiment.py --exp_mode mcts --task titanic --rollouts 10`
If the dataset has reg metric, remember to use `--low_is_better`:
- `python run_experiment.py --exp_mode mcts --task house_prices --rollouts 10 --low_is_better`
In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json`
- `--use_fixed_insights`
#### Ablation Study
**DI RandomSearch**
- Single insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
- Set insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
## 4. Evaluation
Each baseline needs to produce `dev_predictions.csv``test_predictions.csv`. Each csv file only needs a `target` column.
- Use the function `evaluate_score` to evaluate.
## 5. Baselines
### DS Agent
```
git clone https://github.com/guosyjlu/DS-Agent.git
@ -257,53 +293,12 @@ #### Run
```
### Base DI
For setup, check 5.
For setup, check 4.
- `python run_experiment.py --exp_mode base --task titanic --num_experiments 10`
- Ask DI to use AutoGluon: `--special_instruction ag`
- Ask DI to use the stacking ensemble method: `--special_instruction stacking`
## 5. DI MCTS
### Run DI MCTS
#### Setup
In the root directory,
```
pip install -e .
cd expo
pip install -r requirements.txt
```
#### Run
- `python run_experiment.py --exp_mode mcts --task titanic --rollout 10`
If the dataset has reg metric, remember to use `--low_is_better`:
- `python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 10 --low_is_better`
In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json`
- `--use_fixed_insights`
#### Ablation Study
**DI RandomSearch**
- Single insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
- Set insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
- Specifically instruct DI to use AutoGluon: `--special_instruction ag`
- Specifically instruct DI to use the stacking ensemble method: `--special_instruction stacking`

View file

@ -1,160 +1,3 @@
datasets_dir: "D:/work/automl/datasets" # path to the datasets directory
datasets:
titanic:
dataset: 04_titanic
metric: f1
target_col: Survived
user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\
\ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
house-prices:
dataset: 05_house-prices-advanced-regression-techniques
metric: rmse
target_col: SalePrice
user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\
\ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\
\ data preprocessing, feature engineering, and modeling to predict the target.\
\ \nReport rmse on the eval data. Do not plot or make any visualizations.\n"
santander-customer:
dataset: 06_santander-customer-transaction-prediction
metric: f1
target_col: target
user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\
\ Your goal is to predict the target column `target`.\nPerform data analysis,\
\ data preprocessing, feature engineering, and modeling to predict the target.\
\ \nReport f1 on the eval data. Do not plot or make any visualizations.\n"
icr:
dataset: 07_icr-identify-age-related-conditions
metric: f1
target_col: Class
user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\
\ goal is to predict the target column `Class`.\nPerform data analysis, data\
\ preprocessing, feature engineering, and modeling to predict the target. \n\
Report f1 on the eval data. Do not plot or make any visualizations.\n"
Click_prediction_small:
dataset: Click_prediction_small
metric: f1
target_col: click
user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\
\ the target column `click`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 on the eval data.\
\ Do not plot or make any visualizations.\n"
GesturePhaseSegmentationProcessed:
dataset: GesturePhaseSegmentationProcessed
metric: f1 weighted
target_col: Phase
user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\
\ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\
\ feature engineering, and modeling to predict the target. \nReport f1 weighted\
\ on the eval data. Do not plot or make any visualizations.\n"
Moneyball:
dataset: Moneyball
metric: rmse
target_col: RS
user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\
\ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
\ plot or make any visualizations.\n"
SAT11-HAND-runtime-regression:
dataset: SAT11-HAND-runtime-regression
metric: rmse
target_col: runtime
user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\
\ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\
\ feature engineering, and modeling to predict the target. \nReport rmse on\
\ the eval data. Do not plot or make any visualizations.\n"
boston:
dataset: boston
metric: rmse
target_col: MEDV
user_requirement: "This is a boston dataset. Your goal is to predict the target\
\ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
\ plot or make any visualizations.\n"
colleges:
dataset: colleges
metric: rmse
target_col: percent_pell_grant
user_requirement: "This is a colleges dataset. Your goal is to predict the target\
\ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport rmse on the eval\
\ data. Do not plot or make any visualizations.\n"
credit-g:
dataset: credit-g
metric: f1
target_col: class
user_requirement: "This is a credit-g dataset. Your goal is to predict the target\
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
diamonds:
dataset: diamonds
metric: rmse
target_col: price
user_requirement: "This is a diamonds dataset. Your goal is to predict the target\
\ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
\ plot or make any visualizations.\n"
jasmine:
dataset: jasmine
metric: f1
target_col: class
user_requirement: "This is a jasmine dataset. Your goal is to predict the target\
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
kc1:
dataset: kc1
metric: f1
target_col: defects
user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\
\ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
kick:
dataset: kick
metric: f1
target_col: IsBadBuy
user_requirement: "This is a kick dataset. Your goal is to predict the target\
\ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
mfeat-factors:
dataset: mfeat-factors
metric: f1 weighted
target_col: class
user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\
\ target column `class`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
\ eval data. Do not plot or make any visualizations.\n"
segment:
dataset: segment
metric: f1 weighted
target_col: class
user_requirement: "This is a segment dataset. Your goal is to predict the target\
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 weighted on the eval data.\
\ Do not plot or make any visualizations.\n"
steel-plates-fault:
dataset: steel-plates-fault
metric: f1 weighted
target_col: target
user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\
\ the target column `target`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
\ eval data. Do not plot or make any visualizations.\n"
wine-quality-white:
dataset: wine-quality-white
metric: f1 weighted
target_col: Class
user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\
\ the target column `Class`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
\ eval data. Do not plot or make any visualizations.\n"
work_dir: ../workspace # path to the workspace directory
role_dir: storage/team/environment/roles/ResearchAssistant_David
# analysis_pool_dir: D:/work/MG-open/MetaGPT/examples/MCTS_test/analysis_pool_sample.json
role_dir: storage/SELA # path to the role directory

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import json
import os
@ -10,7 +11,7 @@ from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.const import SERDESER_PATH
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message, Task, TaskResult
from metagpt.utils.common import CodeParser, write_json_file
from metagpt.utils.common import CodeParser, role_raise_decorator, write_json_file
EXTRACT_SCORE_PROMPT = """
# Code:
@ -34,6 +35,27 @@ If you cannot find the scores, please still return a dictionary with the keys 't
"""
class TimeoutException(Exception):
pass
def async_timeout(seconds):
def decorator(func):
async def wrapper(self, *args, **kwargs):
try:
result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=seconds)
except asyncio.TimeoutError:
text = f"Function timed out after {seconds} seconds"
mcts_logger.error(text)
self.save_state()
raise TimeoutException(text)
return result
return wrapper
return decorator
class ResearchAssistant(DataInterpreter):
node_id: str = "0"
start_task_id: int = 1
@ -117,6 +139,12 @@ class ResearchAssistant(DataInterpreter):
return task_result
def save_state(self, static_save=False):
"""
attribute:
state_saved - the state has been saved
input:
static_save - saving the state without changing the state_saved flag - used when a new role is created
"""
if self.state_saved and not static_save:
return
if not static_save:
@ -135,18 +163,15 @@ class ResearchAssistant(DataInterpreter):
self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys())
]
@async_timeout(1000)
@role_raise_decorator
async def run(self, with_message=None) -> Message | None:
"""Observe, and think and act based on the results of the observation"""
if with_message == "continue":
# self.set_todo(None)
# working_memory = self.working_memory
# self.remap_tasks()
mcts_logger.info("Continue to run")
self.rc.working_memory.clear()
self.working_memory.clear()
# self.rc.todo = WriteAnalysisCode()
rsp = await self.react()
# 发送响应消息给 Environment 对象,以便它将消息传递给订阅者
self.set_todo(None)
self.publish_message(rsp)
return rsp

View file

@ -1,6 +1,5 @@
import os
import re
import sys
from datetime import datetime
from pathlib import Path
@ -21,22 +20,19 @@ def load_data_config(file_path="data.yaml"):
DATASET_CONFIG = load_data_config("datasets.yaml")
DATA_CONFIG = load_data_config()
DATA_CONFIG["datasets"].update(DATASET_CONFIG["datasets"])
DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"]
def get_mcts_logger():
print_level = "INFO"
print_level2 = "MCTS"
logfile_level = "MCTS"
logfile_level = "DEBUG"
name: str = None
current_date = datetime.now()
formatted_date = current_date.strftime("%Y%m%d")
log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name
_logger.remove()
_logger.level(logfile_level, color="<green>", no=25)
_logger.add(sys.stderr, level=print_level)
_logger.add(sys.stderr, level=print_level2)
# _logger.remove()
_logger.level("MCTS", color="<green>", no=25)
# _logger.add(sys.stderr, level=print_level)
_logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level)
_logger.propagate = False
return _logger