format code

This commit is contained in:
Yizhou Chi 2024-09-04 17:52:02 +08:00
parent fcd1ba66a6
commit ab8a1d6824
17 changed files with 433 additions and 396 deletions

View file

@ -1,5 +1,6 @@
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, mean_squared_error
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, roc_auc_score
def evaluate_score(pred, gt, metric):
if metric == "accuracy":
@ -20,4 +21,4 @@ def evaluate_score(pred, gt, metric):
elif metric == "log rmse":
return mean_squared_error(np.log1p(gt), np.log1p(pred), squared=False)
else:
raise ValueError(f"Metric {metric} not supported")
raise ValueError(f"Metric {metric} not supported")

View file

@ -1,7 +1,7 @@
from expo.MCTS import Node, MCTS
import textwrap
from expo.MCTS import Node
NODE_TEMPLATE = """\
[Node {id}]
Plans:
@ -11,21 +11,23 @@ Score: {score}, Visits: {num_visits}
"""
def get_role_plans(role):
plans = role.planner.plan.tasks
instruct_plans = [f"{i+1}. {task.instruction}" for i, task in enumerate(plans)]
return instruct_plans
def get_tree_text(node : Node):
def get_tree_text(node: Node):
role_dict = {}
code_set = set()
def load_role(node):
if node.id not in role_dict:
role_dict[node.id] = node.load_role()
return role_dict[node.id]
def visualize_node(node : Node, previous_plans=None):
def visualize_node(node: Node, previous_plans=None):
role = load_role(node)
node_id = node.id
plans = role.planner.plan.tasks
@ -36,7 +38,9 @@ def get_tree_text(node : Node):
simulated = role.state_saved
score = f"avg score: {node.avg_value()}, simulated score: {node.raw_reward}"
num_visits = node.visited
return NODE_TEMPLATE.format(id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits)
return NODE_TEMPLATE.format(
id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits
)
def visualize_tree(node, depth=0, previous_plans=None):
text = ""
@ -46,11 +50,10 @@ def get_tree_text(node : Node):
code_set.update({task.instruction for task in role.planner.plan.tasks})
previous_plans = get_role_plans(role)
for child in node.children:
text += textwrap.indent(visualize_tree(child, depth+1, previous_plans), "\t")
text += textwrap.indent(visualize_tree(child, depth + 1, previous_plans), "\t")
return text
num_simulations = node.visited
text = f"Number of simulations: {num_simulations}\n"
text += visualize_tree(node)
return text, len(code_set)