diff --git a/metagpt/actions/write_plan.py b/metagpt/actions/write_plan.py index e35ba7a92..dcfa25d55 100644 --- a/metagpt/actions/write_plan.py +++ b/metagpt/actions/write_plan.py @@ -15,8 +15,6 @@ class WritePlan(Action): PROMPT_TEMPLATE = """ # Context: __context__ - # Current Plan: - __current_plan__ # Task: Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks. If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. @@ -32,10 +30,11 @@ class WritePlan(Action): ] ``` """ - async def run(self, context: List[Message], current_plan: str = "", max_tasks: int = 5) -> str: + async def run(self, context: List[Message], max_tasks: int = 5) -> str: prompt = ( self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context])) - .replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks)) + # .replace("__current_plan__", current_plan) + .replace("__max_tasks__", str(max_tasks)) ) rsp = await self._aask(prompt) rsp = CodeParser.parse_code(block=None, text=rsp) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 2e4bbfc82..9c1fc2dc0 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -12,18 +12,40 @@ from metagpt.actions.write_plan import WritePlan from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools from metagpt.actions.execute_code import ExecutePyCode +STRUCTURAL_CONTEXT = """ +## User Requirement +{user_requirement} +## Current Plan +{tasks} +## Current Task +{current_task} +""" + +def truncate(result: str, keep_len: int = 1000) -> str: + desc = """I truncated the result to only keep the last 1000 characters\n""" + if result.startswith(desc): + result = result[-len(desc):] + + if len(result) > keep_len: + result = result[-keep_len:] + + if not result.startswith(desc): + return desc + result + return desc + + class AskReview(Action): async def run(self, context: List[Message], plan: Plan = None): logger.info("Current overall plan:") - logger.info("\n".join([f"{task.task_id}: {task.instruction}" for task in plan.tasks])) + logger.info("\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks])) logger.info("most recent context:") # prompt = "\n".join( # [f"{msg.cause_by.__name__ if msg.cause_by else 'Main Requirement'}: {msg.content}" for msg in context] # ) prompt = "" - latest_action = context[-1].cause_by.__name__ + latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else "" prompt += f"\nPlease review output from {latest_action}:\n" \ "If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \ "If you confirm the output and wish to continue with the current process, type CONFIRM:\n" @@ -66,6 +88,7 @@ class MLEngineer(Role): task.code = code task.result = result self.plan.finish_current_task() + self.working_memory.clear() else: # update plan according to user's feedback and to take on changed tasks @@ -80,6 +103,11 @@ class MLEngineer(Role): while not success and counter < max_retry: context = self.get_useful_memories() + # print("*" * 10) + # print(context) + # print("*" * 10) + # breakpoint() + if not self.use_tools: # code = "print('abc')" code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide) @@ -89,12 +117,13 @@ class MLEngineer(Role): code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide) cause_by = WriteCodeWithTools - self._rc.memory.add(Message(content=code, role="assistant", cause_by=cause_by)) + self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by)) result, success = await self.execute_code.run(code) # truncated the result - print(self.truncate(result)) - self._rc.memory.add(Message(content=self.truncate(result), role="user", cause_by=ExecutePyCode)) + print(truncate(result)) + # print(result) + self.working_memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) # if not success: # await self._ask_review() @@ -111,34 +140,31 @@ class MLEngineer(Role): return confirmed async def _update_plan(self, max_tasks: int = 3): - current_plan = str([task.json() for task in self.plan.tasks]) plan_confirmed = False while not plan_confirmed: context = self.get_useful_memories() - rsp = await WritePlan().run(context, current_plan=current_plan, max_tasks=max_tasks) - self._rc.memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan)) + rsp = await WritePlan().run(context, max_tasks=max_tasks) + self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan)) plan_confirmed = await self._ask_review() tasks = WritePlan.rsp_to_tasks(rsp) self.plan.add_tasks(tasks) + self.working_memory.clear() - def get_useful_memories(self, current_task_memories: List[str] = []) -> List[Message]: + def get_useful_memories(self) -> List[Message]: """find useful memories only to reduce context length and improve performance""" - memories = super().get_memories() - return memories - def truncate(self, result: str, keep_len: int = 1000) -> str: - desc = """I truncated the result to only keep the last 1000 characters\n""" - if result.startswith(desc): - result = result[-len(desc):] + user_requirement = self.plan.goal + tasks = json.dumps([task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False) + current_task = self.plan.current_task.json() if self.plan.current_task else {} + context = STRUCTURAL_CONTEXT.format(user_requirement=user_requirement, tasks=tasks, current_task=current_task) + context_msg = [Message(content=context, role="user")] - if len(result) > keep_len: - result = result[-keep_len:] - - if not result.startswith(desc): - return desc + result - return desc + return context_msg + self.working_memory.get() + @property + def working_memory(self): + return self._rc.memory if __name__ == "__main__": # requirement = "create a normal distribution and visualize it"