add data_desc to WriteCodeWithTools

This commit is contained in:
lidanyang 2023-11-29 14:55:54 +08:00
parent 2def00d851
commit 25c01abaf4

View file

@ -21,10 +21,11 @@ STRUCTURAL_CONTEXT = """
{current_task}
"""
def truncate(result: str, keep_len: int = 1000) -> str:
desc = """I truncated the result to only keep the last 1000 characters\n"""
if result.startswith(desc):
result = result[-len(desc):]
result = result[-len(desc) :]
if len(result) > keep_len:
result = result[-keep_len:]
@ -35,10 +36,16 @@ def truncate(result: str, keep_len: int = 1000) -> str:
class AskReview(Action):
async def run(self, context: List[Message], plan: Plan = None):
logger.info("Current overall plan:")
logger.info("\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks]))
logger.info(
"\n".join(
[
f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}"
for task in plan.tasks
]
)
)
logger.info("most recent context:")
# prompt = "\n".join(
@ -46,21 +53,26 @@ class AskReview(Action):
# )
prompt = ""
latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else ""
prompt += f"\nPlease review output from {latest_action}:\n" \
"If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \
prompt += (
f"\nPlease review output from {latest_action}:\n"
"If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n"
"If you confirm the output and wish to continue with the current process, type CONFIRM:\n"
)
rsp = input(prompt)
confirmed = "confirm" in rsp.lower()
return rsp, confirmed
class WriteTaskGuide(Action):
class WriteTaskGuide(Action):
async def run(self, task_instruction: str, data_desc: str = "") -> str:
return ""
class MLEngineer(Role):
def __init__(self, name="ABC", profile="MLEngineer", goal="", auto_run: bool = False):
def __init__(
self, name="ABC", profile="MLEngineer", goal="", auto_run: bool = False
):
super().__init__(name=name, profile=profile, goal=goal)
self._set_react_mode(react_mode="plan_and_act")
self.plan = Plan(goal=goal)
@ -70,7 +82,6 @@ class MLEngineer(Role):
self.auto_run = auto_run
async def _plan_and_act(self):
# create initial plan and update until confirmation
await self._update_plan()
@ -96,8 +107,11 @@ class MLEngineer(Role):
await self._update_plan()
async def _write_and_exec_code(self, max_retry: int = 3):
task_guide = await WriteTaskGuide().run(self.plan.current_task.instruction) if self.use_task_guide else ""
task_guide = (
await WriteTaskGuide().run(self.plan.current_task.instruction)
if self.use_task_guide
else ""
)
counter = 0
success = False
@ -109,22 +123,29 @@ class MLEngineer(Role):
# print("*" * 10)
# breakpoint()
if not self.use_tools:
if not self.use_tools or self.plan.current_task.task_type == "unknown":
# code = "print('abc')"
code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide)
code = await WriteCodeByGenerate().run(
context=context, plan=self.plan, task_guide=task_guide
)
cause_by = WriteCodeByGenerate
else:
code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide)
code = await WriteCodeWithTools().run(
context=context, plan=self.plan, task_guide=task_guide, data_desc=""
)
cause_by = WriteCodeWithTools
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
self.working_memory.add(
Message(content=code, role="assistant", cause_by=cause_by)
)
result, success = await self.execute_code.run(code)
# truncated the result
print(truncate(result))
# print(result)
self.working_memory.add(Message(content=result, role="user", cause_by=ExecutePyCode))
self.working_memory.add(
Message(content=result, role="user", cause_by=ExecutePyCode)
)
if code.startswith("!pip"):
success = False
@ -138,9 +159,13 @@ class MLEngineer(Role):
async def _ask_review(self):
if not self.auto_run:
context = self.get_useful_memories()
review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan)
review, confirmed = await AskReview().run(
context=context[-5:], plan=self.plan
)
if review.lower() not in ("confirm", "y", "yes"):
self._rc.memory.add(Message(content=review, role="user", cause_by=AskReview))
self._rc.memory.add(
Message(content=review, role="user", cause_by=AskReview)
)
return confirmed
return True
@ -149,7 +174,9 @@ class MLEngineer(Role):
while not plan_confirmed:
context = self.get_useful_memories()
rsp = await WritePlan().run(context, max_tasks=max_tasks)
self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
self.working_memory.add(
Message(content=rsp, role="assistant", cause_by=WritePlan)
)
plan_confirmed = await self._ask_review()
tasks = WritePlan.rsp_to_tasks(rsp)
@ -160,9 +187,13 @@ class MLEngineer(Role):
"""find useful memories only to reduce context length and improve performance"""
user_requirement = self.plan.goal
tasks = json.dumps([task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False)
tasks = json.dumps(
[task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False
)
current_task = self.plan.current_task.json() if self.plan.current_task else {}
context = STRUCTURAL_CONTEXT.format(user_requirement=user_requirement, tasks=tasks, current_task=current_task)
context = STRUCTURAL_CONTEXT.format(
user_requirement=user_requirement, tasks=tasks, current_task=current_task
)
context_msg = [Message(content=context, role="user")]
return context_msg + self.working_memory.get()
@ -171,6 +202,7 @@ class MLEngineer(Role):
def working_memory(self):
return self._rc.memory
if __name__ == "__main__":
# requirement = "create a normal distribution and visualize it"
requirement = "run some analysis on iris dataset"