From 608126e1f9906bc69c6ea489674c25c0198ec9ae Mon Sep 17 00:00:00 2001 From: yzlin Date: Tue, 28 Nov 2023 13:50:15 +0800 Subject: [PATCH] update context --- metagpt/actions/write_plan.py | 7 +++-- metagpt/roles/ml_engineer.py | 48 +++++++++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/metagpt/actions/write_plan.py b/metagpt/actions/write_plan.py index e35ba7a92..dcfa25d55 100644 --- a/metagpt/actions/write_plan.py +++ b/metagpt/actions/write_plan.py @@ -15,8 +15,6 @@ class WritePlan(Action): PROMPT_TEMPLATE = """ # Context: __context__ - # Current Plan: - __current_plan__ # Task: Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks. If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. @@ -32,10 +30,11 @@ class WritePlan(Action): ] ``` """ - async def run(self, context: List[Message], current_plan: str = "", max_tasks: int = 5) -> str: + async def run(self, context: List[Message], max_tasks: int = 5) -> str: prompt = ( self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context])) - .replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks)) + # .replace("__current_plan__", current_plan) + .replace("__max_tasks__", str(max_tasks)) ) rsp = await self._aask(prompt) rsp = CodeParser.parse_code(block=None, text=rsp) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 480f6cecf..910b94432 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -12,18 +12,27 @@ from metagpt.actions.write_plan import WritePlan from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools from metagpt.actions.execute_code import ExecutePyCode +STRUCTURAL_CONTEXT = """ +## User Requirement +{user_requirement} +## Current Plan +{tasks} +## Current Task +{current_task} +""" + class AskReview(Action): async def run(self, context: List[Message], plan: Plan = None): logger.info("Current overall plan:") - logger.info("\n".join([f"{task.task_id}: {task.instruction}" for task in plan.tasks])) + logger.info("\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks])) logger.info("most recent context:") # prompt = "\n".join( # [f"{msg.cause_by.__name__ if msg.cause_by else 'Main Requirement'}: {msg.content}" for msg in context] # ) prompt = "" - latest_action = context[-1].cause_by.__name__ + latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else "" prompt += f"\nPlease review output from {latest_action}:\n" \ "If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \ "If you confirm the output and wish to continue with the current process, type CONFIRM:\n" @@ -44,6 +53,7 @@ class MLEngineer(Role): self.plan = Plan(goal=goal) self.use_tools = False self.use_task_guide = False + self.execute_code_action = ExecutePyCode() async def _plan_and_act(self): @@ -65,6 +75,7 @@ class MLEngineer(Role): task.code = code task.result = result self.plan.finish_current_task() + self.working_memory.clear() else: # update plan according to user's feedback and to take on changed tasks @@ -79,6 +90,11 @@ class MLEngineer(Role): while not success and counter < max_retry: context = self.get_useful_memories() + # print("*" * 10) + # print(context) + # print("*" * 10) + # breakpoint() + if not self.use_tools: # code = "print('abc')" code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide) @@ -88,11 +104,11 @@ class MLEngineer(Role): code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide) cause_by = WriteCodeWithTools - self._rc.memory.add(Message(content=code, role="assistant", cause_by=cause_by)) + self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by)) - result, success = await ExecutePyCode().run(code) + result, success = await self.execute_code_action.run(code) print(result) - self._rc.memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) + self.working_memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) # if not success: # await self._ask_review() @@ -108,21 +124,31 @@ class MLEngineer(Role): return confirmed async def _update_plan(self, max_tasks: int = 3): - current_plan = str([task.json() for task in self.plan.tasks]) plan_confirmed = False while not plan_confirmed: context = self.get_useful_memories() - rsp = await WritePlan().run(context, current_plan=current_plan, max_tasks=max_tasks) - self._rc.memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan)) + rsp = await WritePlan().run(context, max_tasks=max_tasks) + self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan)) plan_confirmed = await self._ask_review() tasks = WritePlan.rsp_to_tasks(rsp) self.plan.add_tasks(tasks) + self.working_memory.clear() - def get_useful_memories(self, current_task_memories: List[str] = []) -> List[Message]: + def get_useful_memories(self) -> List[Message]: """find useful memories only to reduce context length and improve performance""" - memories = super().get_memories() - return memories + + user_requirement = self.plan.goal + tasks = json.dumps([task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False) + current_task = self.plan.current_task.json() if self.plan.current_task else {} + context = STRUCTURAL_CONTEXT.format(user_requirement=user_requirement, tasks=tasks, current_task=current_task) + context_msg = [Message(content=context, role="user")] + + return context_msg + self.working_memory.get() + + @property + def working_memory(self): + return self._rc.memory if __name__ == "__main__":