mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
1. Rewrite logger message
2. fix import
This commit is contained in:
parent
16d1bf0da0
commit
27bbc927b0
3 changed files with 25 additions and 29 deletions
50
expo/MCTS.py
50
expo/MCTS.py
|
|
@ -3,7 +3,7 @@ import math
|
|||
import os
|
||||
import pandas as pd
|
||||
from expo.research_assistant import ResearchAssistant
|
||||
from exp_optimizer.expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.dataset import get_split_dataset_path, generate_task_requirement
|
||||
from expo.evaluation.evaluation import evaluate_score
|
||||
from expo.utils import mcts_logger, load_execute_notebook, get_exp_pool_path
|
||||
|
|
@ -135,13 +135,13 @@ class Node():
|
|||
role.start_task_id = self.state['start_task_id']
|
||||
role.state_saved = False
|
||||
role.change_next_instruction(self.action)
|
||||
mcts_logger.log("MCTS", f"保存新的role: {role.node_id}")
|
||||
mcts_logger.log("MCTS", f"Saving new role: {role.node_id}")
|
||||
role.save_state(static_save=True)
|
||||
|
||||
async def expand(self, max_children):
|
||||
if self.is_fully_expanded():
|
||||
return
|
||||
insight_geneartor = InsightGenerator()
|
||||
insight_geneartor = InstructionGenerator()
|
||||
role = self.load_role()
|
||||
original_instruction = role.get_next_instruction()
|
||||
insights = await insight_geneartor.generate_new_instructions(task_id=role.start_task_id + 1,
|
||||
|
|
@ -224,7 +224,7 @@ class MCTS():
|
|||
|
||||
def select(self, node: Node):
|
||||
node = self.best_child()
|
||||
mcts_logger.log("MCTS", f"选择的叶子节点id: {node.id}")
|
||||
mcts_logger.log("MCTS", f"Selected node id: {node.id}")
|
||||
return node
|
||||
|
||||
def best_child(self):
|
||||
|
|
@ -245,9 +245,11 @@ class MCTS():
|
|||
|
||||
async def simulate(self, node : Node, role=None):
|
||||
"Returns the reward for a random simulation (to completion) of `node`"
|
||||
mcts_logger.log("MCTS", f"Start simulating node {node.id}:")
|
||||
while node.children:
|
||||
node = random.choice(node.children)
|
||||
reward = await node.run_node(role)
|
||||
reward = await node.run_node(role)
|
||||
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
|
||||
return reward
|
||||
|
||||
|
||||
|
|
@ -292,7 +294,7 @@ class MCTS():
|
|||
if load_tree:
|
||||
tree_loaded = self.load_tree()
|
||||
mcts_logger.log("MCTS", f"Number of simulations: {self.get_num_simulations()}")
|
||||
|
||||
mcts_logger.log("MCTS", f"Tree loaded: {tree_loaded}")
|
||||
|
||||
if not tree_loaded:
|
||||
rollouts -= 2
|
||||
|
|
@ -301,41 +303,36 @@ class MCTS():
|
|||
self.children[root] = []
|
||||
reward = await self.simulate(root, role)
|
||||
self.backpropagate(root, reward)
|
||||
mcts_logger.log("MCTS", f"Root node's value: {reward}")
|
||||
children = await self.expand(root)
|
||||
#目前是随机选择1个,后续可以改成多个
|
||||
first_leaf = random.choice(children)
|
||||
mcts_logger.log("MCTS", f"随机选择的叶子节点id: {first_leaf.id}")
|
||||
reward = await self.simulate(first_leaf)
|
||||
mcts_logger.log("MCTS", f"模拟完毕的叶子节点的Normalized score: {reward}")
|
||||
self.backpropagate(first_leaf, reward)
|
||||
else:
|
||||
root = self.root_node
|
||||
# 后续迭代:使用UCT进行选择,expand并模拟和反向传播
|
||||
for _ in range(rollouts): # 迭代次数
|
||||
mcts_logger.log("MCTS", f"开始第{_+1}次迭代")
|
||||
leaf = self.select(root)
|
||||
if leaf.is_terminal():
|
||||
if leaf.raw_value == 0:
|
||||
reward = await self.simulate(leaf)
|
||||
for _ in range(rollouts): # number of rollouts
|
||||
mcts_logger.log("MCTS", f"Start the next rollout {_+1}")
|
||||
node = self.select(root)
|
||||
if node.is_terminal():
|
||||
if node.raw_value == 0:
|
||||
reward = await self.simulate(node)
|
||||
else:
|
||||
reward = {"test_score": leaf.raw_value, "score": leaf.value}
|
||||
mcts_logger.log("MCTS", f"终止节点的得分为: {reward}")
|
||||
self.backpropagate(leaf, reward)
|
||||
reward = {"test_score": node.raw_value, "score": node.value}
|
||||
mcts_logger.log("MCTS", f"Terminal node's reward: {reward}")
|
||||
self.backpropagate(node, reward)
|
||||
else:
|
||||
if leaf.visited > 0:
|
||||
children = await self.expand(leaf)
|
||||
leaf = random.choice(children)
|
||||
mcts_logger.log("MCTS", f"随机选择的叶子节点id: {leaf.id}")
|
||||
reward = await self.simulate(leaf)
|
||||
mcts_logger.log("MCTS", f"模拟完毕的叶子节点{leaf.id}的Normalized score: {reward}")
|
||||
self.backpropagate(leaf, reward)
|
||||
if node.visited > 0:
|
||||
children = await self.expand(node)
|
||||
node = random.choice(children)
|
||||
reward = await self.simulate(node)
|
||||
self.backpropagate(node, reward)
|
||||
return self.best_path(root)
|
||||
|
||||
|
||||
def load_tree(self):
|
||||
def load_children_node(node):
|
||||
mcts_logger.log("MCTS", f"加载节点{node.id}的子节点:{node.children}")
|
||||
mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}")
|
||||
if node.is_terminal() or not node.children:
|
||||
return
|
||||
for child in node.children:
|
||||
|
|
@ -351,6 +348,5 @@ class MCTS():
|
|||
self.children[self.root_node] = self.root_node.children
|
||||
load_children_node(self.root_node)
|
||||
if self.children:
|
||||
mcts_logger.log("MCTS", "成功加载树")
|
||||
return True
|
||||
return False
|
||||
|
|
@ -2,7 +2,7 @@ from experimenter import Experimenter
|
|||
from expo.MCTS import create_initial_state
|
||||
from expo.dataset import generate_task_requirement
|
||||
from expo.utils import mcts_logger, load_execute_notebook, get_exp_pool_path
|
||||
from exp_optimizer.expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.research_assistant import ResearchAssistant
|
||||
|
||||
EXPS_PROMPT = """
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from expo.research_assistant import ResearchAssistant
|
|||
import asyncio
|
||||
from expo.utils import DATA_CONFIG, get_exp_pool_path
|
||||
from expo.dataset import generate_task_requirement
|
||||
from exp_optimizer.expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.insights.instruction_generator import InstructionGenerator
|
||||
from expo.MCTS import create_initial_state
|
||||
from expo.evaluation.evaluation import evaluate_score
|
||||
import json
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue