mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 22:02:38 +02:00
增加try catch
This commit is contained in:
parent
aea524b4ea
commit
fcd1ba66a6
4 changed files with 52 additions and 18 deletions
26
expo/MCTS.py
26
expo/MCTS.py
|
|
@ -194,13 +194,25 @@ class Node():
|
|||
if self.is_terminal() and role is not None:
|
||||
if role.state_saved:
|
||||
return self.raw_reward
|
||||
|
||||
if not role:
|
||||
role = self.load_role()
|
||||
await load_execute_notebook(role) # execute previous notebook's code
|
||||
await role.run(with_message='continue')
|
||||
else:
|
||||
await role.run(with_message=self.state['requirement'])
|
||||
|
||||
max_retries = 3
|
||||
num_runs = 1
|
||||
run_finished = False
|
||||
while num_runs <= max_retries and not run_finished:
|
||||
try:
|
||||
if not role:
|
||||
role = self.load_role()
|
||||
await load_execute_notebook(role) # execute previous notebook's code
|
||||
await role.run(with_message='continue')
|
||||
else:
|
||||
await role.run(with_message=self.state['requirement'])
|
||||
run_finished = True
|
||||
except Exception as e:
|
||||
mcts_logger.log("MCTS", f"Error in running the role: {e}")
|
||||
num_runs += 1
|
||||
if not run_finished:
|
||||
mcts_logger.log("MCTS", f"Role {role.node_id} failed to run")
|
||||
return {"test_score": 0, "dev_score": 0, "score": 0}
|
||||
score_dict = await role.get_score()
|
||||
score_dict = self.evaluate_simulation(score_dict)
|
||||
self.raw_reward = score_dict
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ ### Budget
|
|||
|
||||
### 提示词使用
|
||||
|
||||
通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
|
||||
- 通过执行`dataset.py`中的`generate_task_requirement`函数获取提示词
|
||||
- 每一个数据集里有`dataset_info.json`,里面的内容需要提供给baselines以保证公平
|
||||
|
||||
|
||||
## 3. Evaluation
|
||||
|
|
@ -74,7 +75,7 @@ #### Setup
|
|||
### Base DI
|
||||
For setup, check 5.
|
||||
|
||||
- `python run_experiment.py --exp_mode base --task titanic`
|
||||
- `python run_experiment.py --exp_mode base --task titanic --num_experiments 10`
|
||||
|
||||
|
||||
### DI RandomSearch
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class AugExperimenter(Experimenter):
|
|||
result_path : str = "results/aug"
|
||||
|
||||
async def run_experiment(self):
|
||||
state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
|
||||
user_requirement = state["requirement"]
|
||||
# state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
|
||||
user_requirement = self.state["requirement"]
|
||||
exp_pool_path = get_exp_pool_path(self.args.task, self.data_config, pool_name="ds_analysis_pool")
|
||||
exp_pool = InstructionGenerator.load_analysis_pool(exp_pool_path)
|
||||
if self.args.aug_mode == "single":
|
||||
|
|
@ -38,9 +38,7 @@ class AugExperimenter(Experimenter):
|
|||
di.role_dir = f"{di.role_dir}_{self.args.task}"
|
||||
requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i])
|
||||
print(requirement)
|
||||
await di.run(requirement)
|
||||
score_dict = await di.get_score()
|
||||
score_dict = self.evaluate(score_dict, state)
|
||||
score_dict = await self.run_di(di, requirement)
|
||||
results.append({
|
||||
"idx": i,
|
||||
"score_dict": score_dict,
|
||||
|
|
|
|||
|
|
@ -16,17 +16,40 @@ class Experimenter:
|
|||
def __init__(self, args, **kwargs):
|
||||
self.args = args
|
||||
self.start_time = datetime.datetime.now().strftime("%Y%m%d%H%M")
|
||||
self.state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
|
||||
|
||||
|
||||
async def run_di(self, di, user_requirement):
|
||||
max_retries = 3
|
||||
num_runs = 1
|
||||
run_finished = False
|
||||
while num_runs <= max_retries and not run_finished:
|
||||
try:
|
||||
await di.run(user_requirement)
|
||||
score_dict = await di.get_score()
|
||||
score_dict = self.evaluate(score_dict, self.state)
|
||||
run_finished = True
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
num_runs += 1
|
||||
if not run_finished:
|
||||
score_dict = {
|
||||
"train_score": -1,
|
||||
"dev_score": -1,
|
||||
"test_score": -1,
|
||||
"score": -1
|
||||
}
|
||||
return score_dict
|
||||
|
||||
|
||||
async def run_experiment(self):
|
||||
state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="")
|
||||
state = self.state
|
||||
user_requirement = state["requirement"]
|
||||
results = []
|
||||
|
||||
for i in range(self.args.num_experiments):
|
||||
di = ResearchAssistant(node_id="0", use_reflection=self.args.reflection)
|
||||
await di.run(user_requirement)
|
||||
score_dict = await di.get_score()
|
||||
score_dict = self.evaluate(score_dict, state)
|
||||
score_dict = await self.run_di(di, user_requirement)
|
||||
results.append({
|
||||
"idx": i,
|
||||
"score_dict": score_dict,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue