From c8858cd8d464ef2c477770f927310e1a84cc7b3c Mon Sep 17 00:00:00 2001 From: yzlin Date: Tue, 16 Jan 2024 17:54:38 +0800 Subject: [PATCH] minimize ml_engineer --- metagpt/actions/ml_da_action.py | 2 +- metagpt/actions/write_analysis_code.py | 23 +++-- metagpt/prompts/ml_engineer.py | 8 +- metagpt/roles/ml_engineer.py | 111 +++++++------------------ metagpt/tools/tool_data_type.py | 1 + metagpt/tools/tool_types.py | 6 ++ 6 files changed, 51 insertions(+), 100 deletions(-) diff --git a/metagpt/actions/ml_da_action.py b/metagpt/actions/ml_da_action.py index d4e77773f..584c4db7a 100644 --- a/metagpt/actions/ml_da_action.py +++ b/metagpt/actions/ml_da_action.py @@ -63,4 +63,4 @@ class UpdateDataColumns(Action): prompt = UPDATE_DATA_COLUMNS.format(history_code=code_context) tool_config = create_func_config(PRINT_DATA_COLUMNS) rsp = await self.llm.aask_code(prompt, **tool_config) - return rsp + return rsp["code"] diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index f4ae1e572..efd1ea163 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -155,10 +155,6 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): ) code_steps = plan.current_task.code_steps - finished_tasks = plan.get_finished_tasks() - code_context = [remove_comments(task.code) for task in finished_tasks] - code_context = "\n\n".join(code_context) - tool_catalog = {} if available_tools: @@ -189,26 +185,28 @@ class WriteCodeWithToolsML(WriteCodeWithTools): column_info: str = "", **kwargs, ) -> Tuple[List[Message], str]: - tool_type = plan.current_task.task_type - available_tools = self.available_tools.get(tool_type, {}) - special_prompt = TOOL_TYPE_USAGE_PROMPT.get(tool_type, "") + tool_type = ( + plan.current_task.task_type + ) # find tool type from task type through exact match, can extend to retrieval in the future + available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type) + special_prompt = ( + TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else "" + ) code_steps = plan.current_task.code_steps finished_tasks = plan.get_finished_tasks() code_context = [remove_comments(task.code) for task in finished_tasks] code_context = "\n\n".join(code_context) - if len(available_tools) > 0: - available_tools = {k: v["description"] for k, v in available_tools.items()} + if available_tools: + available_tools = {tool_name: tool.schema["description"] for tool_name, tool in available_tools.items()} recommend_tools = await self._tool_recommendation( plan.current_task.instruction, code_steps, available_tools ) - tool_catalog = self._parse_recommend_tools(tool_type, recommend_tools) + tool_catalog = self._parse_recommend_tools(recommend_tools) logger.info(f"Recommended tools: \n{recommend_tools}") - module_name = TOOL_TYPE_MODULE[tool_type] - prompt = ML_TOOL_USAGE_PROMPT.format( user_requirement=plan.goal, history_code=code_context, @@ -216,7 +214,6 @@ class WriteCodeWithToolsML(WriteCodeWithTools): column_info=column_info, special_prompt=special_prompt, code_steps=code_steps, - module_name=module_name, tool_catalog=tool_catalog, ) diff --git a/metagpt/prompts/ml_engineer.py b/metagpt/prompts/ml_engineer.py index ff29d5ed4..3fd895e6e 100644 --- a/metagpt/prompts/ml_engineer.py +++ b/metagpt/prompts/ml_engineer.py @@ -134,16 +134,12 @@ PRINT_DATA_COLUMNS = { "parameters": { "type": "object", "properties": { - "is_update": { - "type": "boolean", - "description": "Whether need to update the column info.", - }, "code": { "type": "string", "description": "The code to be added to a new cell in jupyter.", }, }, - "required": ["is_update", "code"], + "required": ["code"], }, } @@ -240,7 +236,7 @@ Strictly follow steps below when you writing code if it's convenient. - You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc.. # Available Tools: -Each Class tool is described in JSON format. When you call a tool, import the tool from `{module_name}` first. +Each Class tool is described in JSON format. When you call a tool, import the tool from its path first. {tool_catalog} # Output Example: diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index a60642bff..aeea39c0c 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -1,64 +1,43 @@ -from metagpt.actions.ask_review import ReviewConst from metagpt.actions.debug_code import DebugCode from metagpt.actions.execute_code import ExecutePyCode -from metagpt.actions.ml_da_action import Reflect, SummarizeAnalysis, UpdateDataColumns +from metagpt.actions.ml_da_action import UpdateDataColumns from metagpt.actions.write_analysis_code import WriteCodeWithToolsML -from metagpt.actions.write_code_steps import WriteCodeSteps from metagpt.logs import logger from metagpt.roles.code_interpreter import CodeInterpreter -from metagpt.roles.kaggle_manager import DownloadData, SubmitResult -from metagpt.schema import Message +from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.utils.common import any_to_str class MLEngineer(CodeInterpreter): - use_code_steps: bool = False - use_udfs: bool = False - data_desc: dict = {} debug_context: list = [] latest_code: str = "" def __init__(self, name="Mark", profile="MLEngineer", **kwargs): super().__init__(name=name, profile=profile, **kwargs) - # self._watch([DownloadData, SubmitResult]) # in multi-agent settings - - async def _plan_and_act(self): - ### a new attempt on the data, relevant in a multi-agent multi-turn setting ### - await self._prepare_data_context() - - ### general plan process ### - await super()._plan_and_act() - - ### summarize analysis ### - summary = await SummarizeAnalysis().run(self.planner.plan) - rsp = Message(content=summary, cause_by=SummarizeAnalysis) - self.rc.memory.add(rsp) - - return rsp - - async def _write_and_exec_code(self, max_retry: int = 3): - self.planner.current_task.code_steps = ( - await WriteCodeSteps().run(self.planner.plan) if self.use_code_steps else "" - ) - - code, result, success = await super()._write_and_exec_code(max_retry=max_retry) - - if success: - if self.use_tools and self.planner.current_task.task_type in ["data_preprocess", "feature_engineering"]: - update_success, new_code = await self._update_data_columns() - if update_success: - code = code + "\n\n" + new_code - - return code, result, success async def _write_code(self): if not self.use_tools: return await super()._write_code() - code_execution_count = sum([msg.cause_by == any_to_str(ExecutePyCode) for msg in self.working_memory.get()]) + # In a trial and errors settings, check whether this is our first attempt to tackle the task. If there is no code execution before, then it is. + is_first_trial = any_to_str(ExecutePyCode) not in [msg.cause_by for msg in self.working_memory.get()] - if code_execution_count > 0: - logger.warning("We got a bug code, now start to debug...") + if is_first_trial: + # For the first trial, write task code from scratch + column_info = await self._update_data_columns() + + logger.info("Write code with tools") + tool_context, code = await WriteCodeWithToolsML().run( + context=[], # context assembled inside the Action + plan=self.planner.plan, + column_info=column_info, + ) + self.debug_context = tool_context + cause_by = WriteCodeWithToolsML + + else: + # Previous trials resulted in error, debug and rewrite the code + logger.warning("We got a bug, now start to debug...") code = await DebugCode().run( code=self.latest_code, runtime_result=self.working_memory.get(), @@ -67,49 +46,21 @@ class MLEngineer(CodeInterpreter): logger.info(f"new code \n{code}") cause_by = DebugCode - else: - logger.info("Write code with tools") - tool_context, code = await WriteCodeWithToolsML().run( - context=[], # context assembled inside the Action - plan=self.planner.plan, - column_info=self.data_desc.get("column_info", ""), - ) - self.debug_context = tool_context - cause_by = WriteCodeWithToolsML - self.latest_code = code return code, cause_by async def _update_data_columns(self): + current_task = self.planner.plan.current_task + if current_task.task_type not in [ + ToolTypeEnum.DATA_PREPROCESS.value, + ToolTypeEnum.FEATURE_ENGINEERING.value, + ToolTypeEnum.MODEL_TRAIN.value, + ]: + return "" logger.info("Check columns in updated data") - rsp = await UpdateDataColumns().run(self.planner.plan) - is_update, code = rsp["is_update"], rsp["code"] + code = await UpdateDataColumns().run(self.planner.plan) success = False - if is_update: - result, success = await self.execute_code.run(code) - if success: - print(result) - self.data_desc["column_info"] = result - return success, code - - async def _prepare_data_context(self): - memories = self.get_memories() - if memories: - latest_event = memories[-1].cause_by - if latest_event == DownloadData: - self.planner.plan.context = memories[-1].content - elif latest_event == SubmitResult: - # self reflect on previous plan outcomes and think about how to improve the plan, add to working memory - await self._reflect() - - # get feedback for improvement from human, add to working memory - await self.planner.ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER) - - async def _reflect(self): - context = self.get_memories() - context = "\n".join([str(msg) for msg in context]) - - reflection = await Reflect().run(context=context) - self.working_memory.add(Message(content=reflection, role="assistant")) - self.working_memory.add(Message(content=Reflect.REWRITE_PLAN_INSTRUCTION, role="user")) + result, success = await self.execute_code.run(code) + print(result) + return result if success else "" diff --git a/metagpt/tools/tool_data_type.py b/metagpt/tools/tool_data_type.py index c767fef9b..a3ab20a4e 100644 --- a/metagpt/tools/tool_data_type.py +++ b/metagpt/tools/tool_data_type.py @@ -4,6 +4,7 @@ from pydantic import BaseModel class ToolTypeEnum(Enum): + EDA = "eda" DATA_PREPROCESS = "data_preprocess" FEATURE_ENGINEERING = "feature_engineering" MODEL_TRAIN = "model_train" diff --git a/metagpt/tools/tool_types.py b/metagpt/tools/tool_types.py index 289271985..2e22adc40 100644 --- a/metagpt/tools/tool_types.py +++ b/metagpt/tools/tool_types.py @@ -8,6 +8,12 @@ from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum from metagpt.tools.tool_registry import register_tool_type +@register_tool_type +class EDA(ToolType): + name: str = ToolTypeEnum.EDA.value + desc: str = "Useful for performing exploratory data analysis" + + @register_tool_type class DataPreprocess(ToolType): name: str = ToolTypeEnum.DATA_PREPROCESS.value