diff --git a/metagpt/ext/sela/data.yaml b/metagpt/ext/sela/data.yaml index 989a07966..5f4a290ea 100644 --- a/metagpt/ext/sela/data.yaml +++ b/metagpt/ext/sela/data.yaml @@ -1,3 +1,3 @@ datasets_dir: "path/to/datasets" # path to the datasets directory -work_dir: ../workspace # path to the workspace directory +work_dir: ../../workspace # path to the workspace directory role_dir: storage/SELA # path to the role directory \ No newline at end of file diff --git a/metagpt/ext/sela/evaluation/visualize_mcts.py b/metagpt/ext/sela/evaluation/visualize_mcts.py index 62efed917..6f803a91c 100644 --- a/metagpt/ext/sela/evaluation/visualize_mcts.py +++ b/metagpt/ext/sela/evaluation/visualize_mcts.py @@ -3,7 +3,7 @@ import textwrap import matplotlib.pyplot as plt import networkx as nx -from metagpt.ext.sela.MCTS import Node +from metagpt.ext.sela.search.tree_search import Node NODE_TEMPLATE = """\ [Node {id}] diff --git a/metagpt/ext/sela/experimenter/experimenter.py b/metagpt/ext/sela/experimenter/experimenter.py index 03671c9ff..fd9122d29 100644 --- a/metagpt/ext/sela/experimenter/experimenter.py +++ b/metagpt/ext/sela/experimenter/experimenter.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd from metagpt.ext.sela.evaluation.evaluation import evaluate_score -from metagpt.ext.sela.MCTS import create_initial_state +from metagpt.ext.sela.search.tree_search import create_initial_state from metagpt.ext.sela.research_assistant import ResearchAssistant from metagpt.ext.sela.utils import DATA_CONFIG, save_notebook diff --git a/metagpt/ext/sela/experimenter/random_search.py b/metagpt/ext/sela/experimenter/random_search.py index 94fb6960b..5617ee601 100644 --- a/metagpt/ext/sela/experimenter/random_search.py +++ b/metagpt/ext/sela/experimenter/random_search.py @@ -17,7 +17,7 @@ class RandomSearchExperimenter(Experimenter): # state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") user_requirement = self.state["requirement"] exp_pool_path = get_exp_pool_path(self.args.task, self.data_config, pool_name="ds_analysis_pool") - exp_pool = InstructionGenerator.load_analysis_pool( + exp_pool = InstructionGenerator.load_insight_pool( exp_pool_path, use_fixed_insights=self.args.use_fixed_insights ) if self.args.rs_mode == "single": diff --git a/metagpt/ext/sela/research_assistant.py b/metagpt/ext/sela/research_assistant.py index 5ce1be68e..21cc46447 100644 --- a/metagpt/ext/sela/research_assistant.py +++ b/metagpt/ext/sela/research_assistant.py @@ -71,7 +71,7 @@ class ResearchAssistant(DataInterpreter): return f"Node-{self.node_id}" def get_next_instruction(self): - return self.planner.plan.tasks[self.start_task_id] + return self.planner.plan.tasks[self.start_task_id].instruction def change_next_instruction(self, new_instruction): if new_instruction is not None: diff --git a/metagpt/ext/sela/search/tree_search.py b/metagpt/ext/sela/search/tree_search.py index a27495230..08f4abb5d 100644 --- a/metagpt/ext/sela/search/tree_search.py +++ b/metagpt/ext/sela/search/tree_search.py @@ -113,7 +113,7 @@ class Node: normalized_reward: dict = {"train_score": 0, "dev_score": 0, "test_score": 0} parent = None - def __init__(self, parent=None, state=None, action=None, value=0, max_depth=4, **kwargs): + def __init__(self, parent=None, state: dict = None, action: str = None, value: float = 0, max_depth: int = 4, **kwargs): self.state = state self.action = action self.value = value