mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
Merge branch 'prepare-for-opensource' into 'expo'
开源准备 See merge request agents/exp_optimizer!24
This commit is contained in:
commit
7cb307aef5
16 changed files with 181 additions and 255 deletions
27
expo/MCTS.py
27
expo/MCTS.py
|
|
@ -10,7 +10,7 @@ import pandas as pd
|
|||
from expo.data.dataset import generate_task_requirement, get_split_dataset_path
|
||||
from expo.evaluation.evaluation import evaluate_score
|
||||
from expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.research_assistant import ResearchAssistant
|
||||
from expo.research_assistant import ResearchAssistant, TimeoutException
|
||||
from expo.utils import get_exp_pool_path, load_execute_notebook, mcts_logger
|
||||
from metagpt.tools.tool_recommend import ToolRecommender
|
||||
from metagpt.utils.common import read_json_file
|
||||
|
|
@ -26,7 +26,9 @@ def initialize_di_root_node(state, reflection: bool = True):
|
|||
return role, Node(parent=None, state=state, action=None, value=0)
|
||||
|
||||
|
||||
def create_initial_state(task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str):
|
||||
def create_initial_state(
|
||||
task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args
|
||||
):
|
||||
initial_state = {
|
||||
"task": task,
|
||||
"work_dir": data_config["work_dir"],
|
||||
|
|
@ -40,6 +42,7 @@ def create_initial_state(task, start_task_id, data_config, low_is_better: bool,
|
|||
"has_run": False,
|
||||
"start_task_id": start_task_id,
|
||||
"low_is_better": low_is_better,
|
||||
"role_timeout": args.role_timeout,
|
||||
}
|
||||
os.makedirs(initial_state["node_dir"], exist_ok=True)
|
||||
return initial_state
|
||||
|
|
@ -152,18 +155,15 @@ class Node:
|
|||
role = role.model_copy()
|
||||
role.save_state(static_save=True)
|
||||
|
||||
async def expand(self, max_children, use_fixed_insights):
|
||||
async def expand(self, max_children: int, instruction_generator: InstructionGenerator):
|
||||
if self.is_fully_expanded():
|
||||
return
|
||||
insight_geneartor = InstructionGenerator()
|
||||
role = self.load_role()
|
||||
original_instruction = role.get_next_instruction()
|
||||
insights = await insight_geneartor.generate_new_instructions(
|
||||
insights = await instruction_generator.generate_new_instructions(
|
||||
task_id=role.start_task_id + 1,
|
||||
original_instruction=original_instruction,
|
||||
max_num=max_children,
|
||||
file_path=self.state["exp_pool_path"],
|
||||
use_fixed_insights=use_fixed_insights,
|
||||
)
|
||||
new_state = self.state.copy()
|
||||
new_state["start_task_id"] += 1
|
||||
|
|
@ -211,10 +211,14 @@ class Node:
|
|||
score_dict = self.evaluate_simulation(score_dict)
|
||||
self.raw_reward = score_dict
|
||||
run_finished = True
|
||||
except TimeoutException as e:
|
||||
mcts_logger.log("MCTS", f"Role-level timeout: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
mcts_logger.log("MCTS", f"Error in running the role: {e}")
|
||||
num_runs += 1
|
||||
|
||||
if not run_finished:
|
||||
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
|
||||
if self.state["low_is_better"]:
|
||||
|
|
@ -242,6 +246,8 @@ class MCTS:
|
|||
c_explore: float = 1.4
|
||||
c_unvisited: float = 0.8
|
||||
node_order: list = []
|
||||
# insight generator
|
||||
instruction_generator: InstructionGenerator = None
|
||||
|
||||
def __init__(self, root_node, max_depth, use_fixed_insights):
|
||||
self.root_node = root_node
|
||||
|
|
@ -265,7 +271,7 @@ class MCTS:
|
|||
return max(all_children, key=uct)
|
||||
|
||||
async def expand(self, node: Node, max_children=5):
|
||||
await node.expand(max_children, self.use_fixed_insights)
|
||||
await node.expand(max_children, self.instruction_generator)
|
||||
if node not in self.children or not self.children[node]:
|
||||
self.children[node] = node.children
|
||||
return node.children
|
||||
|
|
@ -277,6 +283,7 @@ class MCTS:
|
|||
node = random.choice(node.children)
|
||||
reward = await node.run_node(role)
|
||||
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
|
||||
|
||||
return reward
|
||||
|
||||
def backpropagate(self, node: Node, reward):
|
||||
|
|
@ -337,6 +344,10 @@ class MCTS:
|
|||
async def search(self, state, rollouts, load_tree=False, reflection=False):
|
||||
role, root = initialize_di_root_node(state, reflection=reflection)
|
||||
self.root_node = root
|
||||
self.instruction_generator = InstructionGenerator(
|
||||
file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights
|
||||
)
|
||||
|
||||
tree_loaded = False
|
||||
if load_tree:
|
||||
tree_loaded = self.load_tree()
|
||||
|
|
|
|||
118
expo/README.md
118
expo/README.md
|
|
@ -1,21 +1,20 @@
|
|||
# Expo
|
||||
# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning
|
||||
|
||||
|
||||
|
||||
|
||||
## 1. Data Preparation
|
||||
|
||||
- 下载数据集:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
|
||||
- 修改`data.yaml`的`datasets_dir`为数据集合集根目录存储位置
|
||||
- Download Datasets:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
|
||||
|
||||
|
||||
## 2. Configs
|
||||
|
||||
### Data Config
|
||||
|
||||
`datasets.yaml` 提供数据集对应的指标和基础提示词
|
||||
`datasets.yaml` Provide base prompts, metrics, target columns for respective datasets
|
||||
|
||||
`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下
|
||||
- Modify `datasets_dir` to the root directory of all the datasets in `data.yaml`
|
||||
|
||||
|
||||
### LLM Config
|
||||
|
|
@ -30,28 +29,64 @@ ### LLM Config
|
|||
```
|
||||
|
||||
### Budget
|
||||
实验轮次 k = 10, 20
|
||||
Experiment rollouts k = 5, 10, 20
|
||||
|
||||
|
||||
### Prompt Usage
|
||||
|
||||
- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
|
||||
- 非DI-based方法设置`is_di=False`
|
||||
- `data_config`用`utils.DATA_CONFIG`
|
||||
- 每一个数据集里有`dataset_info.json`,里面的内容需要提供给baselines以保证公平(`generate_task_requirement`已经默认提供)
|
||||
- Use the function `generate_task_requirement` in `dataset.py` to get task requirement.
|
||||
- If the method is non-DI-based, set `is_di=False`.
|
||||
- Use `utils.DATA_CONFIG` as `data_config`
|
||||
|
||||
|
||||
## 3. Evaluation
|
||||
## 3. SELA
|
||||
|
||||
运行各个框架,运行后框架需要提供Dev和Test的`dev_predictions.csv`和`test_predictions.csv`,每个csv文件只需要单个名为target的列
|
||||
### Run SELA
|
||||
|
||||
#### Setup
|
||||
In the root directory,
|
||||
|
||||
- 使用`CustomExperimenter`
|
||||
```
|
||||
experimenter = CustomExperimenter(task="titanic")
|
||||
score_dict = experimenter.evaluate_pred_files(dev_pred_path, test_pred_path)
|
||||
pip install -e .
|
||||
|
||||
cd expo
|
||||
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 4. Baselines
|
||||
#### Run
|
||||
|
||||
- `python run_experiment.py --exp_mode mcts --task titanic --rollouts 10`
|
||||
|
||||
If the dataset has reg metric, remember to use `--low_is_better`:
|
||||
|
||||
- `python run_experiment.py --exp_mode mcts --task house_prices --rollouts 10 --low_is_better`
|
||||
|
||||
|
||||
In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json`
|
||||
- `--use_fixed_insights`
|
||||
|
||||
|
||||
|
||||
#### Ablation Study
|
||||
|
||||
**DI RandomSearch**
|
||||
|
||||
- Single insight
|
||||
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
|
||||
|
||||
- Set insight
|
||||
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
|
||||
|
||||
|
||||
## 4. Evaluation
|
||||
|
||||
Each baseline needs to produce `dev_predictions.csv`和`test_predictions.csv`. Each csv file only needs a `target` column.
|
||||
|
||||
- Use the function `evaluate_score` to evaluate.
|
||||
|
||||
|
||||
## 5. Baselines
|
||||
### DS Agent
|
||||
```
|
||||
git clone https://github.com/guosyjlu/DS-Agent.git
|
||||
|
|
@ -257,53 +292,12 @@ #### Run
|
|||
```
|
||||
|
||||
### Base DI
|
||||
For setup, check 5.
|
||||
|
||||
For setup, check 4.
|
||||
- `python run_experiment.py --exp_mode base --task titanic --num_experiments 10`
|
||||
- Ask DI to use AutoGluon: `--special_instruction ag`
|
||||
- Ask DI to use the stacking ensemble method: `--special_instruction stacking`
|
||||
|
||||
|
||||
|
||||
|
||||
## 5. DI MCTS
|
||||
|
||||
### Run DI MCTS
|
||||
|
||||
#### Setup
|
||||
In the root directory,
|
||||
|
||||
```
|
||||
pip install -e .
|
||||
|
||||
cd expo
|
||||
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
#### Run
|
||||
|
||||
- `python run_experiment.py --exp_mode mcts --task titanic --rollout 10`
|
||||
|
||||
If the dataset has reg metric, remember to use `--low_is_better`:
|
||||
|
||||
- `python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 10 --low_is_better`
|
||||
|
||||
|
||||
In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json`
|
||||
- `--use_fixed_insights`
|
||||
|
||||
|
||||
|
||||
#### Ablation Study
|
||||
|
||||
**DI RandomSearch**
|
||||
|
||||
- Single insight
|
||||
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
|
||||
|
||||
- Set insight
|
||||
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
|
||||
- Specifically instruct DI to use AutoGluon: `--special_instruction ag`
|
||||
- Specifically instruct DI to use the stacking ensemble method: `--special_instruction stacking`
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
159
expo/data.yaml
159
expo/data.yaml
|
|
@ -1,160 +1,3 @@
|
|||
datasets_dir: "D:/work/automl/datasets" # path to the datasets directory
|
||||
|
||||
datasets:
|
||||
titanic:
|
||||
dataset: 04_titanic
|
||||
metric: f1
|
||||
target_col: Survived
|
||||
user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\
|
||||
\ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
|
||||
\ or make any visualizations.\n"
|
||||
house-prices:
|
||||
dataset: 05_house-prices-advanced-regression-techniques
|
||||
metric: rmse
|
||||
target_col: SalePrice
|
||||
user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\
|
||||
\ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\
|
||||
\ data preprocessing, feature engineering, and modeling to predict the target.\
|
||||
\ \nReport rmse on the eval data. Do not plot or make any visualizations.\n"
|
||||
santander-customer:
|
||||
dataset: 06_santander-customer-transaction-prediction
|
||||
metric: f1
|
||||
target_col: target
|
||||
user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\
|
||||
\ Your goal is to predict the target column `target`.\nPerform data analysis,\
|
||||
\ data preprocessing, feature engineering, and modeling to predict the target.\
|
||||
\ \nReport f1 on the eval data. Do not plot or make any visualizations.\n"
|
||||
icr:
|
||||
dataset: 07_icr-identify-age-related-conditions
|
||||
metric: f1
|
||||
target_col: Class
|
||||
user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\
|
||||
\ goal is to predict the target column `Class`.\nPerform data analysis, data\
|
||||
\ preprocessing, feature engineering, and modeling to predict the target. \n\
|
||||
Report f1 on the eval data. Do not plot or make any visualizations.\n"
|
||||
Click_prediction_small:
|
||||
dataset: Click_prediction_small
|
||||
metric: f1
|
||||
target_col: click
|
||||
user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\
|
||||
\ the target column `click`.\nPerform data analysis, data preprocessing, feature\
|
||||
\ engineering, and modeling to predict the target. \nReport f1 on the eval data.\
|
||||
\ Do not plot or make any visualizations.\n"
|
||||
GesturePhaseSegmentationProcessed:
|
||||
dataset: GesturePhaseSegmentationProcessed
|
||||
metric: f1 weighted
|
||||
target_col: Phase
|
||||
user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\
|
||||
\ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\
|
||||
\ feature engineering, and modeling to predict the target. \nReport f1 weighted\
|
||||
\ on the eval data. Do not plot or make any visualizations.\n"
|
||||
Moneyball:
|
||||
dataset: Moneyball
|
||||
metric: rmse
|
||||
target_col: RS
|
||||
user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\
|
||||
\ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
|
||||
\ plot or make any visualizations.\n"
|
||||
SAT11-HAND-runtime-regression:
|
||||
dataset: SAT11-HAND-runtime-regression
|
||||
metric: rmse
|
||||
target_col: runtime
|
||||
user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\
|
||||
\ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\
|
||||
\ feature engineering, and modeling to predict the target. \nReport rmse on\
|
||||
\ the eval data. Do not plot or make any visualizations.\n"
|
||||
boston:
|
||||
dataset: boston
|
||||
metric: rmse
|
||||
target_col: MEDV
|
||||
user_requirement: "This is a boston dataset. Your goal is to predict the target\
|
||||
\ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
|
||||
\ plot or make any visualizations.\n"
|
||||
colleges:
|
||||
dataset: colleges
|
||||
metric: rmse
|
||||
target_col: percent_pell_grant
|
||||
user_requirement: "This is a colleges dataset. Your goal is to predict the target\
|
||||
\ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\
|
||||
\ engineering, and modeling to predict the target. \nReport rmse on the eval\
|
||||
\ data. Do not plot or make any visualizations.\n"
|
||||
credit-g:
|
||||
dataset: credit-g
|
||||
metric: f1
|
||||
target_col: class
|
||||
user_requirement: "This is a credit-g dataset. Your goal is to predict the target\
|
||||
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
|
||||
\ or make any visualizations.\n"
|
||||
diamonds:
|
||||
dataset: diamonds
|
||||
metric: rmse
|
||||
target_col: price
|
||||
user_requirement: "This is a diamonds dataset. Your goal is to predict the target\
|
||||
\ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
|
||||
\ plot or make any visualizations.\n"
|
||||
jasmine:
|
||||
dataset: jasmine
|
||||
metric: f1
|
||||
target_col: class
|
||||
user_requirement: "This is a jasmine dataset. Your goal is to predict the target\
|
||||
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
|
||||
\ or make any visualizations.\n"
|
||||
kc1:
|
||||
dataset: kc1
|
||||
metric: f1
|
||||
target_col: defects
|
||||
user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\
|
||||
\ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
|
||||
\ or make any visualizations.\n"
|
||||
kick:
|
||||
dataset: kick
|
||||
metric: f1
|
||||
target_col: IsBadBuy
|
||||
user_requirement: "This is a kick dataset. Your goal is to predict the target\
|
||||
\ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
|
||||
\ or make any visualizations.\n"
|
||||
mfeat-factors:
|
||||
dataset: mfeat-factors
|
||||
metric: f1 weighted
|
||||
target_col: class
|
||||
user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\
|
||||
\ target column `class`.\nPerform data analysis, data preprocessing, feature\
|
||||
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
|
||||
\ eval data. Do not plot or make any visualizations.\n"
|
||||
segment:
|
||||
dataset: segment
|
||||
metric: f1 weighted
|
||||
target_col: class
|
||||
user_requirement: "This is a segment dataset. Your goal is to predict the target\
|
||||
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
|
||||
\ and modeling to predict the target. \nReport f1 weighted on the eval data.\
|
||||
\ Do not plot or make any visualizations.\n"
|
||||
steel-plates-fault:
|
||||
dataset: steel-plates-fault
|
||||
metric: f1 weighted
|
||||
target_col: target
|
||||
user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\
|
||||
\ the target column `target`.\nPerform data analysis, data preprocessing, feature\
|
||||
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
|
||||
\ eval data. Do not plot or make any visualizations.\n"
|
||||
wine-quality-white:
|
||||
dataset: wine-quality-white
|
||||
metric: f1 weighted
|
||||
target_col: Class
|
||||
user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\
|
||||
\ the target column `Class`.\nPerform data analysis, data preprocessing, feature\
|
||||
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
|
||||
\ eval data. Do not plot or make any visualizations.\n"
|
||||
|
||||
|
||||
work_dir: ../workspace # path to the workspace directory
|
||||
role_dir: storage/team/environment/roles/ResearchAssistant_David
|
||||
# analysis_pool_dir: D:/work/MG-open/MetaGPT/examples/MCTS_test/analysis_pool_sample.json
|
||||
role_dir: storage/SELA # path to the role directory
|
||||
|
|
@ -9,6 +9,7 @@ import yaml
|
|||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from expo.insights.solution_designer import SolutionDesigner
|
||||
from expo.utils import DATA_CONFIG
|
||||
|
||||
BASE_USER_REQUIREMENT = """
|
||||
This is a {datasetname} dataset. Your goal is to predict the target column `{target_col}`.
|
||||
|
|
@ -361,7 +362,7 @@ async def process_dataset(dataset, solution_designer: SolutionDesigner, save_ana
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
datasets_dir = "D:/work/automl/datasets"
|
||||
datasets_dir = DATA_CONFIG["datasets_dir"]
|
||||
force_update = False
|
||||
save_analysis_pool = True
|
||||
datasets_dict = {"datasets": {}}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from PIL import Image
|
|||
|
||||
from expo.data.dataset import ExpDataset, process_dataset, save_datasets_dict_to_yaml
|
||||
from expo.insights.solution_designer import SolutionDesigner
|
||||
from expo.utils import DATA_CONFIG
|
||||
|
||||
HFDATSETS = [
|
||||
{"name": "sms_spam", "dataset_name": "ucirvine/sms_spam", "target_col": "label", "modality": "text"},
|
||||
|
|
@ -114,7 +115,7 @@ class HFExpDataset(ExpDataset):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset_dir = "D:/work/automl/datasets"
|
||||
dataset_dir = DATA_CONFIG["datasets_dir"]
|
||||
save_analysis_pool = True
|
||||
force_update = False
|
||||
datasets_dict = {"datasets": {}}
|
||||
|
|
|
|||
|
|
@ -34,7 +34,9 @@ class AugExperimenter(Experimenter):
|
|||
|
||||
results = []
|
||||
for i in range(self.args.num_experiments):
|
||||
di = ResearchAssistant(node_id=str(i), use_reflection=self.args.reflection)
|
||||
di = ResearchAssistant(
|
||||
node_id=str(i), use_reflection=self.args.reflection, role_timeout=self.args.role_timeout
|
||||
)
|
||||
di.role_dir = f"{di.role_dir}_{self.args.task}"
|
||||
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
|
||||
print(requirement)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class Experimenter:
|
|||
low_is_better=self.args.low_is_better,
|
||||
name=self.args.name,
|
||||
special_instruction=self.args.special_instruction,
|
||||
args=self.args,
|
||||
)
|
||||
|
||||
async def run_di(self, di, user_requirement, run_idx):
|
||||
|
|
@ -82,7 +83,9 @@ class Experimenter:
|
|||
results = []
|
||||
|
||||
for i in range(self.args.num_experiments):
|
||||
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
|
||||
di = ResearchAssistant(
|
||||
node_id="0", use_reflection=self.args.reflection, role_timeout=self.args.role_timeout
|
||||
)
|
||||
score_dict = await self.run_di(di, user_requirement, run_idx=i)
|
||||
results.append(
|
||||
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,12 @@ from expo.MCTS import MCTS
|
|||
|
||||
class MCTSExperimenter(Experimenter):
|
||||
result_path: str = "results/mcts"
|
||||
start_task_id = 2
|
||||
|
||||
def __init__(self, args, tree_mode=None, **kwargs):
|
||||
if args.special_instruction == "image":
|
||||
self.start_task_id = 1 # start from datapreprocessing if it is image task
|
||||
else:
|
||||
self.start_task_id = args.start_task_id
|
||||
super().__init__(args, **kwargs)
|
||||
self.tree_mode = tree_mode
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import os
|
||||
import random
|
||||
|
||||
from expo.insights.solution_designer import SolutionDesigner
|
||||
from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -32,6 +33,12 @@ DATA_CONFIG = load_data_config()
|
|||
class InstructionGenerator:
|
||||
data_config = DATA_CONFIG
|
||||
|
||||
def __init__(self, file_path, use_fixed_insights=False):
|
||||
self.file_path = file_path
|
||||
self.use_fixed_insights = use_fixed_insights
|
||||
self.analysis_pool = self.load_insight_pool(file_path, use_fixed_insights)
|
||||
self.proposer = SolutionDesigner()
|
||||
|
||||
@staticmethod
|
||||
def load_json_data(json_dir):
|
||||
with open(json_dir, "r") as file:
|
||||
|
|
@ -69,7 +76,7 @@ class InstructionGenerator:
|
|||
return new_data
|
||||
|
||||
@staticmethod
|
||||
def load_analysis_pool(file_path, use_fixed_insights, task_id=None):
|
||||
def load_insight_pool(file_path, use_fixed_insights, task_id=None):
|
||||
data = InstructionGenerator.load_json_data(file_path)
|
||||
if use_fixed_insights:
|
||||
current_directory = os.path.dirname(__file__)
|
||||
|
|
@ -83,13 +90,8 @@ class InstructionGenerator:
|
|||
data = [item for item in data if int(item["task_id"]) == int(task_id)]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def generate_new_instructions(
|
||||
task_id, original_instruction, max_num, file_path, ext_info=None, use_fixed_insights=False
|
||||
):
|
||||
data = InstructionGenerator.load_analysis_pool(
|
||||
file_path, task_id=task_id, use_fixed_insights=use_fixed_insights
|
||||
)
|
||||
async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None):
|
||||
data = self.analysis_pool
|
||||
new_instructions = []
|
||||
if len(data) == 0:
|
||||
mcts_logger.log("MCTS", f"No insights available for task {task_id}")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ The insights should be proposed based on the dataset description with different
|
|||
Each task type should have at least 5 insights.
|
||||
Make sure each method is diverse enough and can be implemented separately.
|
||||
Be specific about models' choices, ensemble and tuning techniques, and preprocessing & feature engineering techniques.
|
||||
Your model choices should be advanced enough to be helpful.
|
||||
|
||||
# Format
|
||||
```json
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
|
|
@ -34,11 +35,33 @@ If you cannot find the scores, please still return a dictionary with the keys 't
|
|||
"""
|
||||
|
||||
|
||||
class TimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def async_timeout():
|
||||
def decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=self.role_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
text = f"Function timed out after {self.role_timeout} seconds"
|
||||
mcts_logger.error(text)
|
||||
self.save_state()
|
||||
raise TimeoutException(text)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ResearchAssistant(DataInterpreter):
|
||||
node_id: str = "0"
|
||||
start_task_id: int = 1
|
||||
state_saved: bool = False
|
||||
role_dir: str = SERDESER_PATH.joinpath("team", "environment", "roles", "Experimenter")
|
||||
role_timeout: int = 1000
|
||||
|
||||
def get_node_name(self):
|
||||
return f"Node-{self.node_id}"
|
||||
|
|
@ -117,6 +140,12 @@ class ResearchAssistant(DataInterpreter):
|
|||
return task_result
|
||||
|
||||
def save_state(self, static_save=False):
|
||||
"""
|
||||
attribute:
|
||||
state_saved - the state has been saved
|
||||
input:
|
||||
static_save - saving the state without changing the state_saved flag - used when a new role is created
|
||||
"""
|
||||
if self.state_saved and not static_save:
|
||||
return
|
||||
if not static_save:
|
||||
|
|
@ -135,18 +164,14 @@ class ResearchAssistant(DataInterpreter):
|
|||
self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys())
|
||||
]
|
||||
|
||||
@async_timeout()
|
||||
async def run(self, with_message=None) -> Message | None:
|
||||
"""Observe, and think and act based on the results of the observation"""
|
||||
if with_message == "continue":
|
||||
# self.set_todo(None)
|
||||
# working_memory = self.working_memory
|
||||
# self.remap_tasks()
|
||||
mcts_logger.info("Continue to run")
|
||||
self.rc.working_memory.clear()
|
||||
self.working_memory.clear()
|
||||
# self.rc.todo = WriteAnalysisCode()
|
||||
rsp = await self.react()
|
||||
# 发送响应消息给 Environment 对象,以便它将消息传递给订阅者
|
||||
self.set_todo(None)
|
||||
self.publish_message(rsp)
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ def get_args():
|
|||
default="mcts",
|
||||
choices=["mcts", "aug", "base", "custom", "greedy", "autogluon", "random", "autosklearn"],
|
||||
)
|
||||
parser.add_argument("--role_timeout", type=int, default=1000)
|
||||
get_di_args(parser)
|
||||
get_mcts_args(parser)
|
||||
get_aug_exp_args(parser)
|
||||
|
|
@ -30,6 +31,7 @@ def get_mcts_args(parser):
|
|||
parser.set_defaults(load_tree=False)
|
||||
parser.add_argument("--rollouts", type=int, default=5)
|
||||
parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true")
|
||||
parser.add_argument("--start_task_id", type=int, default=2)
|
||||
|
||||
|
||||
def get_aug_exp_args(parser):
|
||||
|
|
|
|||
15
expo/scripts/run_cls.sh
Normal file
15
expo/scripts/run_cls.sh
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
tasks=("smoker-status" "software-defects" "jasmine" "credit-g" "Click_prediction_small" "kick" "kc1" "titanic" "icr" "wine-quality-white" "mfeat-factors" "segment" "GesturePhaseSegmentationProcessed")
|
||||
|
||||
|
||||
for i in {1..3}
|
||||
do
|
||||
for task in "${tasks[@]}"; do
|
||||
echo "Running experiment for task: $task"
|
||||
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --special_instruction stacking
|
||||
echo "Experiment for task $task completed."
|
||||
done
|
||||
done
|
||||
|
||||
echo "All experiments completed."
|
||||
13
expo/scripts/run_cls_mod.sh
Normal file
13
expo/scripts/run_cls_mod.sh
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#!/bin/bash
|
||||
|
||||
tasks=("banking77" "gnad10" "sms_spam" "oxford-iiit-pet" "stanford_cars" "fashion_mnist" )
|
||||
|
||||
for i in {1..3}
|
||||
do
|
||||
for task in "${tasks[@]}"; do
|
||||
echo "Running experiment for task: $task"
|
||||
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10
|
||||
echo "Experiment for task $task completed."
|
||||
done
|
||||
done
|
||||
echo "All experiments completed."
|
||||
14
expo/scripts/run_reg.sh
Normal file
14
expo/scripts/run_reg.sh
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
|
||||
tasks=("concrete-strength" "Moneyball" "colleges" "SAT11-HAND-runtime-regression" "diamonds" "boston" "house-prices")
|
||||
|
||||
for i in {1..3}
|
||||
do
|
||||
for task in "${tasks[@]}"; do
|
||||
echo "Running experiment for task: $task"
|
||||
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --low_is_better --special_instruction stacking
|
||||
echo "Experiment for task $task completed."
|
||||
done
|
||||
done
|
||||
|
||||
echo "All experiments completed."
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -21,22 +20,19 @@ def load_data_config(file_path="data.yaml"):
|
|||
|
||||
DATASET_CONFIG = load_data_config("datasets.yaml")
|
||||
DATA_CONFIG = load_data_config()
|
||||
DATA_CONFIG["datasets"].update(DATASET_CONFIG["datasets"])
|
||||
DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"]
|
||||
|
||||
|
||||
def get_mcts_logger():
|
||||
print_level = "INFO"
|
||||
print_level2 = "MCTS"
|
||||
logfile_level = "MCTS"
|
||||
logfile_level = "DEBUG"
|
||||
name: str = None
|
||||
current_date = datetime.now()
|
||||
formatted_date = current_date.strftime("%Y%m%d")
|
||||
log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name
|
||||
|
||||
_logger.remove()
|
||||
_logger.level(logfile_level, color="<green>", no=25)
|
||||
_logger.add(sys.stderr, level=print_level)
|
||||
_logger.add(sys.stderr, level=print_level2)
|
||||
# _logger.remove()
|
||||
_logger.level("MCTS", color="<green>", no=25)
|
||||
# _logger.add(sys.stderr, level=print_level)
|
||||
_logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level)
|
||||
_logger.propagate = False
|
||||
return _logger
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue