mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
add step score
This commit is contained in:
parent
68c672d438
commit
2b67355358
2 changed files with 32 additions and 1 deletions
31
expo/MCTS.py
31
expo/MCTS.py
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
|
|
@ -240,6 +241,7 @@ class MCTS:
|
|||
max_depth: int = 5
|
||||
c_explore: float = 1.4
|
||||
c_unvisited: float = 0.8
|
||||
node_order: list = []
|
||||
|
||||
def __init__(self, root_node, max_depth, use_fixed_insights):
|
||||
self.root_node = root_node
|
||||
|
|
@ -306,11 +308,32 @@ class MCTS:
|
|||
_, global_best_child = bfs(root, global_best_score, best_child, "test_score")
|
||||
_, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score")
|
||||
|
||||
return {"dev_best": dev_best_child, "global_best": global_best_child}
|
||||
return {"dev_best": dev_best_child, "global_best": global_best_child, "scores": self.get_score_order_dict()}
|
||||
|
||||
def get_num_simulations(self):
|
||||
return self.root_node.visited
|
||||
|
||||
def save_node_order(self, node_id):
|
||||
self.node_order.append(node_id)
|
||||
with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f:
|
||||
json.dump(self.node_order, f)
|
||||
|
||||
def load_node_order(self):
|
||||
with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "r") as f:
|
||||
self.node_order = json.load(f)
|
||||
|
||||
def get_score_order_dict(self):
|
||||
scores = {"dev": [], "test": [], "dev_raw": [], "test_raw": []}
|
||||
for node_id in self.node_order:
|
||||
node = Node(parent=None, state=self.root_node.state, action=None, value=0)
|
||||
node.id = node_id
|
||||
node = node.load_node()
|
||||
scores["dev"].append(node.normalized_reward["dev_score"])
|
||||
scores["test"].append(node.normalized_reward["test_score"])
|
||||
scores["dev_raw"].append(node.raw_reward["dev_score"])
|
||||
scores["test_raw"].append(node.raw_reward["test_score"])
|
||||
return scores
|
||||
|
||||
async def search(self, state, rollouts, load_tree=False, reflection=False):
|
||||
role, root = initialize_di_root_node(state, reflection=reflection)
|
||||
self.root_node = root
|
||||
|
|
@ -329,8 +352,12 @@ class MCTS:
|
|||
self.backpropagate(root, reward)
|
||||
node, reward = await self.expand_and_simulate(root)
|
||||
# self.backpropagate(node, reward)
|
||||
self.save_node_order(root.id)
|
||||
self.save_node_order(node.id)
|
||||
else:
|
||||
root = self.root_node
|
||||
self.load_node_order()
|
||||
|
||||
for _ in range(rollouts): # number of rollouts
|
||||
mcts_logger.log("MCTS", f"Start the next rollout {_+1}")
|
||||
node = self.select(root)
|
||||
|
|
@ -344,6 +371,7 @@ class MCTS:
|
|||
else:
|
||||
node, reward = await self.expand_and_simulate(node)
|
||||
# self.backpropagate(node, reward)
|
||||
self.save_node_order(node.id)
|
||||
return self.best_path(root)
|
||||
|
||||
async def expand_and_simulate(self, node):
|
||||
|
|
@ -373,6 +401,7 @@ class MCTS:
|
|||
self.root_node = pickle.load(f)
|
||||
self.children[self.root_node] = self.root_node.children
|
||||
load_children_node(self.root_node)
|
||||
|
||||
if self.children:
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class MCTSExperimenter(Experimenter):
|
|||
)
|
||||
best_node = best_nodes["global_best"]
|
||||
dev_best_node = best_nodes["dev_best"]
|
||||
score_dict = best_nodes["scores"]
|
||||
|
||||
self.copy_notebook(best_node, "best")
|
||||
self.copy_notebook(dev_best_node, "dev_best")
|
||||
|
|
@ -50,6 +51,7 @@ class MCTSExperimenter(Experimenter):
|
|||
"user_requirement": best_node.state["requirement"],
|
||||
"tree_text": text,
|
||||
"args": vars(self.args),
|
||||
"scores": score_dict,
|
||||
}
|
||||
]
|
||||
self.save_result(results)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue