diff --git a/expo/MCTS.py b/expo/MCTS.py index ec5ef9da0..14f2c4e4b 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -300,7 +300,7 @@ class MCTS(): mcts_logger.log("MCTS", f"Tree loaded: {tree_loaded}") if not tree_loaded: - rollouts -= 2 + rollouts -= 2 # 2 rollouts for the initial tree if rollouts < 0: raise ValueError("Rollouts must be greater than 2 if there is no tree to load") self.children[root] = [] diff --git a/expo/README.md b/expo/README.md index 6e4081031..2ecf2fd2f 100644 --- a/expo/README.md +++ b/expo/README.md @@ -1,25 +1,76 @@ # Expo -## Setup -In the root directory, `pip install -e .` -`cd expo` -`pip install -r requirements.txt` - -## Instruction +## 1. Data Preparation - 下载数据集:https://deepwisdom.feishu.cn/drive/folder/RVyofv9cvlvtxKdddt2cyn3BnTc?from=from_copylink - 修改`data.yaml`的`datasets_dir`为数据集合集根目录存储位置 -## Examples -### Run Base DI - -`python run_experiment.py --exp_mode base --task titanic` +## 2. Configs -### Run DI RandExp +### Data Config + +`datasets.yaml` 提供数据集对应的指标和基础提示词 + +`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下 + + +### LLM Config + +``` +llm: + api_type: 'openai' + model: deepseek-coder + base_url: "https://oneapi.deepwisdom.ai/v1" + api_key: sk-xxx + temperature: 0.5 +``` + +### Budget +实验轮次 k = 10, 20 + + +### 提示词使用 + +通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词 + + +## 3. Evaluation + +运行各个框架,运行后框架需要提供Dev和Test的`dev_predictions.csv`和`test_predictions.csv`, column name为target + +两种评估方式 + +1. `evaluation.py` 提供pred和原始的gt(1D iterable)以及需要使用的metric,返回evaluation score + +2. 使用`CustomExperimenter` +``` +experimenter = CustomExperimenter(task="titanic") +score_dict = experimenter.evaluate_pred_files(dev_pred_path, test_pred_path) +``` + +## 4. Baselines +### DS Agent +提供github链接,并说明使用的命令以及参数设置 + + +### AIDE +提供github链接,并说明使用的命令以及参数设置 + +### Autogluon +提供github链接,并说明使用的命令以及参数设置 + +### Base DI +For setup, check 5. + +- `python run_experiment.py --exp_mode base --task titanic` + + +### DI RandomSearch +For setup, check 5. - Single insight `python run_experiment.py --exp_mode aug --task titanic --aug_mode single` @@ -28,30 +79,36 @@ ### Run DI RandExp `python run_experiment.py --exp_mode aug --task titanic --aug_mode set` +## 5. DI MCTS ### Run DI MCTS -`python run_experiment.py --exp_mode mcts --task titanic --rollout 5` + +#### Setup +In the root directory, + +``` +pip install -e . + +cd expo + +pip install -r requirements.txt +``` + +#### Run + +- `python run_experiment.py --exp_mode mcts --task titanic --rollout 5` If the dataset has reg metric, remember to use `--low_is_better`: - `python run_experiment.py --exp_mode mcts --task househouse_prices --rollout 5 --low_is_better` -## Custom Experimenter -## Code and Configs Explanation - -`datasets.yaml` 提供数据集对应的指标和基础提示词 - -`data.yaml` 继承了`datasets.yaml`以及一些路径信息,需要将`datasets_dir`指到数据集合集的根目录下 - -完整的DI提示词参考`dataset.py`中的`generate_task_requirement`函数 -## Evaluation -`evaluation.py` 提供pred和原始的gt(1D iterable)以及需要使用的metric,返回evaluation score + diff --git a/expo/data.yaml b/expo/data.yaml index 050b0b893..d62e45309 100644 --- a/expo/data.yaml +++ b/expo/data.yaml @@ -4,22 +4,23 @@ datasets: titanic: dataset: 04_titanic metric: f1 + target_col: Survived user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\ \ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ \ or make any visualizations.\n" - - house_prices: + house-prices: dataset: 05_house-prices-advanced-regression-techniques metric: rmse + target_col: SalePrice user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\ \ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\ \ data preprocessing, feature engineering, and modeling to predict the target.\ \ \nReport rmse on the eval data. Do not plot or make any visualizations.\n" - - santander_customers: + santander-customer: dataset: 06_santander-customer-transaction-prediction metric: f1 + target_col: target user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\ \ Your goal is to predict the target column `target`.\nPerform data analysis,\ \ data preprocessing, feature engineering, and modeling to predict the target.\ @@ -27,126 +28,127 @@ datasets: icr: dataset: 07_icr-identify-age-related-conditions metric: f1 + target_col: Class user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\ \ goal is to predict the target column `Class`.\nPerform data analysis, data\ \ preprocessing, feature engineering, and modeling to predict the target. \n\ Report f1 on the eval data. Do not plot or make any visualizations.\n" - - lick_prediction_small: + Click_prediction_small: dataset: Click_prediction_small metric: f1 + target_col: click user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\ \ the target column `click`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ \ Do not plot or make any visualizations.\n" - GesturePhaseSegmentationProcessed: dataset: GesturePhaseSegmentationProcessed metric: f1 weighted + target_col: Phase user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\ \ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\ \ feature engineering, and modeling to predict the target. \nReport f1 weighted\ \ on the eval data. Do not plot or make any visualizations.\n" - Moneyball: dataset: Moneyball metric: rmse + target_col: RS user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\ \ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ \ plot or make any visualizations.\n" - SAT11-HAND-runtime-regression: dataset: SAT11-HAND-runtime-regression metric: rmse + target_col: runtime user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\ \ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\ \ feature engineering, and modeling to predict the target. \nReport rmse on\ \ the eval data. Do not plot or make any visualizations.\n" - boston: dataset: boston metric: rmse + target_col: MEDV user_requirement: "This is a boston dataset. Your goal is to predict the target\ \ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ \ plot or make any visualizations.\n" - colleges: dataset: colleges metric: rmse + target_col: percent_pell_grant user_requirement: "This is a colleges dataset. Your goal is to predict the target\ \ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport rmse on the eval\ \ data. Do not plot or make any visualizations.\n" - credit-g: dataset: credit-g metric: f1 + target_col: class user_requirement: "This is a credit-g dataset. Your goal is to predict the target\ \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ \ or make any visualizations.\n" - diamonds: dataset: diamonds metric: rmse + target_col: price user_requirement: "This is a diamonds dataset. Your goal is to predict the target\ \ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ \ plot or make any visualizations.\n" - jasmine: dataset: jasmine metric: f1 + target_col: class user_requirement: "This is a jasmine dataset. Your goal is to predict the target\ \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ \ or make any visualizations.\n" - kc1: dataset: kc1 metric: f1 + target_col: defects user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\ \ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ \ or make any visualizations.\n" - kick: dataset: kick metric: f1 + target_col: IsBadBuy user_requirement: "This is a kick dataset. Your goal is to predict the target\ \ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ \ or make any visualizations.\n" - mfeat-factors: dataset: mfeat-factors metric: f1 weighted + target_col: class user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\ \ target column `class`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ \ eval data. Do not plot or make any visualizations.\n" - segment: dataset: segment metric: f1 weighted + target_col: class user_requirement: "This is a segment dataset. Your goal is to predict the target\ \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ \ Do not plot or make any visualizations.\n" - steel-plates-fault: dataset: steel-plates-fault metric: f1 weighted + target_col: target user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\ \ the target column `target`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ \ eval data. Do not plot or make any visualizations.\n" - wine-quality-white: dataset: wine-quality-white metric: f1 weighted + target_col: Class user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\ \ the target column `Class`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ diff --git a/expo/dataset.py b/expo/dataset.py index 3e292ba7c..62665d297 100644 --- a/expo/dataset.py +++ b/expo/dataset.py @@ -21,13 +21,13 @@ TASK_PROMPT = """\ 1. Please do not leak the target label in any form during training. 2. Dev and Test sets do not have the target column. 3. You should perform transformations on all sets at the same step. +4. If labels are transformed during training, they should be transformed back to the original format before saving the predictions. ## Saving Dev and Test Predictions 1. Save the prediction results of BOTH the dev set and test set in `dev_predictions.csv` and `test_predictions.csv` respectively in the output directory. - Both files should contain a single column named `target` with the predicted values. 2. Make sure the prediction results are in the same format as the target column in the training set. - The labels should be transformed back to the original format if any transformation was applied during training. -- If the original target column was categorical or string, the predictions MUST be in the same format. ## Output Training Set Performance Make sure the performance of the model is printed in python in the last step even if it has been printed in the previous steps. The value should be a float number. @@ -119,7 +119,8 @@ def create_dataset_dict(dataset): dataset_dict = { "dataset": dataset.name, "user_requirement": dataset.create_base_requirement(), - "metric": dataset.get_metric() + "metric": dataset.get_metric(), + "target_col": dataset.target_col } return dataset_dict @@ -289,23 +290,24 @@ class OpenMLExpDataset(ExpDataset): # def __init__(self, name, dataset_dir, dataset_name, **kwargs): # super().__init__(name, dataset_dir, **kwargs) - +async def process_dataset(dataset, solution_designer, save_analysis_pool, datasets_dict): + if save_analysis_pool: + asyncio.run(solution_designer.generate_solutions(dataset.get_dataset_info(), dataset.name)) + dataset_dict = create_dataset_dict(dataset) + datasets_dict["datasets"][dataset.name] = dataset_dict if __name__ == "__main__": datasets_dir = "D:/work/automl/datasets" - force_update = True + force_update = False + save_analysis_pool = False datasets_dict = {"datasets": {}} solution_designer = SolutionDesigner() for dataset_id in OPENML_DATASET_IDS: openml_dataset = OpenMLExpDataset("", datasets_dir, dataset_id, force_update=force_update) - asyncio.run(solution_designer.generate_solutions(openml_dataset.get_dataset_info(), openml_dataset.name)) - dataset_dict = create_dataset_dict(openml_dataset) - datasets_dict["datasets"][openml_dataset.name] = dataset_dict + asyncio.run(process_dataset(openml_dataset, solution_designer, save_analysis_pool, datasets_dict)) for dataset_name, target_col in CUSTOM_DATASETS: custom_dataset = ExpDataset(dataset_name, datasets_dir, target_col=target_col, force_update=force_update) - asyncio.run(solution_designer.generate_solutions(custom_dataset.get_dataset_info(), custom_dataset.name)) - dataset_dict = create_dataset_dict(custom_dataset) - datasets_dict["datasets"][custom_dataset.name] = dataset_dict - + asyncio.run(process_dataset(custom_dataset, solution_designer, save_analysis_pool, datasets_dict)) + save_datasets_dict_to_yaml(datasets_dict) diff --git a/expo/datasets.yaml b/expo/datasets.yaml index ec00e3d1f..8c28b03ca 100644 --- a/expo/datasets.yaml +++ b/expo/datasets.yaml @@ -1,28 +1,32 @@ datasets: - 04_titanic: + titanic: dataset: 04_titanic metric: f1 + target_col: Survived user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\ \ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ \ or make any visualizations.\n" - 05_house-prices-advanced-regression-techniques: + house-prices: dataset: 05_house-prices-advanced-regression-techniques metric: rmse + target_col: SalePrice user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\ \ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\ \ data preprocessing, feature engineering, and modeling to predict the target.\ \ \nReport rmse on the eval data. Do not plot or make any visualizations.\n" - 06_santander-customer-transaction-prediction: + santander-customer: dataset: 06_santander-customer-transaction-prediction metric: f1 + target_col: target user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\ \ Your goal is to predict the target column `target`.\nPerform data analysis,\ \ data preprocessing, feature engineering, and modeling to predict the target.\ \ \nReport f1 on the eval data. Do not plot or make any visualizations.\n" - 07_icr-identify-age-related-conditions: + icr: dataset: 07_icr-identify-age-related-conditions metric: f1 + target_col: Class user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\ \ goal is to predict the target column `Class`.\nPerform data analysis, data\ \ preprocessing, feature engineering, and modeling to predict the target. \n\ @@ -30,6 +34,7 @@ datasets: Click_prediction_small: dataset: Click_prediction_small metric: f1 + target_col: click user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\ \ the target column `click`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ @@ -37,6 +42,7 @@ datasets: GesturePhaseSegmentationProcessed: dataset: GesturePhaseSegmentationProcessed metric: f1 weighted + target_col: Phase user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\ \ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\ \ feature engineering, and modeling to predict the target. \nReport f1 weighted\ @@ -44,6 +50,7 @@ datasets: Moneyball: dataset: Moneyball metric: rmse + target_col: RS user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\ \ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ @@ -51,6 +58,7 @@ datasets: SAT11-HAND-runtime-regression: dataset: SAT11-HAND-runtime-regression metric: rmse + target_col: runtime user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\ \ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\ \ feature engineering, and modeling to predict the target. \nReport rmse on\ @@ -58,6 +66,7 @@ datasets: boston: dataset: boston metric: rmse + target_col: MEDV user_requirement: "This is a boston dataset. Your goal is to predict the target\ \ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ @@ -65,6 +74,7 @@ datasets: colleges: dataset: colleges metric: rmse + target_col: percent_pell_grant user_requirement: "This is a colleges dataset. Your goal is to predict the target\ \ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport rmse on the eval\ @@ -72,6 +82,7 @@ datasets: credit-g: dataset: credit-g metric: f1 + target_col: class user_requirement: "This is a credit-g dataset. Your goal is to predict the target\ \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ @@ -79,6 +90,7 @@ datasets: diamonds: dataset: diamonds metric: rmse + target_col: price user_requirement: "This is a diamonds dataset. Your goal is to predict the target\ \ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ @@ -86,6 +98,7 @@ datasets: jasmine: dataset: jasmine metric: f1 + target_col: class user_requirement: "This is a jasmine dataset. Your goal is to predict the target\ \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ @@ -93,6 +106,7 @@ datasets: kc1: dataset: kc1 metric: f1 + target_col: defects user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\ \ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ @@ -100,6 +114,7 @@ datasets: kick: dataset: kick metric: f1 + target_col: IsBadBuy user_requirement: "This is a kick dataset. Your goal is to predict the target\ \ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ @@ -107,6 +122,7 @@ datasets: mfeat-factors: dataset: mfeat-factors metric: f1 weighted + target_col: class user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\ \ target column `class`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ @@ -114,6 +130,7 @@ datasets: segment: dataset: segment metric: f1 weighted + target_col: class user_requirement: "This is a segment dataset. Your goal is to predict the target\ \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ @@ -121,6 +138,7 @@ datasets: steel-plates-fault: dataset: steel-plates-fault metric: f1 weighted + target_col: target user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\ \ the target column `target`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ @@ -128,6 +146,7 @@ datasets: wine-quality-white: dataset: wine-quality-white metric: f1 weighted + target_col: Class user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\ \ the target column `Class`.\nPerform data analysis, data preprocessing, feature\ \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ diff --git a/expo/experimenter/custom.py b/expo/experimenter/custom.py index 06df4efcf..ff5ba3546 100644 --- a/expo/experimenter/custom.py +++ b/expo/experimenter/custom.py @@ -9,10 +9,12 @@ class CustomExperimenter(Experimenter): def __init__(self, args, **kwargs): super().__init__(args, **kwargs) - self.framework = kwargs["framework"] + self.framework = kwargs["framework"] # todo + self.task = kwargs.get("task", self.args.task) + self.low_is_better = kwargs.get("low_is_better", self.args.low_is_better) self.name = kwargs.get("name", "") self.result_path = f"results/custom_{self.name}" - self.state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") + self.state = create_initial_state(self.task, start_task_id=1, data_config=self.data_config, low_is_better=self.low_is_better, name=self.name) async def run_experiment(self): user_requirement = self.state["requirement"] @@ -30,6 +32,15 @@ class CustomExperimenter(Experimenter): } self.save_result(results) + def evaluate_pred_files(self, dev_pred_path, test_pred_path): + dev_preds = pd.read_csv(dev_pred_path)["target"] + test_preds = pd.read_csv(test_pred_path)["target"] + score_dict = { + "dev_score": self.evaluate_score(dev_preds, "dev"), + "test_score": self.evaluate_score(test_preds, "test") + } + return score_dict + def evaluate_predictions(self, preds, split): metric = self.state["dataset_config"]["metric"] gt_path = os.path.join(self.state["datasets_dir"][f"{split}_target"]) diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index 4473866af..678d48d6a 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -20,16 +20,24 @@ class Experimenter: async def run_experiment(self): state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") user_requirement = state["requirement"] - di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection) - await di.run(user_requirement) - - score_dict = await di.get_score() - score_dict = self.evaluate(score_dict, state) - results = { - "score_dict": score_dict, - "user_requirement": user_requirement, - "args": vars(self.args) - } + results = [] + + for i in range(self.args.num_experiments): + di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection) + await di.run(user_requirement) + score_dict = await di.get_score() + score_dict = self.evaluate(score_dict, state) + results.append({ + "idx": i, + "score_dict": score_dict, + "user_requirement": user_requirement, + "args": vars(self.args) + }) + scores = [result["score_dict"]["test_score"] for result in results] + avg_score = sum(scores) / len(scores) + best_score = max(scores) if not self.args.low_is_better else min(scores) + best_score_idx = scores.index(best_score) + results.insert(0, {"avg_score": avg_score, "best_score": best_score, "best_score_idx": best_score_idx}) self.save_result(results) def evaluate_prediction(self, split, state): diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 43c5f9868..0159abe24 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -22,18 +22,19 @@ class MCTSExperimenter(Experimenter): text += f"Best node: {best_node}, score: {best_node.raw_reward}\n" text += f"Dev best node: {dev_best_node}, score: {dev_best_node.raw_reward}\n" print(text) - self.save_tree(text) + if self.args.rollouts > 0: + self.save_tree(text) - results = { - "best_node": best_node.id, - "best_node_score": best_node.raw_reward, - "dev_best_node": dev_best_node.id, - "dev_best_node_score": dev_best_node.raw_reward, - "num_generated_codes": num_generated_codes, - "user_requirement": best_node.state["requirement"], - "args": vars(self.args) - } - self.save_result(results) + results = { + "best_node": best_node.id, + "best_node_score": best_node.raw_reward, + "dev_best_node": dev_best_node.id, + "dev_best_node_score": dev_best_node.raw_reward, + "num_generated_codes": num_generated_codes, + "user_requirement": best_node.state["requirement"], + "args": vars(self.args) + } + self.save_result(results)