mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
Refactor MCTS class to handle role running errors and improve error logging
This commit is contained in:
parent
58d7b14007
commit
c16286a006
2 changed files with 15 additions and 6 deletions
13
expo/MCTS.py
13
expo/MCTS.py
|
|
@ -3,6 +3,7 @@ import os
|
|||
import pickle
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from expo.dataset import generate_task_requirement, get_split_dataset_path
|
||||
|
|
@ -209,16 +210,20 @@ class Node:
|
|||
await role.run(with_message="continue")
|
||||
else:
|
||||
await role.run(with_message=self.state["requirement"])
|
||||
score_dict = await role.get_score()
|
||||
score_dict = self.evaluate_simulation(score_dict)
|
||||
self.raw_reward = score_dict
|
||||
run_finished = True
|
||||
except Exception as e:
|
||||
mcts_logger.log("MCTS", f"Error in running the role: {e}")
|
||||
num_runs += 1
|
||||
if not run_finished:
|
||||
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
|
||||
return {"test_score": 0, "dev_score": 0, "score": 0}
|
||||
score_dict = await role.get_score()
|
||||
score_dict = self.evaluate_simulation(score_dict)
|
||||
self.raw_reward = score_dict
|
||||
if self.state["low_is_better"]:
|
||||
score_dict = {"test_score": np.inf, "dev_score": np.inf, "score": np.inf}
|
||||
else:
|
||||
score_dict = {"test_score": 0, "dev_score": 0, "score": 0}
|
||||
self.raw_reward = score_dict
|
||||
if self.state["low_is_better"]:
|
||||
# normalized the score to be between 0 and 1, and higher is better
|
||||
def normalize_score(score):
|
||||
|
|
|
|||
|
|
@ -54,11 +54,15 @@ class Experimenter:
|
|||
{"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}
|
||||
)
|
||||
self.save_result(results) # save intermediate results
|
||||
dev_scores = [result["score_dict"]["dev_score"] for result in results]
|
||||
dev_scores = [
|
||||
result["score_dict"]["dev_score"] for result in results if result["score_dict"]["dev_score"] != -1
|
||||
]
|
||||
best_dev_score = max(dev_scores) if not self.args.low_is_better else min(dev_scores)
|
||||
best_score_idx = dev_scores.index(best_dev_score)
|
||||
|
||||
test_scores = [result["score_dict"]["test_score"] for result in results]
|
||||
test_scores = [
|
||||
result["score_dict"]["test_score"] for result in results if result["score_dict"]["dev_score"] != -1
|
||||
]
|
||||
avg_score = sum(test_scores) / len(test_scores)
|
||||
global_best_score = max(test_scores) if not self.args.low_is_better else min(test_scores)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue