diff --git a/expo/data/dataset.py b/expo/data/dataset.py index 2efaf692b..4ac931d9f 100644 --- a/expo/data/dataset.py +++ b/expo/data/dataset.py @@ -111,12 +111,12 @@ def get_split_dataset_path(dataset_name, config): def get_user_requirement(task_name, config): - datasets_dir = config["datasets_dir"] + # datasets_dir = config["datasets_dir"] if task_name in config["datasets"]: dataset = config["datasets"][task_name] - data_path = os.path.join(datasets_dir, dataset["dataset"]) + # data_path = os.path.join(datasets_dir, dataset["dataset"]) user_requirement = dataset["user_requirement"] - return data_path, user_requirement + return user_requirement else: raise ValueError( f"Dataset {task_name} not found in config file. Available datasets: {config['datasets'].keys()}" diff --git a/expo/experimenter/aug.py b/expo/experimenter/aug.py index 1bf927cc1..8312f57fc 100644 --- a/expo/experimenter/aug.py +++ b/expo/experimenter/aug.py @@ -34,7 +34,7 @@ class AugExperimenter(Experimenter): di.role_dir = f"{di.role_dir}_{self.args.task}" requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i]) print(requirement) - score_dict = await self.run_di(di, requirement) + score_dict = await self.run_di(di, requirement, run_idx=i) results.append( { "idx": i, diff --git a/expo/experimenter/experimenter.py b/expo/experimenter/experimenter.py index e81c64701..b1b5a93c0 100644 --- a/expo/experimenter/experimenter.py +++ b/expo/experimenter/experimenter.py @@ -7,7 +7,7 @@ import pandas as pd from expo.evaluation.evaluation import evaluate_score from expo.MCTS import create_initial_state from expo.research_assistant import ResearchAssistant -from expo.utils import DATA_CONFIG +from expo.utils import DATA_CONFIG, save_notebook class Experimenter: @@ -26,7 +26,7 @@ class Experimenter: name="", ) - async def run_di(self, di, user_requirement): + async def run_di(self, di, user_requirement, run_idx): max_retries = 3 num_runs = 1 run_finished = False @@ -39,6 +39,7 @@ class Experimenter: except Exception as e: print(f"Error: {e}") num_runs += 1 + save_notebook(role=di, save_dir=self.result_path, name=f"{self.args.task}_{self.start_time}_{run_idx}") if not run_finished: score_dict = {"train_score": -1, "dev_score": -1, "test_score": -1, "score": -1} return score_dict @@ -50,7 +51,7 @@ class Experimenter: for i in range(self.args.num_experiments): di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection) - score_dict = await self.run_di(di, user_requirement) + score_dict = await self.run_di(di, user_requirement, run_idx=i) results.append( {"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)} ) diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index 9db6e0807..9bf7306c4 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -1,13 +1,21 @@ from expo.evaluation.visualize_mcts import get_tree_text from expo.experimenter.experimenter import Experimenter +from expo.Greedy import Greedy from expo.MCTS import MCTS class MCTSExperimenter(Experimenter): result_path: str = "results/mcts" + def __init__(self, args, greedy=False, **kwargs): + super().__init__(args, **kwargs) + self.greedy = greedy + async def run_experiment(self): - mcts = MCTS(root_node=None, max_depth=5) + if self.greedy: + mcts = Greedy(root_node=None, max_depth=5) + else: + mcts = MCTS(root_node=None, max_depth=5) best_nodes = await mcts.search( self.args.task, self.data_config, diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 8871c04a6..83237741a 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -10,7 +10,7 @@ from expo.experimenter.mcts import MCTSExperimenter def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--name", type=str, default="") - parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom"]) + parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy"]) get_di_args(parser) get_mcts_args(parser) get_aug_exp_args(parser) @@ -41,6 +41,8 @@ def get_di_args(parser): async def main(args): if args.exp_mode == "mcts": experimenter = MCTSExperimenter(args) + elif args.exp_mode == "greedy": + experimenter = MCTSExperimenter(args, greedy=True) elif args.exp_mode == "aug": experimenter = AugExperimenter(args) elif args.exp_mode == "base": diff --git a/expo/utils.py b/expo/utils.py index d67ceb5a1..65701c3ec 100644 --- a/expo/utils.py +++ b/expo/utils.py @@ -7,8 +7,7 @@ from pathlib import Path import nbformat import yaml from loguru import logger as _logger - -# from nbclient import NotebookClient +from nbclient import NotebookClient from nbformat.notebooknode import NotebookNode from metagpt.roles.role import Role @@ -92,15 +91,24 @@ def process_cells(nb: NotebookNode) -> NotebookNode: def save_notebook(role: Role, save_dir: str = "", name: str = ""): save_dir = Path(save_dir) + tasks = role.planner.plan.tasks + codes = [task.code for task in tasks if task.code] + clean_nb = nbformat.v4.new_notebook() + for code in codes: + clean_nb.cells.append(nbformat.v4.new_code_cell(code)) nb = process_cells(role.execute_code.nb) file_path = save_dir / f"{name}.ipynb" + clean_file_path = save_dir / f"{name}_clean.ipynb" nbformat.write(nb, file_path) + nbformat.write(clean_nb, clean_file_path) async def load_execute_notebook(role): tasks = role.planner.plan.tasks codes = [task.code for task in tasks if task.code] executor = role.execute_code + executor.nb = nbformat.v4.new_notebook() + executor.nb.client = NotebookClient(executor.nb) # await executor.build() for code in codes: outputs, success = await executor.run(code)