reorder import

This commit is contained in:
Cyzus Chi 2024-10-28 17:47:27 +08:00
parent e41f9342cd
commit 4c541c2e53
6 changed files with 19 additions and 11 deletions

View file

@ -6,8 +6,8 @@ import numpy as np
import pandas as pd
from metagpt.ext.sela.evaluation.evaluation import evaluate_score
from metagpt.ext.sela.search.tree_search import create_initial_state
from metagpt.ext.sela.research_assistant import ResearchAssistant
from metagpt.ext.sela.search.tree_search import create_initial_state
from metagpt.ext.sela.utils import DATA_CONFIG, save_notebook

View file

@ -6,7 +6,7 @@ from metagpt.ext.sela.evaluation.evaluation import (
)
from metagpt.ext.sela.evaluation.visualize_mcts import get_tree_text
from metagpt.ext.sela.experimenter.experimenter import Experimenter
from metagpt.ext.sela.search.search_algorithm import Greedy, Random, MCTS
from metagpt.ext.sela.search.search_algorithm import MCTS, Greedy, Random
class MCTSExperimenter(Experimenter):

View file

@ -6,9 +6,9 @@ import os
from pydantic import model_validator
from metagpt.ext.sela.utils import mcts_logger, save_notebook
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.const import SERDESER_PATH
from metagpt.ext.sela.utils import mcts_logger, save_notebook
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message, Task, TaskResult
from metagpt.utils.common import CodeParser, write_json_file

View file

@ -2,12 +2,12 @@ import argparse
import asyncio
from metagpt.ext.sela.data.custom_task import get_mle_is_lower_better, get_mle_task_id
from metagpt.ext.sela.experimenter.random_search import RandomSearchExperimenter
from metagpt.ext.sela.experimenter.autogluon import GluonExperimenter
from metagpt.ext.sela.experimenter.autosklearn import AutoSklearnExperimenter
from metagpt.ext.sela.experimenter.custom import CustomExperimenter
from metagpt.ext.sela.experimenter.experimenter import Experimenter
from metagpt.ext.sela.experimenter.mcts import MCTSExperimenter
from metagpt.ext.sela.experimenter.random_search import RandomSearchExperimenter
def get_args(cmd=True):

View file

@ -1,4 +1,5 @@
import numpy as np
from metagpt.ext.sela.search.tree_search import BaseTreeSearch, Node

View file

@ -6,8 +6,14 @@ import shutil
import numpy as np
import pandas as pd
from metagpt.ext.sela.data.custom_task import get_mle_bench_requirements, get_mle_task_id
from metagpt.ext.sela.data.dataset import generate_task_requirement, get_split_dataset_path
from metagpt.ext.sela.data.custom_task import (
get_mle_bench_requirements,
get_mle_task_id,
)
from metagpt.ext.sela.data.dataset import (
generate_task_requirement,
get_split_dataset_path,
)
from metagpt.ext.sela.evaluation.evaluation import evaluate_score
from metagpt.ext.sela.insights.instruction_generator import InstructionGenerator
from metagpt.ext.sela.research_assistant import ResearchAssistant, TimeoutException
@ -57,9 +63,9 @@ def create_initial_state(task: str, start_task_id: int, data_config: dict, args)
Args:
task (str): The task to be performed.
start_task_id (int): The ID of the starting task.
data_config (dict): The configuration of the data.
data_config (dict): The configuration of the data.
Expected keys: 'datasets', 'work_dir', 'role_dir'.
args (Namespace): The arguments passed to the program.
args (Namespace): The arguments passed to the program.
Expected attributes: 'external_eval', 'custom_dataset_dir', 'special_instruction', 'name', 'low_is_better', 'role_timeout'.
Returns:
@ -104,6 +110,7 @@ def create_initial_state(task: str, start_task_id: int, data_config: dict, args)
os.makedirs(initial_state["node_dir"], exist_ok=True)
return initial_state
class Node:
state: dict = {}
action: str = None
@ -113,7 +120,9 @@ class Node:
normalized_reward: dict = {"train_score": 0, "dev_score": 0, "test_score": 0}
parent = None
def __init__(self, parent=None, state: dict = None, action: str = None, value: float = 0, max_depth: int = 4, **kwargs):
def __init__(
self, parent=None, state: dict = None, action: str = None, value: float = 0, max_depth: int = 4, **kwargs
):
self.state = state
self.action = action
self.value = value
@ -306,8 +315,6 @@ class Node:
return score_dict, result_dict
class BaseTreeSearch:
# data_path
root_node: Node = None