rename classes and functions

This commit is contained in:
Cyzus Chi 2024-10-26 01:06:48 +08:00
parent 76029782cc
commit 4d1a6f4c2b
5 changed files with 86 additions and 47 deletions

View file

@ -1,19 +0,0 @@
import random
from metagpt.ext.sela.MCTS import MCTS
class Greedy(MCTS):
def best_child(self):
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))
class Random(MCTS):
def best_child(self):
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return random.choice(all_children)

View file

@ -4,7 +4,7 @@ import pandas as pd
from metagpt.ext.sela.evaluation.evaluation import evaluate_score
from metagpt.ext.sela.experimenter.experimenter import Experimenter
from metagpt.ext.sela.MCTS import create_initial_state
from metagpt.ext.sela.search.tree_search import create_initial_state
class CustomExperimenter(Experimenter):

View file

@ -6,8 +6,7 @@ from metagpt.ext.sela.evaluation.evaluation import (
)
from metagpt.ext.sela.evaluation.visualize_mcts import get_tree_text
from metagpt.ext.sela.experimenter.experimenter import Experimenter
from metagpt.ext.sela.Greedy import Greedy, Random
from metagpt.ext.sela.MCTS import MCTS
from metagpt.ext.sela.search.search_algorithm import Greedy, Random, MCTS
class MCTSExperimenter(Experimenter):

View file

@ -0,0 +1,31 @@
import numpy as np
from metagpt.ext.sela.search.tree_search import BaseTreeSearch, Node
class Greedy(BaseTreeSearch):
def best_child(self):
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))
class Random(BaseTreeSearch):
def best_child(self):
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return np.random.choice(all_children)
class MCTS(BaseTreeSearch):
def best_child(self):
def uct(node: Node):
n_visits = node.visited if node.visited else self.c_unvisited
avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited
return avg_value + self.c_explore * np.sqrt(np.log(node.parent.visited) / n_visits)
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return max(all_children, key=uct)

View file

@ -1,8 +1,6 @@
import json
import math
import os
import pickle
import random
import shutil
import numpy as np
@ -18,7 +16,30 @@ from metagpt.tools.tool_recommend import ToolRecommender
from metagpt.utils.common import read_json_file
def initialize_di_root_node(state, reflection: bool = True):
def initialize_di_root_node(state: dict, reflection: bool = True):
"""
Initialize the root node of the decision tree.
Args:
state (dict): The initial state of the tree, containing:
- task (str): The task to be performed (e.g., "titanic").
- work_dir (str): The working directory.
- node_dir (str): The directory for the node.
- dataset_config (dict): The configuration of the dataset.
- datasets_dir (str): The directory of the datasets.
- exp_pool_path (str): The path to the experiment pool.
- requirement (str): The requirement for the task.
- has_run (bool): Whether the task has run.
- start_task_id (int): The ID of the starting task.
- low_is_better (bool): Whether a lower score is better.
- role_timeout (int): The timeout for the role.
- external_eval (bool): Whether to use external evaluation.
- custom_dataset_dir (str): The directory of the custom dataset.
reflection (bool, optional): Whether to use reflection. Defaults to True.
Returns:
tuple: A tuple containing the ResearchAssistant role and the root Node.
"""
role = ResearchAssistant(
node_id="0",
start_task_id=state["start_task_id"],
@ -29,7 +50,21 @@ def initialize_di_root_node(state, reflection: bool = True):
return role, Node(parent=None, state=state, action=None, value=0)
def create_initial_state(task, start_task_id, data_config, args):
def create_initial_state(task: str, start_task_id: int, data_config: dict, args):
"""
Create the initial state of the tree.
Args:
task (str): The task to be performed.
start_task_id (int): The ID of the starting task.
data_config (dict): The configuration of the data.
Expected keys: 'datasets', 'work_dir', 'role_dir'.
args (Namespace): The arguments passed to the program.
Expected attributes: 'external_eval', 'custom_dataset_dir', 'special_instruction', 'name', 'low_is_better', 'role_timeout'.
Returns:
dict: The initial state of the tree.
"""
external_eval = args.external_eval
if args.custom_dataset_dir:
@ -69,7 +104,6 @@ def create_initial_state(task, start_task_id, data_config, args):
os.makedirs(initial_state["node_dir"], exist_ok=True)
return initial_state
class Node:
state: dict = {}
action: str = None
@ -225,7 +259,7 @@ class Node:
self.get_and_move_predictions("test")
return score_dict
async def run_node(self, role=None):
async def run_node(self, role: ResearchAssistant = None):
if self.is_terminal() and role is not None:
if role.state_saved:
return self.raw_reward
@ -272,7 +306,9 @@ class Node:
return score_dict, result_dict
class MCTS:
class BaseTreeSearch:
# data_path
root_node: Node = None
children: dict = {}
@ -283,7 +319,7 @@ class MCTS:
# insight generator
instruction_generator: InstructionGenerator = None
def __init__(self, root_node, max_depth, use_fixed_insights):
def __init__(self, root_node: Node, max_depth: int, use_fixed_insights: bool):
self.root_node = root_node
self.max_depth = max_depth
self.use_fixed_insights = use_fixed_insights
@ -294,15 +330,7 @@ class MCTS:
return node
def best_child(self):
def uct(node: Node):
n_visits = node.visited if node.visited else self.c_unvisited
avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited
return avg_value + self.c_explore * math.sqrt(math.log(node.parent.visited) / n_visits)
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return max(all_children, key=uct)
raise NotImplementedError
async def expand(self, node: Node, max_children=5):
await node.expand(max_children, self.instruction_generator)
@ -314,13 +342,13 @@ class MCTS:
"Returns the reward for a random simulation (to completion) of `node`"
mcts_logger.log("MCTS", f"Start simulating node {node.id}:")
while node.children:
node = random.choice(node.children)
node = np.random.choice(node.children)
reward, result_dict = await node.run_node(role)
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
# TODO: add new insights
return reward
def backpropagate(self, node: Node, reward):
def backpropagate(self, node: Node, reward: dict):
child_node = node
node.update(reward)
node = node.parent
@ -333,7 +361,7 @@ class MCTS:
global_best_score = root.normalized_reward["test_score"]
dev_best_score = root.normalized_reward["dev_score"]
def bfs(node: Node, best_score, best_child: Node, split):
def bfs(node: Node, best_score: float, best_child: Node, split: str):
assert split in ["test_score", "dev_score"]
if node not in self.children:
return best_score, best_child
@ -354,7 +382,7 @@ class MCTS:
def get_num_simulations(self):
return self.root_node.visited
def save_node_order(self, node_id):
def save_node_order(self, node_id: str):
self.node_order.append(node_id)
with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f:
json.dump(self.node_order, f)
@ -375,7 +403,7 @@ class MCTS:
scores["test_raw"].append(node.raw_reward["test_score"])
return scores
async def search(self, state, args):
async def search(self, state: dict, args):
reflection = args.reflection
load_tree = args.load_tree
rollouts = args.rollouts
@ -424,17 +452,17 @@ class MCTS:
self.save_node_order(node.id)
return self.best_path(root)
async def expand_and_simulate(self, node):
async def expand_and_simulate(self, node: Node):
# Expand and randomly select a child node, then simulate it
if node.visited > 0:
children = await self.expand(node)
node = random.choice(children)
node = np.random.choice(children)
reward = await self.simulate(node)
self.backpropagate(node, reward)
return node, reward
def load_tree(self):
def load_children_node(node):
def load_children_node(node: Node):
mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}")
if node.is_terminal() or not node.children:
return