增加try catch

This commit is contained in:
Yizhou Chi 2024-09-04 16:38:33 +08:00
parent aea524b4ea
commit fcd1ba66a6
4 changed files with 52 additions and 18 deletions

View file

@ -194,13 +194,25 @@ class Node():
if self.is_terminal() and role is not None:
if role.state_saved:
return self.raw_reward
if not role:
role = self.load_role()
await load_execute_notebook(role) # execute previous notebook's code
await role.run(with_message='continue')
else:
await role.run(with_message=self.state['requirement'])
max_retries = 3
num_runs = 1
run_finished = False
while num_runs <= max_retries and not run_finished:
try:
if not role:
role = self.load_role()
await load_execute_notebook(role) # execute previous notebook's code
await role.run(with_message='continue')
else:
await role.run(with_message=self.state['requirement'])
run_finished = True
except Exception as e:
mcts_logger.log("MCTS", f"Error in running the role: {e}")
num_runs += 1
if not run_finished:
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
return {"test_score": 0, "dev_score": 0, "score": 0}
score_dict = await role.get_score()
score_dict = self.evaluate_simulation(score_dict)
self.raw_reward = score_dict

View file

@ -35,7 +35,8 @@ ### Budget
### 提示词使用
通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
- 每一个数据集里有`dataset_info.json`里面的内容需要提供给baselines以保证公平
## 3. Evaluation
@ -74,7 +75,7 @@ #### Setup
### Base DI
For setup, check 5.
- `python run_experiment.py --exp_mode base --task titanic`
- `python run_experiment.py --exp_mode base --task titanic --num_experiments 10`
### DI RandomSearch

View file

@ -18,8 +18,8 @@ class AugExperimenter(Experimenter):
result_path : str = "results/aug"
async def run_experiment(self):
state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
user_requirement = state["requirement"]
# state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
user_requirement = self.state["requirement"]
exp_pool_path = get_exp_pool_path(self.args.task, self.data_config, pool_name="ds_analysis_pool")
exp_pool = InstructionGenerator.load_analysis_pool(exp_pool_path)
if self.args.aug_mode == "single":
@ -38,9 +38,7 @@ class AugExperimenter(Experimenter):
di.role_dir = f"{di.role_dir}_{self.args.task}"
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
print(requirement)
await di.run(requirement)
score_dict = await di.get_score()
score_dict = self.evaluate(score_dict, state)
score_dict = await self.run_di(di, requirement)
results.append({
"idx": i,
"score_dict": score_dict,

View file

@ -16,17 +16,40 @@ class Experimenter:
def __init__(self, args, **kwargs):
self.args = args
self.start_time = datetime.datetime.now().strftime("%Y%m%d%H%M")
self.state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
async def run_di(self, di, user_requirement):
max_retries = 3
num_runs = 1
run_finished = False
while num_runs <= max_retries and not run_finished:
try:
await di.run(user_requirement)
score_dict = await di.get_score()
score_dict = self.evaluate(score_dict, self.state)
run_finished = True
except Exception as e:
print(f"Error: {e}")
num_runs += 1
if not run_finished:
score_dict = {
"train_score": -1,
"dev_score": -1,
"test_score": -1,
"score": -1
}
return score_dict
async def run_experiment(self):
state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
state = self.state
user_requirement = state["requirement"]
results = []
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
await di.run(user_requirement)
score_dict = await di.get_score()
score_dict = self.evaluate(score_dict, state)
score_dict = await self.run_di(di, user_requirement)
results.append({
"idx": i,
"score_dict": score_dict,