diff --git a/metagpt/ext/sela/Greedy.py b/metagpt/ext/sela/Greedy.py deleted file mode 100644 index c10c4248c..000000000 --- a/metagpt/ext/sela/Greedy.py +++ /dev/null @@ -1,19 +0,0 @@ -import random - -from metagpt.ext.sela.MCTS import MCTS - - -class Greedy(MCTS): - def best_child(self): - if len(self.children) == 0: - return self.root_node - all_children = [child for children in self.children.values() for child in children] - return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0)) - - -class Random(MCTS): - def best_child(self): - if len(self.children) == 0: - return self.root_node - all_children = [child for children in self.children.values() for child in children] - return random.choice(all_children) diff --git a/metagpt/ext/sela/experimenter/custom.py b/metagpt/ext/sela/experimenter/custom.py index 4d69e286f..70df1a78e 100644 --- a/metagpt/ext/sela/experimenter/custom.py +++ b/metagpt/ext/sela/experimenter/custom.py @@ -4,7 +4,7 @@ import pandas as pd from metagpt.ext.sela.evaluation.evaluation import evaluate_score from metagpt.ext.sela.experimenter.experimenter import Experimenter -from metagpt.ext.sela.MCTS import create_initial_state +from metagpt.ext.sela.search.tree_search import create_initial_state class CustomExperimenter(Experimenter): diff --git a/metagpt/ext/sela/experimenter/mcts.py b/metagpt/ext/sela/experimenter/mcts.py index b83aa1392..f8f9f9fd1 100644 --- a/metagpt/ext/sela/experimenter/mcts.py +++ b/metagpt/ext/sela/experimenter/mcts.py @@ -6,8 +6,7 @@ from metagpt.ext.sela.evaluation.evaluation import ( ) from metagpt.ext.sela.evaluation.visualize_mcts import get_tree_text from metagpt.ext.sela.experimenter.experimenter import Experimenter -from metagpt.ext.sela.Greedy import Greedy, Random -from metagpt.ext.sela.MCTS import MCTS +from metagpt.ext.sela.search.search_algorithm import Greedy, Random, MCTS class MCTSExperimenter(Experimenter): diff --git a/metagpt/ext/sela/search/search_algorithm.py b/metagpt/ext/sela/search/search_algorithm.py new file mode 100644 index 000000000..675cc7c5f --- /dev/null +++ b/metagpt/ext/sela/search/search_algorithm.py @@ -0,0 +1,31 @@ +import numpy as np +from metagpt.ext.sela.search.tree_search import BaseTreeSearch, Node + + +class Greedy(BaseTreeSearch): + def best_child(self): + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0)) + + +class Random(BaseTreeSearch): + def best_child(self): + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return np.random.choice(all_children) + + +class MCTS(BaseTreeSearch): + def best_child(self): + def uct(node: Node): + n_visits = node.visited if node.visited else self.c_unvisited + avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited + return avg_value + self.c_explore * np.sqrt(np.log(node.parent.visited) / n_visits) + + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return max(all_children, key=uct) diff --git a/metagpt/ext/sela/MCTS.py b/metagpt/ext/sela/search/tree_search.py similarity index 87% rename from metagpt/ext/sela/MCTS.py rename to metagpt/ext/sela/search/tree_search.py index ac4b68024..a27495230 100644 --- a/metagpt/ext/sela/MCTS.py +++ b/metagpt/ext/sela/search/tree_search.py @@ -1,8 +1,6 @@ import json -import math import os import pickle -import random import shutil import numpy as np @@ -18,7 +16,30 @@ from metagpt.tools.tool_recommend import ToolRecommender from metagpt.utils.common import read_json_file -def initialize_di_root_node(state, reflection: bool = True): +def initialize_di_root_node(state: dict, reflection: bool = True): + """ + Initialize the root node of the decision tree. + + Args: + state (dict): The initial state of the tree, containing: + - task (str): The task to be performed (e.g., "titanic"). + - work_dir (str): The working directory. + - node_dir (str): The directory for the node. + - dataset_config (dict): The configuration of the dataset. + - datasets_dir (str): The directory of the datasets. + - exp_pool_path (str): The path to the experiment pool. + - requirement (str): The requirement for the task. + - has_run (bool): Whether the task has run. + - start_task_id (int): The ID of the starting task. + - low_is_better (bool): Whether a lower score is better. + - role_timeout (int): The timeout for the role. + - external_eval (bool): Whether to use external evaluation. + - custom_dataset_dir (str): The directory of the custom dataset. + reflection (bool, optional): Whether to use reflection. Defaults to True. + + Returns: + tuple: A tuple containing the ResearchAssistant role and the root Node. + """ role = ResearchAssistant( node_id="0", start_task_id=state["start_task_id"], @@ -29,7 +50,21 @@ def initialize_di_root_node(state, reflection: bool = True): return role, Node(parent=None, state=state, action=None, value=0) -def create_initial_state(task, start_task_id, data_config, args): +def create_initial_state(task: str, start_task_id: int, data_config: dict, args): + """ + Create the initial state of the tree. + + Args: + task (str): The task to be performed. + start_task_id (int): The ID of the starting task. + data_config (dict): The configuration of the data. + Expected keys: 'datasets', 'work_dir', 'role_dir'. + args (Namespace): The arguments passed to the program. + Expected attributes: 'external_eval', 'custom_dataset_dir', 'special_instruction', 'name', 'low_is_better', 'role_timeout'. + + Returns: + dict: The initial state of the tree. + """ external_eval = args.external_eval if args.custom_dataset_dir: @@ -69,7 +104,6 @@ def create_initial_state(task, start_task_id, data_config, args): os.makedirs(initial_state["node_dir"], exist_ok=True) return initial_state - class Node: state: dict = {} action: str = None @@ -225,7 +259,7 @@ class Node: self.get_and_move_predictions("test") return score_dict - async def run_node(self, role=None): + async def run_node(self, role: ResearchAssistant = None): if self.is_terminal() and role is not None: if role.state_saved: return self.raw_reward @@ -272,7 +306,9 @@ class Node: return score_dict, result_dict -class MCTS: + + +class BaseTreeSearch: # data_path root_node: Node = None children: dict = {} @@ -283,7 +319,7 @@ class MCTS: # insight generator instruction_generator: InstructionGenerator = None - def __init__(self, root_node, max_depth, use_fixed_insights): + def __init__(self, root_node: Node, max_depth: int, use_fixed_insights: bool): self.root_node = root_node self.max_depth = max_depth self.use_fixed_insights = use_fixed_insights @@ -294,15 +330,7 @@ class MCTS: return node def best_child(self): - def uct(node: Node): - n_visits = node.visited if node.visited else self.c_unvisited - avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited - return avg_value + self.c_explore * math.sqrt(math.log(node.parent.visited) / n_visits) - - if len(self.children) == 0: - return self.root_node - all_children = [child for children in self.children.values() for child in children] - return max(all_children, key=uct) + raise NotImplementedError async def expand(self, node: Node, max_children=5): await node.expand(max_children, self.instruction_generator) @@ -314,13 +342,13 @@ class MCTS: "Returns the reward for a random simulation (to completion) of `node`" mcts_logger.log("MCTS", f"Start simulating node {node.id}:") while node.children: - node = random.choice(node.children) + node = np.random.choice(node.children) reward, result_dict = await node.run_node(role) mcts_logger.log("MCTS", f"Simulated node's reward: {reward}") # TODO: add new insights return reward - def backpropagate(self, node: Node, reward): + def backpropagate(self, node: Node, reward: dict): child_node = node node.update(reward) node = node.parent @@ -333,7 +361,7 @@ class MCTS: global_best_score = root.normalized_reward["test_score"] dev_best_score = root.normalized_reward["dev_score"] - def bfs(node: Node, best_score, best_child: Node, split): + def bfs(node: Node, best_score: float, best_child: Node, split: str): assert split in ["test_score", "dev_score"] if node not in self.children: return best_score, best_child @@ -354,7 +382,7 @@ class MCTS: def get_num_simulations(self): return self.root_node.visited - def save_node_order(self, node_id): + def save_node_order(self, node_id: str): self.node_order.append(node_id) with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f: json.dump(self.node_order, f) @@ -375,7 +403,7 @@ class MCTS: scores["test_raw"].append(node.raw_reward["test_score"]) return scores - async def search(self, state, args): + async def search(self, state: dict, args): reflection = args.reflection load_tree = args.load_tree rollouts = args.rollouts @@ -424,17 +452,17 @@ class MCTS: self.save_node_order(node.id) return self.best_path(root) - async def expand_and_simulate(self, node): + async def expand_and_simulate(self, node: Node): # Expand and randomly select a child node, then simulate it if node.visited > 0: children = await self.expand(node) - node = random.choice(children) + node = np.random.choice(children) reward = await self.simulate(node) self.backpropagate(node, reward) return node, reward def load_tree(self): - def load_children_node(node): + def load_children_node(node: Node): mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}") if node.is_terminal() or not node.children: return