1. role.py - 如果已经有plan,便不再重复生成

2. 修改prompt,让predictions.csv生成的格式与原gt格式一样
This commit is contained in:
Yizhou Chi 2024-09-02 15:15:53 +08:00
parent 0e5db1c364
commit f28908e4c5
5 changed files with 14 additions and 11 deletions

View file

@ -42,6 +42,7 @@ class Node():
value : float = 0
visited : int = 0
children : list = []
normalized_reward : dict = {"train_score": 0, "dev_score": 0, "test_score": 0}
parent = None
def __init__(self, parent=None, state = None, action=None, value = 0, max_depth=4, **kwargs):
@ -274,7 +275,7 @@ class MCTS():
if score > best_score:
best_score = score
best_child = child
best_score, best_child = bfs(child, best_score, best_child)
best_score, best_child = bfs(child, best_score, best_child, split)
return best_score, best_child
_, best_child = bfs(root, best_score, best_child, "test_score")
_, dev_best_child = bfs(root, best_score, best_child, "dev_score")