add check_data for data_analyst

This commit is contained in:
lidanyang 2024-07-30 15:52:22 +08:00
parent 7ea23a48e4
commit abe65a062b

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from pydantic import Field, model_validator
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode, CheckData
from metagpt.logs import logger
from metagpt.prompts.di.data_analyst import (
CODE_STATUS,
@ -11,9 +11,11 @@ from metagpt.prompts.di.data_analyst import (
TASK_TYPE_DESC,
)
from metagpt.prompts.di.role_zero import ROLE_INSTRUCTION
from metagpt.prompts.di.write_analysis_code import DATA_INFO
from metagpt.roles.di.role_zero import RoleZero
from metagpt.schema import Message, TaskResult
from metagpt.strategy.experience_retriever import ExpRetriever, KeywordExpRetriever
from metagpt.strategy.task_type import TaskType
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
@ -68,6 +70,9 @@ class DataAnalyst(RoleZero):
else:
tool_info = ""
# data info
await self._check_data()
while not success and counter < 3:
### write code ###
logger.info("ready to WriteAnalysisCode")
@ -101,3 +106,24 @@ class DataAnalyst(RoleZero):
output += "The code written has been executed successfully."
self.rc.working_memory.clear()
return output
async def _check_data(self):
if (
not self.planner.plan.get_finished_tasks()
or self.planner.plan.current_task.task_type
not in [
TaskType.DATA_PREPROCESS.type_name,
TaskType.FEATURE_ENGINEERING.type_name,
TaskType.MODEL_TRAIN.type_name,
]
):
return
logger.info("Check updated data")
code = await CheckData().run(self.planner.plan)
if not code.strip():
return
result, success = await self.execute_code.run(code)
if success:
print(result)
data_info = DATA_INFO.format(info=result)
self.rc.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))