diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 3f46b9451..b8a258b46 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -21,10 +21,11 @@ STRUCTURAL_CONTEXT = """ {current_task} """ + def truncate(result: str, keep_len: int = 1000) -> str: desc = """I truncated the result to only keep the last 1000 characters\n""" if result.startswith(desc): - result = result[-len(desc):] + result = result[-len(desc) :] if len(result) > keep_len: result = result[-keep_len:] @@ -35,10 +36,16 @@ def truncate(result: str, keep_len: int = 1000) -> str: class AskReview(Action): - async def run(self, context: List[Message], plan: Plan = None): logger.info("Current overall plan:") - logger.info("\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks])) + logger.info( + "\n".join( + [ + f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" + for task in plan.tasks + ] + ) + ) logger.info("most recent context:") # prompt = "\n".join( @@ -46,21 +53,26 @@ class AskReview(Action): # ) prompt = "" latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else "" - prompt += f"\nPlease review output from {latest_action}:\n" \ - "If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \ + prompt += ( + f"\nPlease review output from {latest_action}:\n" + "If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" "If you confirm the output and wish to continue with the current process, type CONFIRM:\n" + ) rsp = input(prompt) confirmed = "confirm" in rsp.lower() return rsp, confirmed -class WriteTaskGuide(Action): +class WriteTaskGuide(Action): async def run(self, task_instruction: str, data_desc: str = "") -> str: return "" + class MLEngineer(Role): - def __init__(self, name="ABC", profile="MLEngineer", goal="", auto_run: bool = False): + def __init__( + self, name="ABC", profile="MLEngineer", goal="", auto_run: bool = False + ): super().__init__(name=name, profile=profile, goal=goal) self._set_react_mode(react_mode="plan_and_act") self.plan = Plan(goal=goal) @@ -70,7 +82,6 @@ class MLEngineer(Role): self.auto_run = auto_run async def _plan_and_act(self): - # create initial plan and update until confirmation await self._update_plan() @@ -96,8 +107,11 @@ class MLEngineer(Role): await self._update_plan() async def _write_and_exec_code(self, max_retry: int = 3): - - task_guide = await WriteTaskGuide().run(self.plan.current_task.instruction) if self.use_task_guide else "" + task_guide = ( + await WriteTaskGuide().run(self.plan.current_task.instruction) + if self.use_task_guide + else "" + ) counter = 0 success = False @@ -109,22 +123,29 @@ class MLEngineer(Role): # print("*" * 10) # breakpoint() - if not self.use_tools: + if not self.use_tools or self.plan.current_task.task_type == "unknown": # code = "print('abc')" - code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide) + code = await WriteCodeByGenerate().run( + context=context, plan=self.plan, task_guide=task_guide + ) cause_by = WriteCodeByGenerate - else: - code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide) + code = await WriteCodeWithTools().run( + context=context, plan=self.plan, task_guide=task_guide, data_desc="" + ) cause_by = WriteCodeWithTools - self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by)) + self.working_memory.add( + Message(content=code, role="assistant", cause_by=cause_by) + ) result, success = await self.execute_code.run(code) # truncated the result print(truncate(result)) # print(result) - self.working_memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) + self.working_memory.add( + Message(content=result, role="user", cause_by=ExecutePyCode) + ) if code.startswith("!pip"): success = False @@ -138,9 +159,13 @@ class MLEngineer(Role): async def _ask_review(self): if not self.auto_run: context = self.get_useful_memories() - review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan) + review, confirmed = await AskReview().run( + context=context[-5:], plan=self.plan + ) if review.lower() not in ("confirm", "y", "yes"): - self._rc.memory.add(Message(content=review, role="user", cause_by=AskReview)) + self._rc.memory.add( + Message(content=review, role="user", cause_by=AskReview) + ) return confirmed return True @@ -149,7 +174,9 @@ class MLEngineer(Role): while not plan_confirmed: context = self.get_useful_memories() rsp = await WritePlan().run(context, max_tasks=max_tasks) - self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan)) + self.working_memory.add( + Message(content=rsp, role="assistant", cause_by=WritePlan) + ) plan_confirmed = await self._ask_review() tasks = WritePlan.rsp_to_tasks(rsp) @@ -160,9 +187,13 @@ class MLEngineer(Role): """find useful memories only to reduce context length and improve performance""" user_requirement = self.plan.goal - tasks = json.dumps([task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False) + tasks = json.dumps( + [task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False + ) current_task = self.plan.current_task.json() if self.plan.current_task else {} - context = STRUCTURAL_CONTEXT.format(user_requirement=user_requirement, tasks=tasks, current_task=current_task) + context = STRUCTURAL_CONTEXT.format( + user_requirement=user_requirement, tasks=tasks, current_task=current_task + ) context_msg = [Message(content=context, role="user")] return context_msg + self.working_memory.get() @@ -171,6 +202,7 @@ class MLEngineer(Role): def working_memory(self): return self._rc.memory + if __name__ == "__main__": # requirement = "create a normal distribution and visualize it" requirement = "run some analysis on iris dataset"