2023-12-01 00:44:47 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import fire
from metagpt . roles . kaggle_manager import KaggleManager
from metagpt . roles . ml_engineer import MLEngineer
from metagpt . team import Team
2024-01-10 14:15:30 +08:00
2023-12-01 00:44:47 +08:00
async def main (
# competition: str,
# data_desc: str,
# requirement: str,
2023-12-02 01:34:22 +08:00
investment : float = 5.0 ,
2023-12-04 14:29:47 +08:00
n_round : int = 10 ,
auto_run : bool = False ,
2023-12-01 00:44:47 +08:00
) :
competition , data_desc , requirement = (
" titanic " ,
" Training set is train.csv. \n Test set is test.csv. We also include gender_submission.csv, a set of predictions that assume all and only female passengers survive, as an example of what a submission file should look like. " ,
2023-12-11 16:13:34 +08:00
# "Run EDA on the train dataset, train a model to predict survival (20% as validation) and save it, predict the test set using saved model, save the test result according to format",
2023-12-04 14:29:47 +08:00
# "generate a random prediction, replace the Survived column of gender_submission.csv, and save the prediction to a new submission file",
2024-01-10 14:15:30 +08:00
" Score as high as possible for the provided dataset, save the test prediction to a csv with two columns PassengerId and Survived " ,
2023-12-01 00:44:47 +08:00
)
team = Team ( )
team . hire (
[
KaggleManager ( competition = competition , data_desc = data_desc ) ,
2023-12-04 14:29:47 +08:00
MLEngineer ( goal = requirement , auto_run = auto_run ) ,
2023-12-01 00:44:47 +08:00
]
)
team . invest ( investment )
team . start_project ( requirement )
await team . run ( n_round = n_round )
2024-01-10 14:15:30 +08:00
if __name__ == " __main__ " :
2023-12-01 00:44:47 +08:00
fire . Fire ( main )