diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index 099934c5a..c9acb32b9 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -270,9 +270,18 @@ class MakeTools(WriteCodeByGenerate): saved_path.write_text(tool_code, encoding='utf-8') @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) - async def run(self, code_message: List[Message | Dict], **kwargs) -> str: - msgs = self.process_msg(code_message, self.DEFAULT_SYSTEM_MSG) + async def run(self, code: str, code_desc: str = None, **kwargs) -> str: + # 拼接code prompt + code_prompt = f"The following code is about {code_desc}, convert it to be a General Function, {code}" + msgs = self.process_msg(code_prompt, self.DEFAULT_SYSTEM_MSG) logger.info(f"\n\nAsk to Make tools:\n{'-'*60}\n {msgs[-1]}") + + # 更新kwargs + if 'code' in kwargs: + kwargs.pop('code') + if 'code_desc' in kwargs: + kwargs.pop('code_desc') + tool_code = await self.llm.aask_code(msgs, **kwargs) max_tries, current_try = 3, 1 func_name = self.parse_function_name(tool_code['code']) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index f44d42554..db2dfeeff 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -291,15 +291,14 @@ class MLEngineer(Role): logger.warning(f"Making tools for task_id {self.plan.current_task_id}: \ `{self.plan.current_task.instruction}` \n code: \n {code}") make_tools = MakeTools() - code_prompt = f"The following code is about {self.plan.current_task.instruction},\ - convert it to be a General Function, {code}" - tool_code = await make_tools.run(code_prompt) + tool_code = await make_tools.run(code, self.plan.current_task.instruction) # check tool_code by execute_code logger.info(f"Checking task_id {self.plan.current_task_id} tool code by executor...") _, success = await self.execute_code.run(tool_code) make_tool_retries, make_tool_current_retry = 3, 1 while not success: - tool_code = await make_tools.run(code_prompt) + # tool_code = await make_tools.run(code_prompt) + tool_code = await make_tools.run(code) _, success = await self.execute_code.run(tool_code) if make_tool_current_retry > make_tool_retries: logger.error(f"We have tried the maximum number of attempts {make_tool_retries}\