1. avoid circular reference

2. add greedy
This commit is contained in:
Yizhou Chi 2024-09-06 13:09:18 +08:00
parent e07ed0df8b
commit a6f71f4498
2 changed files with 23 additions and 23 deletions

9
expo/Greedy.py Normal file
View file

@ -0,0 +1,9 @@
from expo.MCTS import MCTS
class Greedy(MCTS):
def best_child(self):
if len(self.children) == 0:
return self.root_node
all_children = [child for children in self.children.values() for child in children]
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))

View file

@ -143,6 +143,7 @@ class Node:
role.state_saved = False
role.change_next_instruction(self.action)
mcts_logger.log("MCTS", f"Saving new role: {role.node_id}")
role = role.model_copy()
role.save_state(static_save=True)
async def expand(self, max_children):
@ -165,18 +166,6 @@ class Node:
node.save_new_role(new_role)
self.add_child(node)
# def evaluate_test(self):
# prediction_fpath = os.path.join(self.state["work_dir"], self.state["task"], "predictions.csv")
# predictions = pd.read_csv(prediction_fpath)["target"]
# # copy predictions.csv to the node_dir
# predictions_node_fpath = os.path.join(self.state["node_dir"], "Node-{self.id}-predictions.csv")
# predictions.to_csv(predictions_node_fpath, index=False)
# # load test_target.csv
# split_datasets_dir = self.state["datasets_dir"]
# gt = pd.read_csv(os.path.join(split_datasets_dir["test_target"]))["target"]
# metric = self.state["dataset_config"]["metric"]
# return evaluate_score(predictions, gt, metric)
def evaluate_prediction(self, split):
pred_path = os.path.join(self.state["work_dir"], self.state["task"], f"{split}_predictions.csv")
pred_node_path = os.path.join(self.state["node_dir"], f"Node-{self.id}-{split}_predictions.csv")
@ -331,14 +320,10 @@ class MCTS:
self.children[root] = []
reward = await self.simulate(root, role)
self.backpropagate(root, reward)
children = await self.expand(root)
# 目前是随机选择1个后续可以改成多个
first_leaf = random.choice(children)
reward = await self.simulate(first_leaf)
self.backpropagate(first_leaf, reward)
node, reward = await self.expand_and_simulate(root)
# self.backpropagate(node, reward)
else:
root = self.root_node
# 后续迭代使用UCT进行选择expand并模拟和反向传播
for _ in range(rollouts): # number of rollouts
mcts_logger.log("MCTS", f"Start the next rollout {_+1}")
node = self.select(root)
@ -350,13 +335,19 @@ class MCTS:
mcts_logger.log("MCTS", f"Terminal node's reward: {reward}")
self.backpropagate(node, reward)
else:
if node.visited > 0:
children = await self.expand(node)
node = random.choice(children)
reward = await self.simulate(node)
self.backpropagate(node, reward)
node, reward = await self.expand_and_simulate(node)
# self.backpropagate(node, reward)
return self.best_path(root)
async def expand_and_simulate(self, node):
# Expand and randomly select a child node, then simulate it
if node.visited > 0:
children = await self.expand(node)
node = random.choice(children)
reward = await self.simulate(node)
self.backpropagate(node, reward)
return node, reward
def load_tree(self):
def load_children_node(node):
mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}")