mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
minimize ml_engineer
This commit is contained in:
parent
8a14dde219
commit
c8858cd8d4
6 changed files with 51 additions and 100 deletions
|
|
@ -63,4 +63,4 @@ class UpdateDataColumns(Action):
|
|||
prompt = UPDATE_DATA_COLUMNS.format(history_code=code_context)
|
||||
tool_config = create_func_config(PRINT_DATA_COLUMNS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
return rsp
|
||||
return rsp["code"]
|
||||
|
|
|
|||
|
|
@ -155,10 +155,6 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
|
|||
)
|
||||
code_steps = plan.current_task.code_steps
|
||||
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
|
||||
tool_catalog = {}
|
||||
|
||||
if available_tools:
|
||||
|
|
@ -189,26 +185,28 @@ class WriteCodeWithToolsML(WriteCodeWithTools):
|
|||
column_info: str = "",
|
||||
**kwargs,
|
||||
) -> Tuple[List[Message], str]:
|
||||
tool_type = plan.current_task.task_type
|
||||
available_tools = self.available_tools.get(tool_type, {})
|
||||
special_prompt = TOOL_TYPE_USAGE_PROMPT.get(tool_type, "")
|
||||
tool_type = (
|
||||
plan.current_task.task_type
|
||||
) # find tool type from task type through exact match, can extend to retrieval in the future
|
||||
available_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
|
||||
special_prompt = (
|
||||
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
|
||||
)
|
||||
code_steps = plan.current_task.code_steps
|
||||
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
|
||||
if len(available_tools) > 0:
|
||||
available_tools = {k: v["description"] for k, v in available_tools.items()}
|
||||
if available_tools:
|
||||
available_tools = {tool_name: tool.schema["description"] for tool_name, tool in available_tools.items()}
|
||||
|
||||
recommend_tools = await self._tool_recommendation(
|
||||
plan.current_task.instruction, code_steps, available_tools
|
||||
)
|
||||
tool_catalog = self._parse_recommend_tools(tool_type, recommend_tools)
|
||||
tool_catalog = self._parse_recommend_tools(recommend_tools)
|
||||
logger.info(f"Recommended tools: \n{recommend_tools}")
|
||||
|
||||
module_name = TOOL_TYPE_MODULE[tool_type]
|
||||
|
||||
prompt = ML_TOOL_USAGE_PROMPT.format(
|
||||
user_requirement=plan.goal,
|
||||
history_code=code_context,
|
||||
|
|
@ -216,7 +214,6 @@ class WriteCodeWithToolsML(WriteCodeWithTools):
|
|||
column_info=column_info,
|
||||
special_prompt=special_prompt,
|
||||
code_steps=code_steps,
|
||||
module_name=module_name,
|
||||
tool_catalog=tool_catalog,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -134,16 +134,12 @@ PRINT_DATA_COLUMNS = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"is_update": {
|
||||
"type": "boolean",
|
||||
"description": "Whether need to update the column info.",
|
||||
},
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to be added to a new cell in jupyter.",
|
||||
},
|
||||
},
|
||||
"required": ["is_update", "code"],
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -240,7 +236,7 @@ Strictly follow steps below when you writing code if it's convenient.
|
|||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
# Available Tools:
|
||||
Each Class tool is described in JSON format. When you call a tool, import the tool from `{module_name}` first.
|
||||
Each Class tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{tool_catalog}
|
||||
|
||||
# Output Example:
|
||||
|
|
|
|||
|
|
@ -1,64 +1,43 @@
|
|||
from metagpt.actions.ask_review import ReviewConst
|
||||
from metagpt.actions.debug_code import DebugCode
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.ml_da_action import Reflect, SummarizeAnalysis, UpdateDataColumns
|
||||
from metagpt.actions.ml_da_action import UpdateDataColumns
|
||||
from metagpt.actions.write_analysis_code import WriteCodeWithToolsML
|
||||
from metagpt.actions.write_code_steps import WriteCodeSteps
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
from metagpt.roles.kaggle_manager import DownloadData, SubmitResult
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class MLEngineer(CodeInterpreter):
|
||||
use_code_steps: bool = False
|
||||
use_udfs: bool = False
|
||||
data_desc: dict = {}
|
||||
debug_context: list = []
|
||||
latest_code: str = ""
|
||||
|
||||
def __init__(self, name="Mark", profile="MLEngineer", **kwargs):
|
||||
super().__init__(name=name, profile=profile, **kwargs)
|
||||
# self._watch([DownloadData, SubmitResult]) # in multi-agent settings
|
||||
|
||||
async def _plan_and_act(self):
|
||||
### a new attempt on the data, relevant in a multi-agent multi-turn setting ###
|
||||
await self._prepare_data_context()
|
||||
|
||||
### general plan process ###
|
||||
await super()._plan_and_act()
|
||||
|
||||
### summarize analysis ###
|
||||
summary = await SummarizeAnalysis().run(self.planner.plan)
|
||||
rsp = Message(content=summary, cause_by=SummarizeAnalysis)
|
||||
self.rc.memory.add(rsp)
|
||||
|
||||
return rsp
|
||||
|
||||
async def _write_and_exec_code(self, max_retry: int = 3):
|
||||
self.planner.current_task.code_steps = (
|
||||
await WriteCodeSteps().run(self.planner.plan) if self.use_code_steps else ""
|
||||
)
|
||||
|
||||
code, result, success = await super()._write_and_exec_code(max_retry=max_retry)
|
||||
|
||||
if success:
|
||||
if self.use_tools and self.planner.current_task.task_type in ["data_preprocess", "feature_engineering"]:
|
||||
update_success, new_code = await self._update_data_columns()
|
||||
if update_success:
|
||||
code = code + "\n\n" + new_code
|
||||
|
||||
return code, result, success
|
||||
|
||||
async def _write_code(self):
|
||||
if not self.use_tools:
|
||||
return await super()._write_code()
|
||||
|
||||
code_execution_count = sum([msg.cause_by == any_to_str(ExecutePyCode) for msg in self.working_memory.get()])
|
||||
# In a trial and errors settings, check whether this is our first attempt to tackle the task. If there is no code execution before, then it is.
|
||||
is_first_trial = any_to_str(ExecutePyCode) not in [msg.cause_by for msg in self.working_memory.get()]
|
||||
|
||||
if code_execution_count > 0:
|
||||
logger.warning("We got a bug code, now start to debug...")
|
||||
if is_first_trial:
|
||||
# For the first trial, write task code from scratch
|
||||
column_info = await self._update_data_columns()
|
||||
|
||||
logger.info("Write code with tools")
|
||||
tool_context, code = await WriteCodeWithToolsML().run(
|
||||
context=[], # context assembled inside the Action
|
||||
plan=self.planner.plan,
|
||||
column_info=column_info,
|
||||
)
|
||||
self.debug_context = tool_context
|
||||
cause_by = WriteCodeWithToolsML
|
||||
|
||||
else:
|
||||
# Previous trials resulted in error, debug and rewrite the code
|
||||
logger.warning("We got a bug, now start to debug...")
|
||||
code = await DebugCode().run(
|
||||
code=self.latest_code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
|
|
@ -67,49 +46,21 @@ class MLEngineer(CodeInterpreter):
|
|||
logger.info(f"new code \n{code}")
|
||||
cause_by = DebugCode
|
||||
|
||||
else:
|
||||
logger.info("Write code with tools")
|
||||
tool_context, code = await WriteCodeWithToolsML().run(
|
||||
context=[], # context assembled inside the Action
|
||||
plan=self.planner.plan,
|
||||
column_info=self.data_desc.get("column_info", ""),
|
||||
)
|
||||
self.debug_context = tool_context
|
||||
cause_by = WriteCodeWithToolsML
|
||||
|
||||
self.latest_code = code
|
||||
|
||||
return code, cause_by
|
||||
|
||||
async def _update_data_columns(self):
|
||||
current_task = self.planner.plan.current_task
|
||||
if current_task.task_type not in [
|
||||
ToolTypeEnum.DATA_PREPROCESS.value,
|
||||
ToolTypeEnum.FEATURE_ENGINEERING.value,
|
||||
ToolTypeEnum.MODEL_TRAIN.value,
|
||||
]:
|
||||
return ""
|
||||
logger.info("Check columns in updated data")
|
||||
rsp = await UpdateDataColumns().run(self.planner.plan)
|
||||
is_update, code = rsp["is_update"], rsp["code"]
|
||||
code = await UpdateDataColumns().run(self.planner.plan)
|
||||
success = False
|
||||
if is_update:
|
||||
result, success = await self.execute_code.run(code)
|
||||
if success:
|
||||
print(result)
|
||||
self.data_desc["column_info"] = result
|
||||
return success, code
|
||||
|
||||
async def _prepare_data_context(self):
|
||||
memories = self.get_memories()
|
||||
if memories:
|
||||
latest_event = memories[-1].cause_by
|
||||
if latest_event == DownloadData:
|
||||
self.planner.plan.context = memories[-1].content
|
||||
elif latest_event == SubmitResult:
|
||||
# self reflect on previous plan outcomes and think about how to improve the plan, add to working memory
|
||||
await self._reflect()
|
||||
|
||||
# get feedback for improvement from human, add to working memory
|
||||
await self.planner.ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
async def _reflect(self):
|
||||
context = self.get_memories()
|
||||
context = "\n".join([str(msg) for msg in context])
|
||||
|
||||
reflection = await Reflect().run(context=context)
|
||||
self.working_memory.add(Message(content=reflection, role="assistant"))
|
||||
self.working_memory.add(Message(content=Reflect.REWRITE_PLAN_INSTRUCTION, role="user"))
|
||||
result, success = await self.execute_code.run(code)
|
||||
print(result)
|
||||
return result if success else ""
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class ToolTypeEnum(Enum):
|
||||
EDA = "eda"
|
||||
DATA_PREPROCESS = "data_preprocess"
|
||||
FEATURE_ENGINEERING = "feature_engineering"
|
||||
MODEL_TRAIN = "model_train"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,12 @@ from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum
|
|||
from metagpt.tools.tool_registry import register_tool_type
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class EDA(ToolType):
|
||||
name: str = ToolTypeEnum.EDA.value
|
||||
desc: str = "Useful for performing exploratory data analysis"
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class DataPreprocess(ToolType):
|
||||
name: str = ToolTypeEnum.DATA_PREPROCESS.value
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue