mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
add greedy to run_experiment; add save_notebook to experimenter.py
This commit is contained in:
parent
df6fe9854d
commit
9728b3a891
6 changed files with 30 additions and 11 deletions
|
|
@ -111,12 +111,12 @@ def get_split_dataset_path(dataset_name, config):
|
|||
|
||||
|
||||
def get_user_requirement(task_name, config):
|
||||
datasets_dir = config["datasets_dir"]
|
||||
# datasets_dir = config["datasets_dir"]
|
||||
if task_name in config["datasets"]:
|
||||
dataset = config["datasets"][task_name]
|
||||
data_path = os.path.join(datasets_dir, dataset["dataset"])
|
||||
# data_path = os.path.join(datasets_dir, dataset["dataset"])
|
||||
user_requirement = dataset["user_requirement"]
|
||||
return data_path, user_requirement
|
||||
return user_requirement
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dataset {task_name} not found in config file. Available datasets: {config['datasets'].keys()}"
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class AugExperimenter(Experimenter):
|
|||
di.role_dir = f"{di.role_dir}_{self.args.task}"
|
||||
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
|
||||
print(requirement)
|
||||
score_dict = await self.run_di(di, requirement)
|
||||
score_dict = await self.run_di(di, requirement, run_idx=i)
|
||||
results.append(
|
||||
{
|
||||
"idx": i,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pandas as pd
|
|||
from expo.evaluation.evaluation import evaluate_score
|
||||
from expo.MCTS import create_initial_state
|
||||
from expo.research_assistant import ResearchAssistant
|
||||
from expo.utils import DATA_CONFIG
|
||||
from expo.utils import DATA_CONFIG, save_notebook
|
||||
|
||||
|
||||
class Experimenter:
|
||||
|
|
@ -26,7 +26,7 @@ class Experimenter:
|
|||
name="",
|
||||
)
|
||||
|
||||
async def run_di(self, di, user_requirement):
|
||||
async def run_di(self, di, user_requirement, run_idx):
|
||||
max_retries = 3
|
||||
num_runs = 1
|
||||
run_finished = False
|
||||
|
|
@ -39,6 +39,7 @@ class Experimenter:
|
|||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
num_runs += 1
|
||||
save_notebook(role=di, save_dir=self.result_path, name=f"{self.args.task}_{self.start_time}_{run_idx}")
|
||||
if not run_finished:
|
||||
score_dict = {"train_score": -1, "dev_score": -1, "test_score": -1, "score": -1}
|
||||
return score_dict
|
||||
|
|
@ -50,7 +51,7 @@ class Experimenter:
|
|||
|
||||
for i in range(self.args.num_experiments):
|
||||
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
|
||||
score_dict = await self.run_di(di, user_requirement)
|
||||
score_dict = await self.run_di(di, user_requirement, run_idx=i)
|
||||
results.append(
|
||||
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,21 @@
|
|||
from expo.evaluation.visualize_mcts import get_tree_text
|
||||
from expo.experimenter.experimenter import Experimenter
|
||||
from expo.Greedy import Greedy
|
||||
from expo.MCTS import MCTS
|
||||
|
||||
|
||||
class MCTSExperimenter(Experimenter):
|
||||
result_path: str = "results/mcts"
|
||||
|
||||
def __init__(self, args, greedy=False, **kwargs):
|
||||
super().__init__(args, **kwargs)
|
||||
self.greedy = greedy
|
||||
|
||||
async def run_experiment(self):
|
||||
mcts = MCTS(root_node=None, max_depth=5)
|
||||
if self.greedy:
|
||||
mcts = Greedy(root_node=None, max_depth=5)
|
||||
else:
|
||||
mcts = MCTS(root_node=None, max_depth=5)
|
||||
best_nodes = await mcts.search(
|
||||
self.args.task,
|
||||
self.data_config,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from expo.experimenter.mcts import MCTSExperimenter
|
|||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--name", type=str, default="")
|
||||
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom"])
|
||||
parser.add_argument("--exp_mode", type=str, default="mcts", choices=["mcts", "aug", "base", "custom", "greedy"])
|
||||
get_di_args(parser)
|
||||
get_mcts_args(parser)
|
||||
get_aug_exp_args(parser)
|
||||
|
|
@ -41,6 +41,8 @@ def get_di_args(parser):
|
|||
async def main(args):
|
||||
if args.exp_mode == "mcts":
|
||||
experimenter = MCTSExperimenter(args)
|
||||
elif args.exp_mode == "greedy":
|
||||
experimenter = MCTSExperimenter(args, greedy=True)
|
||||
elif args.exp_mode == "aug":
|
||||
experimenter = AugExperimenter(args)
|
||||
elif args.exp_mode == "base":
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ from pathlib import Path
|
|||
import nbformat
|
||||
import yaml
|
||||
from loguru import logger as _logger
|
||||
|
||||
# from nbclient import NotebookClient
|
||||
from nbclient import NotebookClient
|
||||
from nbformat.notebooknode import NotebookNode
|
||||
|
||||
from metagpt.roles.role import Role
|
||||
|
|
@ -92,15 +91,24 @@ def process_cells(nb: NotebookNode) -> NotebookNode:
|
|||
|
||||
def save_notebook(role: Role, save_dir: str = "", name: str = ""):
|
||||
save_dir = Path(save_dir)
|
||||
tasks = role.planner.plan.tasks
|
||||
codes = [task.code for task in tasks if task.code]
|
||||
clean_nb = nbformat.v4.new_notebook()
|
||||
for code in codes:
|
||||
clean_nb.cells.append(nbformat.v4.new_code_cell(code))
|
||||
nb = process_cells(role.execute_code.nb)
|
||||
file_path = save_dir / f"{name}.ipynb"
|
||||
clean_file_path = save_dir / f"{name}_clean.ipynb"
|
||||
nbformat.write(nb, file_path)
|
||||
nbformat.write(clean_nb, clean_file_path)
|
||||
|
||||
|
||||
async def load_execute_notebook(role):
|
||||
tasks = role.planner.plan.tasks
|
||||
codes = [task.code for task in tasks if task.code]
|
||||
executor = role.execute_code
|
||||
executor.nb = nbformat.v4.new_notebook()
|
||||
executor.nb.client = NotebookClient(executor.nb)
|
||||
# await executor.build()
|
||||
for code in codes:
|
||||
outputs, success = await executor.run(code)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue