From a6f71f449873eea1c1b3486e93d84c340279fada Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Fri, 6 Sep 2024 13:09:18 +0800 Subject: [PATCH] 1. avoid circular reference 2. add greedy --- expo/Greedy.py | 9 +++++++++ expo/MCTS.py | 37 ++++++++++++++----------------------- 2 files changed, 23 insertions(+), 23 deletions(-) create mode 100644 expo/Greedy.py diff --git a/expo/Greedy.py b/expo/Greedy.py new file mode 100644 index 000000000..f6f60db01 --- /dev/null +++ b/expo/Greedy.py @@ -0,0 +1,9 @@ +from expo.MCTS import MCTS + + +class Greedy(MCTS): + def best_child(self): + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0)) diff --git a/expo/MCTS.py b/expo/MCTS.py index 365315330..3331f35fa 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -143,6 +143,7 @@ class Node: role.state_saved = False role.change_next_instruction(self.action) mcts_logger.log("MCTS", f"Saving new role: {role.node_id}") + role = role.model_copy() role.save_state(static_save=True) async def expand(self, max_children): @@ -165,18 +166,6 @@ class Node: node.save_new_role(new_role) self.add_child(node) - # def evaluate_test(self): - # prediction_fpath = os.path.join(self.state["work_dir"], self.state["task"], "predictions.csv") - # predictions = pd.read_csv(prediction_fpath)["target"] - # # copy predictions.csv to the node_dir - # predictions_node_fpath = os.path.join(self.state["node_dir"], "Node-{self.id}-predictions.csv") - # predictions.to_csv(predictions_node_fpath, index=False) - # # load test_target.csv - # split_datasets_dir = self.state["datasets_dir"] - # gt = pd.read_csv(os.path.join(split_datasets_dir["test_target"]))["target"] - # metric = self.state["dataset_config"]["metric"] - # return evaluate_score(predictions, gt, metric) - def evaluate_prediction(self, split): pred_path = os.path.join(self.state["work_dir"], self.state["task"], f"{split}_predictions.csv") pred_node_path = os.path.join(self.state["node_dir"], f"Node-{self.id}-{split}_predictions.csv") @@ -331,14 +320,10 @@ class MCTS: self.children[root] = [] reward = await self.simulate(root, role) self.backpropagate(root, reward) - children = await self.expand(root) - # 目前是随机选择1个,后续可以改成多个 - first_leaf = random.choice(children) - reward = await self.simulate(first_leaf) - self.backpropagate(first_leaf, reward) + node, reward = await self.expand_and_simulate(root) + # self.backpropagate(node, reward) else: root = self.root_node - # 后续迭代:使用UCT进行选择,expand并模拟和反向传播 for _ in range(rollouts): # number of rollouts mcts_logger.log("MCTS", f"Start the next rollout {_+1}") node = self.select(root) @@ -350,13 +335,19 @@ class MCTS: mcts_logger.log("MCTS", f"Terminal node's reward: {reward}") self.backpropagate(node, reward) else: - if node.visited > 0: - children = await self.expand(node) - node = random.choice(children) - reward = await self.simulate(node) - self.backpropagate(node, reward) + node, reward = await self.expand_and_simulate(node) + # self.backpropagate(node, reward) return self.best_path(root) + async def expand_and_simulate(self, node): + # Expand and randomly select a child node, then simulate it + if node.visited > 0: + children = await self.expand(node) + node = random.choice(children) + reward = await self.simulate(node) + self.backpropagate(node, reward) + return node, reward + def load_tree(self): def load_children_node(node): mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}")