fix start task id consistency

This commit is contained in:
Yizhou Chi 2024-09-14 20:33:40 +08:00
parent 9ff9d27ab0
commit 24db19fa13
3 changed files with 8 additions and 3 deletions

View file

@ -16,12 +16,11 @@ from metagpt.utils.common import read_json_file
def initialize_di_root_node(state, reflection: bool = True):
start_task_id = 2
# state = create_initial_state(
# task, start_task_id=start_task_id, data_config=data_config, low_is_better=low_is_better, name=name
# )
role = ResearchAssistant(
node_id="0", start_task_id=start_task_id, use_reflection=reflection, role_dir=state["node_dir"]
node_id="0", start_task_id=state["start_task_id"], use_reflection=reflection, role_dir=state["node_dir"]
)
return role, Node(parent=None, state=state, action=None, value=0)
@ -208,6 +207,10 @@ class Node:
self.raw_reward = score_dict
run_finished = True
except Exception as e:
print(f"Error: {e}")
import pdb
pdb.set_trace()
mcts_logger.log("MCTS", f"Error in running the role: {e}")
num_runs += 1
if not run_finished:

View file

@ -13,6 +13,7 @@ from expo.utils import DATA_CONFIG, save_notebook
class Experimenter:
result_path: str = "results/base"
data_config = DATA_CONFIG
start_task_id = 1
def __init__(self, args, **kwargs):
self.args = args
@ -20,7 +21,7 @@ class Experimenter:
self.start_time = self.start_time_raw.strftime("%Y%m%d%H%M")
self.state = create_initial_state(
self.args.task,
start_task_id=1,
start_task_id=self.start_task_id,
data_config=self.data_config,
low_is_better=self.args.low_is_better,
name=self.args.name,

View file

@ -6,6 +6,7 @@ from expo.MCTS import MCTS
class MCTSExperimenter(Experimenter):
result_path: str = "results/mcts"
start_task_id = 2
def __init__(self, args, tree_mode=None, **kwargs):
super().__init__(args, **kwargs)