Merge branch 'prepare-for-opensource' into 'expo'

开源准备

See merge request agents/exp_optimizer!24
This commit is contained in:
林义章 2024-10-15 03:53:30 +00:00
commit 7cb307aef5
16 changed files with 181 additions and 255 deletions

View file

@ -10,7 +10,7 @@ import pandas as pd
from expo.data.dataset import generate_task_requirement, get_split_dataset_path
from expo.evaluation.evaluation import evaluate_score
from expo.insights.instruction_generator import InstructionGenerator
from expo.research_assistant import ResearchAssistant
from expo.research_assistant import ResearchAssistant, TimeoutException
from expo.utils import get_exp_pool_path, load_execute_notebook, mcts_logger
from metagpt.tools.tool_recommend import ToolRecommender
from metagpt.utils.common import read_json_file
@ -26,7 +26,9 @@ def initialize_di_root_node(state, reflection: bool = True):
return role, Node(parent=None, state=state, action=None, value=0)
def create_initial_state(task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str):
def create_initial_state(
task, start_task_id, data_config, low_is_better: bool, name: str, special_instruction: str, args
):
initial_state = {
"task": task,
"work_dir": data_config["work_dir"],
@ -40,6 +42,7 @@ def create_initial_state(task, start_task_id, data_config, low_is_better: bool,
"has_run": False,
"start_task_id": start_task_id,
"low_is_better": low_is_better,
"role_timeout": args.role_timeout,
}
os.makedirs(initial_state["node_dir"], exist_ok=True)
return initial_state
@ -152,18 +155,15 @@ class Node:
role = role.model_copy()
role.save_state(static_save=True)
async def expand(self, max_children, use_fixed_insights):
async def expand(self, max_children: int, instruction_generator: InstructionGenerator):
if self.is_fully_expanded():
return
insight_geneartor = InstructionGenerator()
role = self.load_role()
original_instruction = role.get_next_instruction()
insights = await insight_geneartor.generate_new_instructions(
insights = await instruction_generator.generate_new_instructions(
task_id=role.start_task_id + 1,
original_instruction=original_instruction,
max_num=max_children,
file_path=self.state["exp_pool_path"],
use_fixed_insights=use_fixed_insights,
)
new_state = self.state.copy()
new_state["start_task_id"] += 1
@ -211,10 +211,14 @@ class Node:
score_dict = self.evaluate_simulation(score_dict)
self.raw_reward = score_dict
run_finished = True
except TimeoutException as e:
mcts_logger.log("MCTS", f"Role-level timeout: {e}")
break
except Exception as e:
print(f"Error: {e}")
mcts_logger.log("MCTS", f"Error in running the role: {e}")
num_runs += 1
if not run_finished:
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
if self.state["low_is_better"]:
@ -242,6 +246,8 @@ class MCTS:
c_explore: float = 1.4
c_unvisited: float = 0.8
node_order: list = []
# insight generator
instruction_generator: InstructionGenerator = None
def __init__(self, root_node, max_depth, use_fixed_insights):
self.root_node = root_node
@ -265,7 +271,7 @@ class MCTS:
return max(all_children, key=uct)
async def expand(self, node: Node, max_children=5):
await node.expand(max_children, self.use_fixed_insights)
await node.expand(max_children, self.instruction_generator)
if node not in self.children or not self.children[node]:
self.children[node] = node.children
return node.children
@ -277,6 +283,7 @@ class MCTS:
node = random.choice(node.children)
reward = await node.run_node(role)
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
return reward
def backpropagate(self, node: Node, reward):
@ -337,6 +344,10 @@ class MCTS:
async def search(self, state, rollouts, load_tree=False, reflection=False):
role, root = initialize_di_root_node(state, reflection=reflection)
self.root_node = root
self.instruction_generator = InstructionGenerator(
file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights
)
tree_loaded = False
if load_tree:
tree_loaded = self.load_tree()

View file

@ -1,21 +1,20 @@
# Expo
# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning
## 1. Data Preparation
- 下载数据集https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
- 修改`data.yaml``datasets_dir`为数据集合集根目录存储位置
- Download Datasetshttps://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink
## 2. Configs
### Data Config
`datasets.yaml` 提供数据集对应的指标和基础提示词
`datasets.yaml` Provide base prompts, metrics, target columns for respective datasets
`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下
- Modify `datasets_dir` to the root directory of all the datasets in `data.yaml`
### LLM Config
@ -30,28 +29,64 @@ ### LLM Config
```
### Budget
实验轮次 k = 10, 20
Experiment rollouts k = 5, 10, 20
### Prompt Usage
- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
- 非DI-based方法设置`is_di=False`
- `data_config``utils.DATA_CONFIG`
- 每一个数据集里有`dataset_info.json`里面的内容需要提供给baselines以保证公平`generate_task_requirement`已经默认提供)
- Use the function `generate_task_requirement` in `dataset.py` to get task requirement.
- If the method is non-DI-based, set `is_di=False`.
- Use `utils.DATA_CONFIG` as `data_config`
## 3. Evaluation
## 3. SELA
运行各个框架运行后框架需要提供Dev和Test的`dev_predictions.csv``test_predictions.csv`每个csv文件只需要单个名为target的列
### Run SELA
#### Setup
In the root directory,
- 使用`CustomExperimenter`
```
experimenter = CustomExperimenter(task="titanic")
score_dict = experimenter.evaluate_pred_files(dev_pred_path, test_pred_path)
pip install -e .
cd expo
pip install -r requirements.txt
```
## 4. Baselines
#### Run
- `python run_experiment.py --exp_mode mcts --task titanic --rollouts 10`
If the dataset has reg metric, remember to use `--low_is_better`:
- `python run_experiment.py --exp_mode mcts --task house_prices --rollouts 10 --low_is_better`
In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json`
- `--use_fixed_insights`
#### Ablation Study
**DI RandomSearch**
- Single insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
- Set insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
## 4. Evaluation
Each baseline needs to produce `dev_predictions.csv``test_predictions.csv`. Each csv file only needs a `target` column.
- Use the function `evaluate_score` to evaluate.
## 5. Baselines
### DS Agent
```
git clone https://github.com/guosyjlu/DS-Agent.git
@ -257,53 +292,12 @@ #### Run
```
### Base DI
For setup, check 5.
For setup, check 4.
- `python run_experiment.py --exp_mode base --task titanic --num_experiments 10`
- Ask DI to use AutoGluon: `--special_instruction ag`
- Ask DI to use the stacking ensemble method: `--special_instruction stacking`
## 5. DI MCTS
### Run DI MCTS
#### Setup
In the root directory,
```
pip install -e .
cd expo
pip install -r requirements.txt
```
#### Run
- `python run_experiment.py --exp_mode mcts --task titanic --rollout 10`
If the dataset has reg metric, remember to use `--low_is_better`:
- `python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 10 --low_is_better`
In addition to the generated insights, include the fixed insights saved in `expo/insights/fixed_insights.json`
- `--use_fixed_insights`
#### Ablation Study
**DI RandomSearch**
- Single insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode single`
- Set insight
`python run_experiment.py --exp_mode aug --task titanic --aug_mode set`
- Specifically instruct DI to use AutoGluon: `--special_instruction ag`
- Specifically instruct DI to use the stacking ensemble method: `--special_instruction stacking`

View file

@ -1,160 +1,3 @@
datasets_dir: "D:/work/automl/datasets" # path to the datasets directory
datasets:
titanic:
dataset: 04_titanic
metric: f1
target_col: Survived
user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\
\ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
house-prices:
dataset: 05_house-prices-advanced-regression-techniques
metric: rmse
target_col: SalePrice
user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\
\ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\
\ data preprocessing, feature engineering, and modeling to predict the target.\
\ \nReport rmse on the eval data. Do not plot or make any visualizations.\n"
santander-customer:
dataset: 06_santander-customer-transaction-prediction
metric: f1
target_col: target
user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\
\ Your goal is to predict the target column `target`.\nPerform data analysis,\
\ data preprocessing, feature engineering, and modeling to predict the target.\
\ \nReport f1 on the eval data. Do not plot or make any visualizations.\n"
icr:
dataset: 07_icr-identify-age-related-conditions
metric: f1
target_col: Class
user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\
\ goal is to predict the target column `Class`.\nPerform data analysis, data\
\ preprocessing, feature engineering, and modeling to predict the target. \n\
Report f1 on the eval data. Do not plot or make any visualizations.\n"
Click_prediction_small:
dataset: Click_prediction_small
metric: f1
target_col: click
user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\
\ the target column `click`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 on the eval data.\
\ Do not plot or make any visualizations.\n"
GesturePhaseSegmentationProcessed:
dataset: GesturePhaseSegmentationProcessed
metric: f1 weighted
target_col: Phase
user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\
\ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\
\ feature engineering, and modeling to predict the target. \nReport f1 weighted\
\ on the eval data. Do not plot or make any visualizations.\n"
Moneyball:
dataset: Moneyball
metric: rmse
target_col: RS
user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\
\ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
\ plot or make any visualizations.\n"
SAT11-HAND-runtime-regression:
dataset: SAT11-HAND-runtime-regression
metric: rmse
target_col: runtime
user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\
\ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\
\ feature engineering, and modeling to predict the target. \nReport rmse on\
\ the eval data. Do not plot or make any visualizations.\n"
boston:
dataset: boston
metric: rmse
target_col: MEDV
user_requirement: "This is a boston dataset. Your goal is to predict the target\
\ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
\ plot or make any visualizations.\n"
colleges:
dataset: colleges
metric: rmse
target_col: percent_pell_grant
user_requirement: "This is a colleges dataset. Your goal is to predict the target\
\ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport rmse on the eval\
\ data. Do not plot or make any visualizations.\n"
credit-g:
dataset: credit-g
metric: f1
target_col: class
user_requirement: "This is a credit-g dataset. Your goal is to predict the target\
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
diamonds:
dataset: diamonds
metric: rmse
target_col: price
user_requirement: "This is a diamonds dataset. Your goal is to predict the target\
\ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport rmse on the eval data. Do not\
\ plot or make any visualizations.\n"
jasmine:
dataset: jasmine
metric: f1
target_col: class
user_requirement: "This is a jasmine dataset. Your goal is to predict the target\
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
kc1:
dataset: kc1
metric: f1
target_col: defects
user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\
\ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
kick:
dataset: kick
metric: f1
target_col: IsBadBuy
user_requirement: "This is a kick dataset. Your goal is to predict the target\
\ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\
\ or make any visualizations.\n"
mfeat-factors:
dataset: mfeat-factors
metric: f1 weighted
target_col: class
user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\
\ target column `class`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
\ eval data. Do not plot or make any visualizations.\n"
segment:
dataset: segment
metric: f1 weighted
target_col: class
user_requirement: "This is a segment dataset. Your goal is to predict the target\
\ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\
\ and modeling to predict the target. \nReport f1 weighted on the eval data.\
\ Do not plot or make any visualizations.\n"
steel-plates-fault:
dataset: steel-plates-fault
metric: f1 weighted
target_col: target
user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\
\ the target column `target`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
\ eval data. Do not plot or make any visualizations.\n"
wine-quality-white:
dataset: wine-quality-white
metric: f1 weighted
target_col: Class
user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\
\ the target column `Class`.\nPerform data analysis, data preprocessing, feature\
\ engineering, and modeling to predict the target. \nReport f1 weighted on the\
\ eval data. Do not plot or make any visualizations.\n"
work_dir: ../workspace # path to the workspace directory
role_dir: storage/team/environment/roles/ResearchAssistant_David
# analysis_pool_dir: D:/work/MG-open/MetaGPT/examples/MCTS_test/analysis_pool_sample.json
role_dir: storage/SELA # path to the role directory

View file

@ -9,6 +9,7 @@ import yaml
from sklearn.model_selection import train_test_split
from expo.insights.solution_designer import SolutionDesigner
from expo.utils import DATA_CONFIG
BASE_USER_REQUIREMENT = """
This is a {datasetname} dataset. Your goal is to predict the target column `{target_col}`.
@ -361,7 +362,7 @@ async def process_dataset(dataset, solution_designer: SolutionDesigner, save_ana
if __name__ == "__main__":
datasets_dir = "D:/work/automl/datasets"
datasets_dir = DATA_CONFIG["datasets_dir"]
force_update = False
save_analysis_pool = True
datasets_dict = {"datasets": {}}

View file

@ -9,6 +9,7 @@ from PIL import Image
from expo.data.dataset import ExpDataset, process_dataset, save_datasets_dict_to_yaml
from expo.insights.solution_designer import SolutionDesigner
from expo.utils import DATA_CONFIG
HFDATSETS = [
{"name": "sms_spam", "dataset_name": "ucirvine/sms_spam", "target_col": "label", "modality": "text"},
@ -114,7 +115,7 @@ class HFExpDataset(ExpDataset):
if __name__ == "__main__":
dataset_dir = "D:/work/automl/datasets"
dataset_dir = DATA_CONFIG["datasets_dir"]
save_analysis_pool = True
force_update = False
datasets_dict = {"datasets": {}}

View file

@ -34,7 +34,9 @@ class AugExperimenter(Experimenter):
results = []
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id=str(i), use_reflection=self.args.reflection)
di = ResearchAssistant(
node_id=str(i), use_reflection=self.args.reflection, role_timeout=self.args.role_timeout
)
di.role_dir = f"{di.role_dir}_{self.args.task}"
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
print(requirement)

View file

@ -27,6 +27,7 @@ class Experimenter:
low_is_better=self.args.low_is_better,
name=self.args.name,
special_instruction=self.args.special_instruction,
args=self.args,
)
async def run_di(self, di, user_requirement, run_idx):
@ -82,7 +83,9 @@ class Experimenter:
results = []
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
di = ResearchAssistant(
node_id="0", use_reflection=self.args.reflection, role_timeout=self.args.role_timeout
)
score_dict = await self.run_di(di, user_requirement, run_idx=i)
results.append(
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}

View file

@ -8,9 +8,12 @@ from expo.MCTS import MCTS
class MCTSExperimenter(Experimenter):
result_path: str = "results/mcts"
start_task_id = 2
def __init__(self, args, tree_mode=None, **kwargs):
if args.special_instruction == "image":
self.start_task_id = 1 # start from datapreprocessing if it is image task
else:
self.start_task_id = args.start_task_id
super().__init__(args, **kwargs)
self.tree_mode = tree_mode

View file

@ -2,6 +2,7 @@ import json
import os
import random
from expo.insights.solution_designer import SolutionDesigner
from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger
from metagpt.llm import LLM
from metagpt.schema import Message
@ -32,6 +33,12 @@ DATA_CONFIG = load_data_config()
class InstructionGenerator:
data_config = DATA_CONFIG
def __init__(self, file_path, use_fixed_insights=False):
self.file_path = file_path
self.use_fixed_insights = use_fixed_insights
self.analysis_pool = self.load_insight_pool(file_path, use_fixed_insights)
self.proposer = SolutionDesigner()
@staticmethod
def load_json_data(json_dir):
with open(json_dir, "r") as file:
@ -69,7 +76,7 @@ class InstructionGenerator:
return new_data
@staticmethod
def load_analysis_pool(file_path, use_fixed_insights, task_id=None):
def load_insight_pool(file_path, use_fixed_insights, task_id=None):
data = InstructionGenerator.load_json_data(file_path)
if use_fixed_insights:
current_directory = os.path.dirname(__file__)
@ -83,13 +90,8 @@ class InstructionGenerator:
data = [item for item in data if int(item["task_id"]) == int(task_id)]
return data
@staticmethod
async def generate_new_instructions(
task_id, original_instruction, max_num, file_path, ext_info=None, use_fixed_insights=False
):
data = InstructionGenerator.load_analysis_pool(
file_path, task_id=task_id, use_fixed_insights=use_fixed_insights
)
async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None):
data = self.analysis_pool
new_instructions = []
if len(data) == 0:
mcts_logger.log("MCTS", f"No insights available for task {task_id}")

View file

@ -21,6 +21,7 @@ The insights should be proposed based on the dataset description with different
Each task type should have at least 5 insights.
Make sure each method is diverse enough and can be implemented separately.
Be specific about models' choices, ensemble and tuning techniques, and preprocessing & feature engineering techniques.
Your model choices should be advanced enough to be helpful.
# Format
```json

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import json
import os
@ -34,11 +35,33 @@ If you cannot find the scores, please still return a dictionary with the keys 't
"""
class TimeoutException(Exception):
pass
def async_timeout():
def decorator(func):
async def wrapper(self, *args, **kwargs):
try:
result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=self.role_timeout)
except asyncio.TimeoutError:
text = f"Function timed out after {self.role_timeout} seconds"
mcts_logger.error(text)
self.save_state()
raise TimeoutException(text)
return result
return wrapper
return decorator
class ResearchAssistant(DataInterpreter):
node_id: str = "0"
start_task_id: int = 1
state_saved: bool = False
role_dir: str = SERDESER_PATH.joinpath("team", "environment", "roles", "Experimenter")
role_timeout: int = 1000
def get_node_name(self):
return f"Node-{self.node_id}"
@ -117,6 +140,12 @@ class ResearchAssistant(DataInterpreter):
return task_result
def save_state(self, static_save=False):
"""
attribute:
state_saved - the state has been saved
input:
static_save - saving the state without changing the state_saved flag - used when a new role is created
"""
if self.state_saved and not static_save:
return
if not static_save:
@ -135,18 +164,14 @@ class ResearchAssistant(DataInterpreter):
self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys())
]
@async_timeout()
async def run(self, with_message=None) -> Message | None:
"""Observe, and think and act based on the results of the observation"""
if with_message == "continue":
# self.set_todo(None)
# working_memory = self.working_memory
# self.remap_tasks()
mcts_logger.info("Continue to run")
self.rc.working_memory.clear()
self.working_memory.clear()
# self.rc.todo = WriteAnalysisCode()
rsp = await self.react()
# 发送响应消息给 Environment 对象,以便它将消息传递给订阅者
self.set_todo(None)
self.publish_message(rsp)
return rsp

View file

@ -18,6 +18,7 @@ def get_args():
default="mcts",
choices=["mcts", "aug", "base", "custom", "greedy", "autogluon", "random", "autosklearn"],
)
parser.add_argument("--role_timeout", type=int, default=1000)
get_di_args(parser)
get_mcts_args(parser)
get_aug_exp_args(parser)
@ -30,6 +31,7 @@ def get_mcts_args(parser):
parser.set_defaults(load_tree=False)
parser.add_argument("--rollouts", type=int, default=5)
parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true")
parser.add_argument("--start_task_id", type=int, default=2)
def get_aug_exp_args(parser):

15
expo/scripts/run_cls.sh Normal file
View file

@ -0,0 +1,15 @@
#!/bin/bash
tasks=("smoker-status" "software-defects" "jasmine" "credit-g" "Click_prediction_small" "kick" "kc1" "titanic" "icr" "wine-quality-white" "mfeat-factors" "segment" "GesturePhaseSegmentationProcessed")
for i in {1..3}
do
for task in "${tasks[@]}"; do
echo "Running experiment for task: $task"
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --special_instruction stacking
echo "Experiment for task $task completed."
done
done
echo "All experiments completed."

View file

@ -0,0 +1,13 @@
#!/bin/bash
tasks=("banking77" "gnad10" "sms_spam" "oxford-iiit-pet" "stanford_cars" "fashion_mnist" )
for i in {1..3}
do
for task in "${tasks[@]}"; do
echo "Running experiment for task: $task"
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10
echo "Experiment for task $task completed."
done
done
echo "All experiments completed."

14
expo/scripts/run_reg.sh Normal file
View file

@ -0,0 +1,14 @@
#!/bin/bash
tasks=("concrete-strength" "Moneyball" "colleges" "SAT11-HAND-runtime-regression" "diamonds" "boston" "house-prices")
for i in {1..3}
do
for task in "${tasks[@]}"; do
echo "Running experiment for task: $task"
python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --low_is_better --special_instruction stacking
echo "Experiment for task $task completed."
done
done
echo "All experiments completed."

View file

@ -1,6 +1,5 @@
import os
import re
import sys
from datetime import datetime
from pathlib import Path
@ -21,22 +20,19 @@ def load_data_config(file_path="data.yaml"):
DATASET_CONFIG = load_data_config("datasets.yaml")
DATA_CONFIG = load_data_config()
DATA_CONFIG["datasets"].update(DATASET_CONFIG["datasets"])
DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"]
def get_mcts_logger():
print_level = "INFO"
print_level2 = "MCTS"
logfile_level = "MCTS"
logfile_level = "DEBUG"
name: str = None
current_date = datetime.now()
formatted_date = current_date.strftime("%Y%m%d")
log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name
_logger.remove()
_logger.level(logfile_level, color="<green>", no=25)
_logger.add(sys.stderr, level=print_level)
_logger.add(sys.stderr, level=print_level2)
# _logger.remove()
_logger.level("MCTS", color="<green>", no=25)
# _logger.add(sys.stderr, level=print_level)
_logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level)
_logger.propagate = False
return _logger