update prompt (specify whether each set has target label)

This commit is contained in:
Yizhou Chi 2024-09-03 16:31:03 +08:00
parent df877c973e
commit 6972afb755
5 changed files with 22 additions and 21 deletions

View file

@ -240,7 +240,7 @@ class MCTS():
all_children = [child for children in self.children.values() for child in children]
return max(all_children, key=uct)
async def expand(self, node : Node, max_children=4):
async def expand(self, node : Node, max_children=5):
await node.expand(max_children)
if node not in self.children or not self.children[node]:
self.children[node] = node.children
@ -273,7 +273,7 @@ class MCTS():
return best_score, best_child
for child in self.children[node]:
score = child.normalized_reward[split]
print(child.id, score)
print(child.id, split, score)
if score > best_score:
best_score = score
best_child = child