From abe65a062b54ba4530a6405beedd1b06bce4d876 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Tue, 30 Jul 2024 15:52:22 +0800 Subject: [PATCH] add check_data for data_analyst --- metagpt/roles/di/data_analyst.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index 25e504e2f..326b41020 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -3,7 +3,7 @@ from __future__ import annotations from pydantic import Field, model_validator from metagpt.actions.di.execute_nb_code import ExecuteNbCode -from metagpt.actions.di.write_analysis_code import WriteAnalysisCode +from metagpt.actions.di.write_analysis_code import WriteAnalysisCode, CheckData from metagpt.logs import logger from metagpt.prompts.di.data_analyst import ( CODE_STATUS, @@ -11,9 +11,11 @@ from metagpt.prompts.di.data_analyst import ( TASK_TYPE_DESC, ) from metagpt.prompts.di.role_zero import ROLE_INSTRUCTION +from metagpt.prompts.di.write_analysis_code import DATA_INFO from metagpt.roles.di.role_zero import RoleZero from metagpt.schema import Message, TaskResult from metagpt.strategy.experience_retriever import ExpRetriever, KeywordExpRetriever +from metagpt.strategy.task_type import TaskType from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool @@ -68,6 +70,9 @@ class DataAnalyst(RoleZero): else: tool_info = "" + # data info + await self._check_data() + while not success and counter < 3: ### write code ### logger.info("ready to WriteAnalysisCode") @@ -101,3 +106,24 @@ class DataAnalyst(RoleZero): output += "The code written has been executed successfully." self.rc.working_memory.clear() return output + + async def _check_data(self): + if ( + not self.planner.plan.get_finished_tasks() + or self.planner.plan.current_task.task_type + not in [ + TaskType.DATA_PREPROCESS.type_name, + TaskType.FEATURE_ENGINEERING.type_name, + TaskType.MODEL_TRAIN.type_name, + ] + ): + return + logger.info("Check updated data") + code = await CheckData().run(self.planner.plan) + if not code.strip(): + return + result, success = await self.execute_code.run(code) + if success: + print(result) + data_info = DATA_INFO.format(info=result) + self.rc.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))