diff --git a/expo/MCTS.py b/expo/MCTS.py index 265356f65..ef408b2dd 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -148,7 +148,7 @@ class Node: role = role.model_copy() role.save_state(static_save=True) - async def expand(self, max_children): + async def expand(self, max_children, use_fixed_insights): if self.is_fully_expanded(): return insight_geneartor = InstructionGenerator() @@ -159,7 +159,7 @@ class Node: original_instruction=original_instruction, max_num=max_children, file_path=self.state["exp_pool_path"], - use_fixed_insights=self.use_fixed_insights, + use_fixed_insights=use_fixed_insights, ) new_state = self.state.copy() new_state["start_task_id"] += 1 @@ -259,7 +259,7 @@ class MCTS: return max(all_children, key=uct) async def expand(self, node: Node, max_children=5): - await node.expand(max_children) + await node.expand(max_children, self.use_fixed_insights) if node not in self.children or not self.children[node]: self.children[node] = node.children return node.children