Merge branch 'dev' into fix_truncate

This commit is contained in:
刘棒棒 2023-12-13 16:07:35 +08:00
commit 233b143da8
5 changed files with 112 additions and 12 deletions

View file

@ -1,13 +1,11 @@
from typing import Dict, List, Union
from typing import List
import json
import subprocess
from datetime import datetime
import fire
import re
from metagpt.roles import Role
from metagpt.actions import Action
from metagpt.schema import Message, Task, Plan
from metagpt.schema import Message, Plan
from metagpt.memory import Memory
from metagpt.logs import logger
from metagpt.actions.write_plan import WritePlan, update_plan_from_rsp, precheck_update_plan_from_rsp
@ -17,6 +15,7 @@ from metagpt.actions.execute_code import ExecutePyCode
from metagpt.roles.kaggle_manager import DownloadData, SubmitResult
from metagpt.prompts.ml_engineer import STRUCTURAL_CONTEXT
from metagpt.actions.write_code_steps import WriteCodeSteps
from metagpt.utils.save_code import save_code_file
class MLEngineer(Role):
def __init__(
@ -99,6 +98,9 @@ class MLEngineer(Role):
rsp = Message(content=summary, cause_by=SummarizeAnalysis)
self._rc.memory.add(rsp)
# save code using datetime.now or keywords related to the goal of your project (plan.goal).
project_record = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_code_file(name=project_record, code_context=self.execute_code.nb, file_format="ipynb")
return rsp
async def _write_and_exec_code(self, max_retry: int = 3):
@ -223,14 +225,13 @@ class MLEngineer(Role):
if __name__ == "__main__":
# requirement = "Run data analysis on sklearn Iris dataset, include a plot"
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
# requirement = "Run data analysis on sklearn Diabetes dataset, include a plot"
# requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy"
# requirement = "Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy"
# requirement = "Run EDA and visualization on this dataset, train a model to predict survival, report metrics on validation set (20%), dataset: workspace/titanic/train.csv"
requirement = "This is a house price dataset, your goal is to predict the sale price of a property based on its features. The target column is SalePrice. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report RMSE between the logarithm of the predicted value and the logarithm of the observed sales price on the eval data. Train data path: 'workspace/house-prices-advanced-regression-techniques/split_train.csv', eval data path: 'workspace/house-prices-advanced-regression-techniques/split_eval.csv'."
async def main(requirement: str = requirement, auto_run: bool = False):
async def main(requirement: str = requirement, auto_run: bool = True):
role = MLEngineer(goal=requirement, auto_run=auto_run)
await role.run(requirement)

View file

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# @Date : 12/12/2023 4:14 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import os
import json
import nbformat
from metagpt.const import DATA_PATH
def save_code_file(name: str, code_context: str, file_format: str = "py") -> None:
"""
Save code files to a specified path.
Args:
- name (str): The name of the folder to save the files.
- code_context (str): The code content.
- file_format (str, optional): The file format. Supports 'py' (Python file), 'json' (JSON file), and 'ipynb' (Jupyter Notebook file). Default is 'py'.
Returns:
- None
"""
# Create the folder path if it doesn't exist
os.makedirs(name=DATA_PATH / "output" / f"{name}", exist_ok=True)
# Choose to save as a Python file or a JSON file based on the file format
file_path = DATA_PATH / "output" / f"{name}/code.{file_format}"
if file_format == "py":
with open(file_path, "w", encoding="utf-8") as fp:
fp.write(code_context + "\n\n")
elif file_format == "json":
# Parse the code content as JSON and save
data = {"code": code_context}
with open(file_path, "w", encoding="utf-8") as fp:
json.dump(data, fp, indent=2)
elif file_format == "ipynb":
nbformat.write(code_context, file_path)
else:
raise ValueError("Unsupported file format. Please choose 'py', 'json', or 'ipynb'.")