diff --git a/.gitignore b/.gitignore index e3f7c5e86..841306429 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,6 @@ cov.xml *.csv metagpt/ext/sela/results/* .chainlit/ + +metagpt/ext/aflow/data +metagpt/ext/aflow/scripts/optimized diff --git a/examples/aflow/optimize.py b/examples/aflow/optimize.py index d07eab993..8f13cfad2 100644 --- a/examples/aflow/optimize.py +++ b/examples/aflow/optimize.py @@ -78,23 +78,48 @@ def parse_args(): default=True, help="Whether to download dataset for the first time", ) + parser.add_argument( + "--opt_model_name", + type=str, + default="claude-3-5-sonnet-20240620", + help="Specifies the name of the model used for optimization tasks.", + ) + parser.add_argument( + "--exec_model_name", + type=str, + default="gpt-4o-mini", + help="Specifies the name of the model used for execution tasks.", + ) return parser.parse_args() if __name__ == "__main__": args = parse_args() - download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize) config = EXPERIMENT_CONFIGS[args.dataset] - mini_llm_config = ModelsConfig.default().get("gpt-4o-mini") - claude_llm_config = ModelsConfig.default().get("claude-3-5-sonnet-20240620") + models_config = ModelsConfig.default() + opt_llm_config = models_config.get(args.opt_model_name) + if opt_llm_config is None: + raise ValueError( + f"The optimization model '{args.opt_model_name}' was not found in the 'models' section of the configuration file. " + "Please add it to the configuration file or specify a valid model using the --opt_model_name flag. " + ) + + exec_llm_config = models_config.get(args.exec_model_name) + if exec_llm_config is None: + raise ValueError( + f"The execution model '{args.exec_model_name}' was not found in the 'models' section of the configuration file. " + "Please add it to the configuration file or specify a valid model using the --exec_model_name flag. " + ) + + download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize) optimizer = Optimizer( dataset=config.dataset, question_type=config.question_type, - opt_llm_config=claude_llm_config, - exec_llm_config=mini_llm_config, + opt_llm_config=opt_llm_config, + exec_llm_config=exec_llm_config, check_convergence=args.check_convergence, operators=config.operators, optimized_path=args.optimized_path,