增加try catch

This commit is contained in:
Yizhou Chi 2024-09-04 16:38:33 +08:00
parent aea524b4ea
commit fcd1ba66a6
4 changed files with 52 additions and 18 deletions

View file

@ -18,8 +18,8 @@ class AugExperimenter(Experimenter):
result_path : str = "results/aug"
async def run_experiment(self):
state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
user_requirement = state["requirement"]
# state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
user_requirement = self.state["requirement"]
exp_pool_path = get_exp_pool_path(self.args.task, self.data_config, pool_name="ds_analysis_pool")
exp_pool = InstructionGenerator.load_analysis_pool(exp_pool_path)
if self.args.aug_mode == "single":
@ -38,9 +38,7 @@ class AugExperimenter(Experimenter):
di.role_dir = f"{di.role_dir}_{self.args.task}"
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
print(requirement)
await di.run(requirement)
score_dict = await di.get_score()
score_dict = self.evaluate(score_dict, state)
score_dict = await self.run_di(di, requirement)
results.append({
"idx": i,
"score_dict": score_dict,

View file

@ -16,17 +16,40 @@ class Experimenter:
def __init__(self, args, **kwargs):
self.args = args
self.start_time = datetime.datetime.now().strftime("%Y%m%d%H%M")
self.state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
async def run_di(self, di, user_requirement):
max_retries = 3
num_runs = 1
run_finished = False
while num_runs <= max_retries and not run_finished:
try:
await di.run(user_requirement)
score_dict = await di.get_score()
score_dict = self.evaluate(score_dict, self.state)
run_finished = True
except Exception as e:
print(f"Error: {e}")
num_runs += 1
if not run_finished:
score_dict = {
"train_score": -1,
"dev_score": -1,
"test_score": -1,
"score": -1
}
return score_dict
async def run_experiment(self):
state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
state = self.state
user_requirement = state["requirement"]
results = []
for i in range(self.args.num_experiments):
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
await di.run(user_requirement)
score_dict = await di.get_score()
score_dict = self.evaluate(score_dict, state)
score_dict = await self.run_di(di, user_requirement)
results.append({
"idx": i,
"score_dict": score_dict,