mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
format code
This commit is contained in:
parent
fcd1ba66a6
commit
ab8a1d6824
17 changed files with 433 additions and 396 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, mean_squared_error
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, roc_auc_score
|
||||
|
||||
|
||||
def evaluate_score(pred, gt, metric):
|
||||
if metric == "accuracy":
|
||||
|
|
@ -20,4 +21,4 @@ def evaluate_score(pred, gt, metric):
|
|||
elif metric == "log rmse":
|
||||
return mean_squared_error(np.log1p(gt), np.log1p(pred), squared=False)
|
||||
else:
|
||||
raise ValueError(f"Metric {metric} not supported")
|
||||
raise ValueError(f"Metric {metric} not supported")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
from expo.MCTS import Node, MCTS
|
||||
import textwrap
|
||||
|
||||
from expo.MCTS import Node
|
||||
|
||||
NODE_TEMPLATE = """\
|
||||
[Node {id}]
|
||||
Plans:
|
||||
|
|
@ -11,21 +11,23 @@ Score: {score}, Visits: {num_visits}
|
|||
|
||||
"""
|
||||
|
||||
|
||||
def get_role_plans(role):
|
||||
plans = role.planner.plan.tasks
|
||||
instruct_plans = [f"{i+1}. {task.instruction}" for i, task in enumerate(plans)]
|
||||
return instruct_plans
|
||||
|
||||
|
||||
def get_tree_text(node : Node):
|
||||
def get_tree_text(node: Node):
|
||||
role_dict = {}
|
||||
code_set = set()
|
||||
|
||||
def load_role(node):
|
||||
if node.id not in role_dict:
|
||||
role_dict[node.id] = node.load_role()
|
||||
return role_dict[node.id]
|
||||
|
||||
def visualize_node(node : Node, previous_plans=None):
|
||||
|
||||
def visualize_node(node: Node, previous_plans=None):
|
||||
role = load_role(node)
|
||||
node_id = node.id
|
||||
plans = role.planner.plan.tasks
|
||||
|
|
@ -36,7 +38,9 @@ def get_tree_text(node : Node):
|
|||
simulated = role.state_saved
|
||||
score = f"avg score: {node.avg_value()}, simulated score: {node.raw_reward}"
|
||||
num_visits = node.visited
|
||||
return NODE_TEMPLATE.format(id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits)
|
||||
return NODE_TEMPLATE.format(
|
||||
id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits
|
||||
)
|
||||
|
||||
def visualize_tree(node, depth=0, previous_plans=None):
|
||||
text = ""
|
||||
|
|
@ -46,11 +50,10 @@ def get_tree_text(node : Node):
|
|||
code_set.update({task.instruction for task in role.planner.plan.tasks})
|
||||
previous_plans = get_role_plans(role)
|
||||
for child in node.children:
|
||||
text += textwrap.indent(visualize_tree(child, depth+1, previous_plans), "\t")
|
||||
text += textwrap.indent(visualize_tree(child, depth + 1, previous_plans), "\t")
|
||||
return text
|
||||
|
||||
num_simulations = node.visited
|
||||
text = f"Number of simulations: {num_simulations}\n"
|
||||
text += visualize_tree(node)
|
||||
return text, len(code_set)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue