mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
rename classes and functions
This commit is contained in:
parent
76029782cc
commit
4d1a6f4c2b
5 changed files with 86 additions and 47 deletions
|
|
@ -1,19 +0,0 @@
|
|||
import random
|
||||
|
||||
from metagpt.ext.sela.MCTS import MCTS
|
||||
|
||||
|
||||
class Greedy(MCTS):
|
||||
def best_child(self):
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))
|
||||
|
||||
|
||||
class Random(MCTS):
|
||||
def best_child(self):
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return random.choice(all_children)
|
||||
|
|
@ -4,7 +4,7 @@ import pandas as pd
|
|||
|
||||
from metagpt.ext.sela.evaluation.evaluation import evaluate_score
|
||||
from metagpt.ext.sela.experimenter.experimenter import Experimenter
|
||||
from metagpt.ext.sela.MCTS import create_initial_state
|
||||
from metagpt.ext.sela.search.tree_search import create_initial_state
|
||||
|
||||
|
||||
class CustomExperimenter(Experimenter):
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ from metagpt.ext.sela.evaluation.evaluation import (
|
|||
)
|
||||
from metagpt.ext.sela.evaluation.visualize_mcts import get_tree_text
|
||||
from metagpt.ext.sela.experimenter.experimenter import Experimenter
|
||||
from metagpt.ext.sela.Greedy import Greedy, Random
|
||||
from metagpt.ext.sela.MCTS import MCTS
|
||||
from metagpt.ext.sela.search.search_algorithm import Greedy, Random, MCTS
|
||||
|
||||
|
||||
class MCTSExperimenter(Experimenter):
|
||||
|
|
|
|||
31
metagpt/ext/sela/search/search_algorithm.py
Normal file
31
metagpt/ext/sela/search/search_algorithm.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import numpy as np
|
||||
from metagpt.ext.sela.search.tree_search import BaseTreeSearch, Node
|
||||
|
||||
|
||||
class Greedy(BaseTreeSearch):
|
||||
def best_child(self):
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))
|
||||
|
||||
|
||||
class Random(BaseTreeSearch):
|
||||
def best_child(self):
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return np.random.choice(all_children)
|
||||
|
||||
|
||||
class MCTS(BaseTreeSearch):
|
||||
def best_child(self):
|
||||
def uct(node: Node):
|
||||
n_visits = node.visited if node.visited else self.c_unvisited
|
||||
avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited
|
||||
return avg_value + self.c_explore * np.sqrt(np.log(node.parent.visited) / n_visits)
|
||||
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return max(all_children, key=uct)
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -18,7 +16,30 @@ from metagpt.tools.tool_recommend import ToolRecommender
|
|||
from metagpt.utils.common import read_json_file
|
||||
|
||||
|
||||
def initialize_di_root_node(state, reflection: bool = True):
|
||||
def initialize_di_root_node(state: dict, reflection: bool = True):
|
||||
"""
|
||||
Initialize the root node of the decision tree.
|
||||
|
||||
Args:
|
||||
state (dict): The initial state of the tree, containing:
|
||||
- task (str): The task to be performed (e.g., "titanic").
|
||||
- work_dir (str): The working directory.
|
||||
- node_dir (str): The directory for the node.
|
||||
- dataset_config (dict): The configuration of the dataset.
|
||||
- datasets_dir (str): The directory of the datasets.
|
||||
- exp_pool_path (str): The path to the experiment pool.
|
||||
- requirement (str): The requirement for the task.
|
||||
- has_run (bool): Whether the task has run.
|
||||
- start_task_id (int): The ID of the starting task.
|
||||
- low_is_better (bool): Whether a lower score is better.
|
||||
- role_timeout (int): The timeout for the role.
|
||||
- external_eval (bool): Whether to use external evaluation.
|
||||
- custom_dataset_dir (str): The directory of the custom dataset.
|
||||
reflection (bool, optional): Whether to use reflection. Defaults to True.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the ResearchAssistant role and the root Node.
|
||||
"""
|
||||
role = ResearchAssistant(
|
||||
node_id="0",
|
||||
start_task_id=state["start_task_id"],
|
||||
|
|
@ -29,7 +50,21 @@ def initialize_di_root_node(state, reflection: bool = True):
|
|||
return role, Node(parent=None, state=state, action=None, value=0)
|
||||
|
||||
|
||||
def create_initial_state(task, start_task_id, data_config, args):
|
||||
def create_initial_state(task: str, start_task_id: int, data_config: dict, args):
|
||||
"""
|
||||
Create the initial state of the tree.
|
||||
|
||||
Args:
|
||||
task (str): The task to be performed.
|
||||
start_task_id (int): The ID of the starting task.
|
||||
data_config (dict): The configuration of the data.
|
||||
Expected keys: 'datasets', 'work_dir', 'role_dir'.
|
||||
args (Namespace): The arguments passed to the program.
|
||||
Expected attributes: 'external_eval', 'custom_dataset_dir', 'special_instruction', 'name', 'low_is_better', 'role_timeout'.
|
||||
|
||||
Returns:
|
||||
dict: The initial state of the tree.
|
||||
"""
|
||||
external_eval = args.external_eval
|
||||
|
||||
if args.custom_dataset_dir:
|
||||
|
|
@ -69,7 +104,6 @@ def create_initial_state(task, start_task_id, data_config, args):
|
|||
os.makedirs(initial_state["node_dir"], exist_ok=True)
|
||||
return initial_state
|
||||
|
||||
|
||||
class Node:
|
||||
state: dict = {}
|
||||
action: str = None
|
||||
|
|
@ -225,7 +259,7 @@ class Node:
|
|||
self.get_and_move_predictions("test")
|
||||
return score_dict
|
||||
|
||||
async def run_node(self, role=None):
|
||||
async def run_node(self, role: ResearchAssistant = None):
|
||||
if self.is_terminal() and role is not None:
|
||||
if role.state_saved:
|
||||
return self.raw_reward
|
||||
|
|
@ -272,7 +306,9 @@ class Node:
|
|||
return score_dict, result_dict
|
||||
|
||||
|
||||
class MCTS:
|
||||
|
||||
|
||||
class BaseTreeSearch:
|
||||
# data_path
|
||||
root_node: Node = None
|
||||
children: dict = {}
|
||||
|
|
@ -283,7 +319,7 @@ class MCTS:
|
|||
# insight generator
|
||||
instruction_generator: InstructionGenerator = None
|
||||
|
||||
def __init__(self, root_node, max_depth, use_fixed_insights):
|
||||
def __init__(self, root_node: Node, max_depth: int, use_fixed_insights: bool):
|
||||
self.root_node = root_node
|
||||
self.max_depth = max_depth
|
||||
self.use_fixed_insights = use_fixed_insights
|
||||
|
|
@ -294,15 +330,7 @@ class MCTS:
|
|||
return node
|
||||
|
||||
def best_child(self):
|
||||
def uct(node: Node):
|
||||
n_visits = node.visited if node.visited else self.c_unvisited
|
||||
avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited
|
||||
return avg_value + self.c_explore * math.sqrt(math.log(node.parent.visited) / n_visits)
|
||||
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return max(all_children, key=uct)
|
||||
raise NotImplementedError
|
||||
|
||||
async def expand(self, node: Node, max_children=5):
|
||||
await node.expand(max_children, self.instruction_generator)
|
||||
|
|
@ -314,13 +342,13 @@ class MCTS:
|
|||
"Returns the reward for a random simulation (to completion) of `node`"
|
||||
mcts_logger.log("MCTS", f"Start simulating node {node.id}:")
|
||||
while node.children:
|
||||
node = random.choice(node.children)
|
||||
node = np.random.choice(node.children)
|
||||
reward, result_dict = await node.run_node(role)
|
||||
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
|
||||
# TODO: add new insights
|
||||
return reward
|
||||
|
||||
def backpropagate(self, node: Node, reward):
|
||||
def backpropagate(self, node: Node, reward: dict):
|
||||
child_node = node
|
||||
node.update(reward)
|
||||
node = node.parent
|
||||
|
|
@ -333,7 +361,7 @@ class MCTS:
|
|||
global_best_score = root.normalized_reward["test_score"]
|
||||
dev_best_score = root.normalized_reward["dev_score"]
|
||||
|
||||
def bfs(node: Node, best_score, best_child: Node, split):
|
||||
def bfs(node: Node, best_score: float, best_child: Node, split: str):
|
||||
assert split in ["test_score", "dev_score"]
|
||||
if node not in self.children:
|
||||
return best_score, best_child
|
||||
|
|
@ -354,7 +382,7 @@ class MCTS:
|
|||
def get_num_simulations(self):
|
||||
return self.root_node.visited
|
||||
|
||||
def save_node_order(self, node_id):
|
||||
def save_node_order(self, node_id: str):
|
||||
self.node_order.append(node_id)
|
||||
with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f:
|
||||
json.dump(self.node_order, f)
|
||||
|
|
@ -375,7 +403,7 @@ class MCTS:
|
|||
scores["test_raw"].append(node.raw_reward["test_score"])
|
||||
return scores
|
||||
|
||||
async def search(self, state, args):
|
||||
async def search(self, state: dict, args):
|
||||
reflection = args.reflection
|
||||
load_tree = args.load_tree
|
||||
rollouts = args.rollouts
|
||||
|
|
@ -424,17 +452,17 @@ class MCTS:
|
|||
self.save_node_order(node.id)
|
||||
return self.best_path(root)
|
||||
|
||||
async def expand_and_simulate(self, node):
|
||||
async def expand_and_simulate(self, node: Node):
|
||||
# Expand and randomly select a child node, then simulate it
|
||||
if node.visited > 0:
|
||||
children = await self.expand(node)
|
||||
node = random.choice(children)
|
||||
node = np.random.choice(children)
|
||||
reward = await self.simulate(node)
|
||||
self.backpropagate(node, reward)
|
||||
return node, reward
|
||||
|
||||
def load_tree(self):
|
||||
def load_children_node(node):
|
||||
def load_children_node(node: Node):
|
||||
mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}")
|
||||
if node.is_terminal() or not node.children:
|
||||
return
|
||||
Loading…
Add table
Add a link
Reference in a new issue