add greedy to run_experiment; add save_notebook to experimenter.py

This commit is contained in:
Yizhou Chi 2024-09-09 13:44:38 +08:00
parent df6fe9854d
commit 9728b3a891
6 changed files with 30 additions and 11 deletions

View file

@ -111,12 +111,12 @@ def get_split_dataset_path(dataset_name, config):
def get_user_requirement(task_name, config):
datasets_dir = config["datasets_dir"]
# datasets_dir = config["datasets_dir"]
if task_name in config["datasets"]:
dataset = config["datasets"][task_name]
data_path = os.path.join(datasets_dir, dataset["dataset"])
# data_path = os.path.join(datasets_dir, dataset["dataset"])
user_requirement = dataset["user_requirement"]
return data_path, user_requirement
return user_requirement
else:
raise ValueError(
f"Dataset {task_name} not found in config file. Available datasets: {config['datasets'].keys()}"

View file

@ -34,7 +34,7 @@ class AugExperimenter(Experimenter):
di.role_dir = f"{di.role_dir}_{self.args.task}"
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
print(requirement)
score_dict = await self.run_di(di, requirement)
score_dict = await self.run_di(di, requirement, run_idx=i)
results.append(
{
"idx": i,

View file

@ -7,7 +7,7 @@ import pandas as pd
from expo.evaluation.evaluation import evaluate_score
from expo.MCTS import create_initial_state
from expo.research_assistant import ResearchAssistant
from expo.utils import DATA_CONFIG
from expo.utils import DATA_CONFIG, save_notebook
class Experimenter:
@ -26,7 +26,7 @@ class Experimenter:
name="",
)
async def run_di(self, di, user_requirement):
async def run_di(self, di, user_requirement, run_idx):
max_retries = 3
num_runs = 1
run_finished = False
@ -39,6 +39,7 @@ class Experimenter:
except Exception as e:
print(f"Error: {e}")
num_runs += 1
save_notebook(role=di, save_dir=self.result_path, name=f"{self.args.task}_{self.start_time}_{run_idx}")
if not run_finished:
score_dict = {"train_score": -1, "dev_score": -1, "test_score": -1, "score": -1}
return score_dict
@ -50,7 +51,7 @@ class Experimenter:
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
score_dict = await self.run_di(di, user_requirement)
score_dict = await self.run_di(di, user_requirement, run_idx=i)
results.append(
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}
)

View file

@ -1,13 +1,21 @@
from expo.evaluation.visualize_mcts import get_tree_text
from expo.experimenter.experimenter import Experimenter
from expo.Greedy import Greedy
from expo.MCTS import MCTS
class MCTSExperimenter(Experimenter):
result_path: str = "results/mcts"
def __init__(self, args, greedy=False, **kwargs):
super().__init__(args, **kwargs)
self.greedy = greedy
async def run_experiment(self):
mcts = MCTS(root_node=None, max_depth=5)
if self.greedy:
mcts = Greedy(root_node=None, max_depth=5)
else:
mcts = MCTS(root_node=None, max_depth=5)
best_nodes = await mcts.search(
self.args.task,
self.data_config,

View file

@ -10,7 +10,7 @@ from expo.experimenter.mcts import MCTSExperimenter
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="")
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom"])
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy"])
get_di_args(parser)
get_mcts_args(parser)
get_aug_exp_args(parser)
@ -41,6 +41,8 @@ def get_di_args(parser):
async def main(args):
if args.exp_mode == "mcts":
experimenter = MCTSExperimenter(args)
elif args.exp_mode == "greedy":
experimenter = MCTSExperimenter(args, greedy=True)
elif args.exp_mode == "aug":
experimenter = AugExperimenter(args)
elif args.exp_mode == "base":

View file

@ -7,8 +7,7 @@ from pathlib import Path
import nbformat
import yaml
from loguru import logger as _logger
# from nbclient import NotebookClient
from nbclient import NotebookClient
from nbformat.notebooknode import NotebookNode
from metagpt.roles.role import Role
@ -92,15 +91,24 @@ def process_cells(nb: NotebookNode) -> NotebookNode:
def save_notebook(role: Role, save_dir: str = "", name: str = ""):
save_dir = Path(save_dir)
tasks = role.planner.plan.tasks
codes = [task.code for task in tasks if task.code]
clean_nb = nbformat.v4.new_notebook()
for code in codes:
clean_nb.cells.append(nbformat.v4.new_code_cell(code))
nb = process_cells(role.execute_code.nb)
file_path = save_dir / f"{name}.ipynb"
clean_file_path = save_dir / f"{name}_clean.ipynb"
nbformat.write(nb, file_path)
nbformat.write(clean_nb, clean_file_path)
async def load_execute_notebook(role):
tasks = role.planner.plan.tasks
codes = [task.code for task in tasks if task.code]
executor = role.execute_code
executor.nb = nbformat.v4.new_notebook()
executor.nb.client = NotebookClient(executor.nb)
# await executor.build()
for code in codes:
outputs, success = await executor.run(code)