update context

This commit is contained in:
yzlin 2023-11-28 13:50:15 +08:00
parent a9b46579b4
commit 608126e1f9
2 changed files with 40 additions and 15 deletions

View file

@ -15,8 +15,6 @@ class WritePlan(Action):
PROMPT_TEMPLATE = """
# Context:
__context__
# Current Plan:
__current_plan__
# Task:
Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks.
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes.
@ -32,10 +30,11 @@ class WritePlan(Action):
]
```
"""
async def run(self, context: List[Message], current_plan: str = "", max_tasks: int = 5) -> str:
async def run(self, context: List[Message], max_tasks: int = 5) -> str:
prompt = (
self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context]))
.replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks))
# .replace("__current_plan__", current_plan)
.replace("__max_tasks__", str(max_tasks))
)
rsp = await self._aask(prompt)
rsp = CodeParser.parse_code(block=None, text=rsp)

View file

@ -12,18 +12,27 @@ from metagpt.actions.write_plan import WritePlan
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
from metagpt.actions.execute_code import ExecutePyCode
STRUCTURAL_CONTEXT = """
## User Requirement
{user_requirement}
## Current Plan
{tasks}
## Current Task
{current_task}
"""
class AskReview(Action):
async def run(self, context: List[Message], plan: Plan = None):
logger.info("Current overall plan:")
logger.info("\n".join([f"{task.task_id}: {task.instruction}" for task in plan.tasks]))
logger.info("\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks]))
logger.info("most recent context:")
# prompt = "\n".join(
# [f"{msg.cause_by.__name__ if msg.cause_by else 'Main Requirement'}: {msg.content}" for msg in context]
# )
prompt = ""
latest_action = context[-1].cause_by.__name__
latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else ""
prompt += f"\nPlease review output from {latest_action}:\n" \
"If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \
"If you confirm the output and wish to continue with the current process, type CONFIRM:\n"
@ -44,6 +53,7 @@ class MLEngineer(Role):
self.plan = Plan(goal=goal)
self.use_tools = False
self.use_task_guide = False
self.execute_code_action = ExecutePyCode()
async def _plan_and_act(self):
@ -65,6 +75,7 @@ class MLEngineer(Role):
task.code = code
task.result = result
self.plan.finish_current_task()
self.working_memory.clear()
else:
# update plan according to user's feedback and to take on changed tasks
@ -79,6 +90,11 @@ class MLEngineer(Role):
while not success and counter < max_retry:
context = self.get_useful_memories()
# print("*" * 10)
# print(context)
# print("*" * 10)
# breakpoint()
if not self.use_tools:
# code = "print('abc')"
code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide)
@ -88,11 +104,11 @@ class MLEngineer(Role):
code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide)
cause_by = WriteCodeWithTools
self._rc.memory.add(Message(content=code, role="assistant", cause_by=cause_by))
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
result, success = await ExecutePyCode().run(code)
result, success = await self.execute_code_action.run(code)
print(result)
self._rc.memory.add(Message(content=result, role="user", cause_by=ExecutePyCode))
self.working_memory.add(Message(content=result, role="user", cause_by=ExecutePyCode))
# if not success:
# await self._ask_review()
@ -108,21 +124,31 @@ class MLEngineer(Role):
return confirmed
async def _update_plan(self, max_tasks: int = 3):
current_plan = str([task.json() for task in self.plan.tasks])
plan_confirmed = False
while not plan_confirmed:
context = self.get_useful_memories()
rsp = await WritePlan().run(context, current_plan=current_plan, max_tasks=max_tasks)
self._rc.memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
rsp = await WritePlan().run(context, max_tasks=max_tasks)
self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
plan_confirmed = await self._ask_review()
tasks = WritePlan.rsp_to_tasks(rsp)
self.plan.add_tasks(tasks)
self.working_memory.clear()
def get_useful_memories(self, current_task_memories: List[str] = []) -> List[Message]:
def get_useful_memories(self) -> List[Message]:
"""find useful memories only to reduce context length and improve performance"""
memories = super().get_memories()
return memories
user_requirement = self.plan.goal
tasks = json.dumps([task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False)
current_task = self.plan.current_task.json() if self.plan.current_task else {}
context = STRUCTURAL_CONTEXT.format(user_requirement=user_requirement, tasks=tasks, current_task=current_task)
context_msg = [Message(content=context, role="user")]
return context_msg + self.working_memory.get()
@property
def working_memory(self):
return self._rc.memory
if __name__ == "__main__":