Merge branch 'pipeline' into 'dev'

MLE Pipeline, WritePlan, WriteCodeByGenerate, ExecuteCode

See merge request agents/data_agents_opt!2
This commit is contained in:
林义章 2023-11-24 06:14:38 +00:00
commit f0ada24e3d
11 changed files with 751 additions and 0 deletions

View file

@ -23,6 +23,9 @@ from metagpt.actions.write_code_review import WriteCodeReview
from metagpt.actions.write_prd import WritePRD
from metagpt.actions.write_prd_review import WritePRDReview
from metagpt.actions.write_test import WriteTest
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.actions.write_analysis_code import WriteCodeByGenerate
from metagpt.actions.write_plan import WritePlan
class ActionType(Enum):
@ -45,6 +48,9 @@ class ActionType(Enum):
COLLECT_LINKS = CollectLinks
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
CONDUCT_RESEARCH = ConductResearch
EXECUTE_PYCODE = ExecutePyCode
WRITE_CODE_BY_GENERATE = WriteCodeByGenerate
WRITE_PLAN = WritePlan
__all__ = [

View file

@ -0,0 +1,178 @@
# -*- encoding: utf-8 -*-
"""
@Date : 2023/11/17 14:22:15
@Author : orange-crow
@File : code_executor.py
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Tuple, Union
import traceback
import nbformat
from nbclient import NotebookClient
from nbformat.v4 import new_code_cell, new_output
from rich.console import Console
from rich.syntax import Syntax
from metagpt.actions import Action
from metagpt.schema import Message
class ExecuteCode(ABC):
@abstractmethod
async def build(self):
"""build code executor"""
...
@abstractmethod
async def run(self, code: str):
"""run code"""
...
@abstractmethod
async def terminate(self):
"""terminate executor"""
...
@abstractmethod
async def reset(self):
"""reset executor"""
...
class ExecutePyCode(ExecuteCode, Action):
"""execute code, return result to llm, and display it."""
def __init__(self, name: str = "python_executor", context=None, llm=None):
super().__init__(name, context, llm)
self.nb = nbformat.v4.new_notebook()
self.nb_client = NotebookClient(self.nb)
self.console = Console()
self.interaction = "ipython" if self.is_ipython() else "terminal"
async def build(self):
if self.nb_client.kc is None or not await self.nb_client.kc.is_alive():
self.nb_client.create_kernel_manager()
self.nb_client.start_new_kernel()
self.nb_client.start_new_kernel_client()
async def terminate(self):
"""kill NotebookClient"""
await self.nb_client._async_cleanup_kernel()
async def reset(self):
"""reset NotebookClient"""
await self.terminate()
self.nb_client = NotebookClient(self.nb)
def add_code_cell(self, code):
self.nb.cells.append(new_code_cell(source=code))
def _display(self, code, language: str = "python"):
if language == "python":
code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True)
self.console.print("\n")
self.console.print(code)
def add_output_to_cell(self, cell, output):
if "outputs" not in cell:
cell["outputs"] = []
# TODO: show figures
else:
cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output)))
def parse_outputs(self, outputs: List) -> str:
assert isinstance(outputs, list)
parsed_output = ""
# empty outputs: such as 'x=1\ny=2'
if not outputs:
return parsed_output
for output in outputs:
if output["output_type"] == "stream":
parsed_output += output["text"]
elif output["output_type"] == "display_data":
self.show_bytes_figure(output["data"]["image/png"], self.interaction)
elif output["output_type"] == "execute_result":
parsed_output += output["data"]["text/plain"]
return parsed_output
def show_bytes_figure(self, image_base64: str, interaction_type: str = "ipython"):
import base64
image_bytes = base64.b64decode(image_base64)
if interaction_type == "ipython":
from IPython.display import Image, display
display(Image(data=image_bytes))
else:
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))
image.show()
def is_ipython(self) -> bool:
try:
# 如果在Jupyter Notebook中运行__file__ 变量不存在
from IPython import get_ipython
if get_ipython() is not None and "IPKernelApp" in get_ipython().config:
return True
else:
return False
except NameError:
# 如果在Python脚本中运行__file__ 变量存在
return False
def _process_code(self, code: Union[str, Dict, Message], language: str = None) -> Tuple:
language = language or 'python'
if isinstance(code, str) and Path(code).suffix in (".py", ".txt"):
code = Path(code).read_text(encoding="utf-8")
return code, language
if isinstance(code, str):
return code, language
if isinstance(code, dict):
assert "code" in code
if "language" not in code:
code['language'] = 'python'
code, language = code["code"], code["language"]
elif isinstance(code, Message):
if isinstance(code.content, dict) and "language" not in code.content:
code.content["language"] = 'python'
code, language = code.content["code"], code.content["language"]
elif isinstance(code.content, str):
code, language = code.content, language
else:
raise ValueError(f"Not support code type {type(code).__name__}.")
return code, language
async def run(self, code: Union[str, Dict, Message], language: str = "python") -> Tuple[str, bool]:
code, language = self._process_code(code, language)
self._display(code, language)
if language == "python":
# add code to the notebook
self.add_code_cell(code=code)
try:
# build code executor
await self.build()
# run code
# TODO: add max_tries for run code.
cell_index = len(self.nb.cells) - 1
await self.nb_client.async_execute_cell(self.nb.cells[-1], cell_index)
return self.parse_outputs(self.nb.cells[-1].outputs), True
except Exception as e:
# FIXME: CellExecutionError is hard to read. for example `1\0` raise ZeroDivisionError:
# CellExecutionError('An error occurred while executing the following cell:\n------------------\nz=1/0\n------------------\n\n\n\x1b[0;31m---------------------------------------------------------------------------\x1b[0m\n\x1b[0;31mZeroDivisionError\x1b[0m Traceback (most recent call last)\nCell \x1b[0;32mIn[1], line 1\x1b[0m\n\x1b[0;32m----> 1\x1b[0m z\x1b[38;5;241m=\x1b[39m\x1b[38;5;241;43m1\x1b[39;49m\x1b[38;5;241;43m/\x1b[39;49m\x1b[38;5;241;43m0\x1b[39;49m\n\n\x1b[0;31mZeroDivisionError\x1b[0m: division by zero\n')
return traceback.format_exc(), False
else:
# TODO: markdown
raise NotImplementedError(f"Not support this code type : {language}, Only support code!")

View file

@ -0,0 +1,71 @@
# -*- encoding: utf-8 -*-
"""
@Date : 2023/11/20 13:19:39
@Author : orange-crow
@File : write_code_v2.py
"""
from typing import Dict, List, Union
from metagpt.actions import Action
from metagpt.schema import Message, Plan
class BaseWriteAnalysisCode(Action):
async def run(self, context: List[Message], plan: Plan = None, task_guide: str = "") -> str:
"""Run of a code writing action, used in data analysis or modeling
Args:
context (List[Message]): Action output history, source action denoted by Message.cause_by
plan (Plan, optional): Overall plan. Defaults to None.
task_guide (str, optional): suggested step breakdown for the current task. Defaults to "".
Returns:
str: The code string.
"""
class WriteCodeByGenerate(BaseWriteAnalysisCode):
"""Write code fully by generation"""
def __init__(self, name: str = "", context=None, llm=None) -> str:
super().__init__(name, context, llm)
def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None):
if isinstance(prompt, str):
return system_msg + prompt if system_msg else prompt
if isinstance(prompt, Message):
if isinstance(prompt.content, dict):
prompt.content = system_msg + str([(k, v) for k, v in prompt.content.items()])\
if system_msg else prompt.content
else:
prompt.content = system_msg + prompt.content if system_msg else prompt.content
return prompt
if isinstance(prompt, list):
_prompt = []
for msg in prompt:
if isinstance(msg, Message) and isinstance(msg.content, dict):
msg.content = str([(k, v) for k, v in msg.content.items()])
if isinstance(msg, Message):
msg = msg.to_dict()
_prompt.append(msg)
prompt = _prompt
if isinstance(prompt, list) and system_msg:
if system_msg not in prompt[0]['content']:
prompt[0]['content'] = system_msg + prompt[0]['content']
return prompt
async def run(
self, context: [List[Message]], plan: Plan = None, task_guide: str = "", system_msg: str = None, **kwargs
) -> str:
prompt = self.process_msg(context, system_msg)
code_content = await self.llm.aask_code(prompt, **kwargs)
return code_content["code"]
class WriteCodeWithTools(BaseWriteAnalysisCode):
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
async def run(self, context: List[Message], plan: Plan = None, task_guide: str = "") -> str:
return "print('abc')"

View file

@ -0,0 +1,48 @@
# -*- encoding: utf-8 -*-
"""
@Date : 2023/11/20 11:24:03
@Author : orange-crow
@File : plan.py
"""
from typing import List
import json
from metagpt.actions import Action
from metagpt.schema import Message, Task
from metagpt.utils.common import CodeParser
class WritePlan(Action):
PROMPT_TEMPLATE = """
# Context:
__context__
# Current Plan:
__current_plan__
# Task:
Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks.
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes.
Output a list of jsons following the format:
```json
[
{
"task_id": str = "unique identifier for a task in plan, can be an ordinal",
"dependent_task_ids": list[str] = "ids of tasks prerequisite to this task",
"instruction": "what you should do in this task, one short phrase or sentence",
},
...
]
```
"""
async def run(self, context: List[Message], current_plan: str = "", max_tasks: int = 5) -> str:
prompt = (
self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context]))
.replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks))
)
rsp = await self._aask(prompt)
rsp = CodeParser.parse_code(block=None, text=rsp)
return rsp
@staticmethod
def rsp_to_tasks(rsp: str) -> List[Task]:
rsp = json.loads(rsp)
tasks = [Task(**task_config) for task_config in rsp]
return tasks

View file

@ -0,0 +1,136 @@
from typing import Dict, List, Union
import json
import subprocess
import fire
from metagpt.roles import Role
from metagpt.actions import Action
from metagpt.schema import Message, Task, Plan
from metagpt.logs import logger
from metagpt.actions.write_plan import WritePlan
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
from metagpt.actions.execute_code import ExecutePyCode
class AskReview(Action):
async def run(self, context: List[Message], plan: Plan = None):
logger.info("Current overall plan:")
logger.info("\n".join([f"{task.task_id}: {task.instruction}" for task in plan.tasks]))
logger.info("most recent context:")
# prompt = "\n".join(
# [f"{msg.cause_by.__name__ if msg.cause_by else 'Main Requirement'}: {msg.content}" for msg in context]
# )
prompt = ""
latest_action = context[-1].cause_by.__name__
prompt += f"\nPlease review output from {latest_action}:\n" \
"If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \
"If you confirm the output and wish to continue with the current process, type CONFIRM:\n"
rsp = input(prompt)
confirmed = "confirm" in rsp.lower()
return rsp, confirmed
class WriteTaskGuide(Action):
async def run(self, task_instruction: str, data_desc: str = "") -> str:
return ""
class MLEngineer(Role):
def __init__(self, name="ABC", profile="MLEngineer", goal=""):
super().__init__(name=name, profile=profile, goal=goal)
self._set_react_mode(react_mode="plan_and_act")
self.plan = Plan(goal=goal)
self.use_tools = False
self.use_task_guide = False
async def _plan_and_act(self):
# create initial plan and update until confirmation
await self._update_plan()
while self.plan.current_task:
task = self.plan.current_task
logger.info(f"ready to take on task {task}")
# take on current task
code, result, success = await self._write_and_exec_code()
# ask for acceptance, users can other refuse and change tasks in the plan
task_result_confirmed = await self._ask_review()
if success and task_result_confirmed:
# tick off this task and record progress
task.code = code
task.result = result
self.plan.finish_current_task()
else:
# update plan according to user's feedback and to take on changed tasks
await self._update_plan()
async def _write_and_exec_code(self, max_retry: int = 3):
task_guide = await WriteTaskGuide().run(self.plan.current_task.instruction) if self.use_task_guide else ""
counter = 0
success = False
while not success and counter < max_retry:
context = self.get_useful_memories()
if not self.use_tools:
# code = "print('abc')"
code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide)
cause_by = WriteCodeByGenerate
else:
code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide)
cause_by = WriteCodeWithTools
self._rc.memory.add(Message(content=code, role="assistant", cause_by=cause_by))
result, success = await ExecutePyCode().run(code)
print(result)
self._rc.memory.add(Message(content=result, role="user", cause_by=ExecutePyCode))
# if not success:
# await self._ask_review()
counter += 1
return code, result, success
async def _ask_review(self):
context = self.get_useful_memories()
review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan)
self._rc.memory.add(Message(content=review, role="user", cause_by=AskReview))
return confirmed
async def _update_plan(self, max_tasks: int = 3):
current_plan = str([task.json() for task in self.plan.tasks])
plan_confirmed = False
while not plan_confirmed:
context = self.get_useful_memories()
rsp = await WritePlan().run(context, current_plan=current_plan, max_tasks=max_tasks)
self._rc.memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
plan_confirmed = await self._ask_review()
tasks = WritePlan.rsp_to_tasks(rsp)
self.plan.add_tasks(tasks)
def get_useful_memories(self, current_task_memories: List[str] = []) -> List[Message]:
"""find useful memories only to reduce context length and improve performance"""
memories = super().get_memories()
return memories
if __name__ == "__main__":
# requirement = "create a normal distribution and visualize it"
requirement = "run some analysis on iris dataset"
async def main(requirement: str = requirement):
role = MLEngineer(goal=requirement)
await role.run(requirement)
fire.Fire(main)

View file

@ -30,6 +30,7 @@ class Message:
sent_from: str = field(default="")
send_to: str = field(default="")
restricted_to: str = field(default="")
state: str = None # None, done, todo, doing, error
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
@ -72,6 +73,116 @@ class AIMessage(Message):
super().__init__(content, 'assistant')
class Task(BaseModel):
task_id: str = ""
dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task
instruction: str = ""
task_type: str = ""
code: str = ""
result: str = ""
is_finished: bool = False
class Plan(BaseModel):
goal: str
tasks: list[Task] = []
task_map: dict[str, Task] = {}
current_task_id = ""
def _topological_sort(self, tasks: list[Task]):
task_map = {task.task_id: task for task in tasks}
dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks}
sorted_tasks = []
visited = set()
def visit(task_id):
if task_id in visited:
return
visited.add(task_id)
for dependent_id in dependencies.get(task_id, []):
visit(dependent_id)
sorted_tasks.append(task_map[task_id])
for task in tasks:
visit(task.task_id)
return sorted_tasks
def add_tasks(self, tasks: list[Task]):
"""
Integrates new tasks into the existing plan, ensuring dependency order is maintained.
This method performs two primary functions based on the current state of the task list:
1. If there are no existing tasks, it topologically sorts the provided tasks to ensure
correct execution order based on dependencies, and sets these as the current tasks.
2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains
any common prefix of tasks (based on task_id and instruction) and appends the remainder
of the new tasks. The current task is updated to the first unfinished task in this merged list.
Args:
tasks (list[Task]): A list of tasks (may be unordered) to add to the plan.
Returns:
None: The method updates the internal state of the plan but does not return anything.
"""
if not tasks:
return
# Topologically sort the new tasks to ensure correct dependency order
new_tasks = self._topological_sort(tasks)
if not self.tasks:
# If there are no existing tasks, set the new tasks as the current tasks
self.tasks = new_tasks
else:
# Find the length of the common prefix between existing and new tasks
prefix_length = 0
for old_task, new_task in zip(self.tasks, new_tasks):
if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction:
break
prefix_length += 1
# Combine the common prefix with the remainder of the new tasks
final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:]
self.tasks = final_tasks
# Update current_task_id to the first unfinished task in the merged list
for task in self.tasks:
if not task.is_finished:
self.current_task_id = task.task_id
break
# Update the task map for quick access to tasks by ID
self.task_map = {task.task_id: task for task in self.tasks}
@property
def current_task(self) -> Task:
"""Find current task to execute
Returns:
Task: the current task to be executed
"""
return self.task_map.get(self.current_task_id, None)
def finish_current_task(self):
"""Finish current task, set Task.is_finished=True, set current task to next task
"""
if self.current_task_id:
current_task = self.current_task
current_task.is_finished = True
next_task_index = self.tasks.index(current_task) + 1
self.current_task_id = self.tasks[next_task_index].task_id if next_task_index < len(self.tasks) else None
def get_finished_tasks(self) -> list[Task]:
"""return all finished tasks in correct linearized order
Returns:
list[Task]: list of finished tasks
"""
return [task for task in self.tasks if task.is_finished]
if __name__ == '__main__':
test_content = 'test_message'
msgs = [

View file

@ -45,3 +45,8 @@ semantic-kernel==0.3.13.dev0
wrapt==1.15.0
websocket-client==0.58.0
zhipuai==1.0.7
rich==13.6.0
nbclient==0.9.0
nbformat==5.9.2
ipython==8.17.2
ipykernel==6.27.0

View file

@ -0,0 +1,57 @@
import pytest
from metagpt.actions import ExecutePyCode
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_code_running():
pi = ExecutePyCode()
output = await pi.run("print('hello world!')")
assert output.state == "done"
output = await pi.run({"code": "print('hello world!')", "language": "python"})
assert output.state == "done"
code_msg = Message("print('hello world!')")
output = await pi.run(code_msg)
assert output.state == "done"
@pytest.mark.asyncio
async def test_split_code_running():
pi = ExecutePyCode()
output = await pi.run("x=1\ny=2")
output = await pi.run("z=x+y")
output = await pi.run("assert z==3")
assert output.state == "done"
@pytest.mark.asyncio
async def test_execute_error():
pi = ExecutePyCode()
output = await pi.run("z=1/0")
assert output.state == "error"
@pytest.mark.asyncio
async def test_plotting_code():
pi = ExecutePyCode()
code = """
import numpy as np
import matplotlib.pyplot as plt
# 生成随机数据
random_data = np.random.randn(1000) # 生成1000个符合标准正态分布的随机数
# 绘制直方图
plt.hist(random_data, bins=30, density=True, alpha=0.7, color='blue', edgecolor='black')
# 添加标题和标签
plt.title('Histogram of Random Data')
plt.xlabel('Value')
plt.ylabel('Frequency')
# 显示图形
plt.show()
"""
output = await pi.run(code)
assert output.state == "done"

View file

@ -0,0 +1,41 @@
import pytest
from metagpt.actions.write_analysis_code import WriteCodeByGenerate
from metagpt.actions.execute_code import ExecutePyCode
@pytest.mark.asyncio
async def test_write_code():
write_code = WriteCodeByGenerate()
code = await write_code.run("Write a hello world code.")
assert "language" in code.content
assert "code" in code.content
print(code)
@pytest.mark.asyncio
async def test_write_code_by_list_prompt():
write_code = WriteCodeByGenerate()
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
code = await write_code.run(msg)
assert "language" in code.content
assert "code" in code.content
print(code)
@pytest.mark.asyncio
async def test_write_code_by_list_plan():
write_code = WriteCodeByGenerate()
execute_code = ExecutePyCode()
messages = []
plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "求均值"]
for task in plan:
print(f"\n任务: {task}\n\n")
messages.append(task)
code = await write_code.run(messages)
messages.append(code)
assert "language" in code.content
assert "code" in code.content
output = await execute_code.run(code)
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
messages.append(output)

View file

@ -0,0 +1,13 @@
import pytest
from metagpt.actions.write_plan import WritePlan
@pytest.mark.asyncio
async def test_plan():
p = WritePlan()
task_desc = """Heres some background information on Cyclistic, a bike-sharing company designing a marketing strategy aimed at converting casual riders into annual members: So far, Cyclistics marketing strategy has relied on building general awareness and engaging a wide range of consumers. group. One way to help achieve these goals is the flexibility of its pricing plans: one-way passes, full-day passes, and annual memberships. Customers who purchase a one-way or full-day pass are known as recreational riders. Customers purchasing an annual membership are Cyclistic members. I will provide you with a data sheet that records user behavior: '/Users/vicis/Downloads/202103-divvy-tripdata.csv"""
rsp = await p.run(task_desc, role="data analyst")
assert len(rsp.content) > 0
assert rsp.sent_from == "WritePlan"
print(rsp)

View file

@ -6,6 +6,7 @@
@File : test_schema.py
"""
from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
from metagpt.schema import Task, Plan
def test_messages():
@ -19,3 +20,87 @@ def test_messages():
text = str(msgs)
roles = ['user', 'system', 'assistant', 'QA']
assert all([i in text for i in roles])
class TestPlan:
def test_add_tasks_ordering(self):
plan = Plan(goal="")
tasks = [
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
Task(task_id="2", instruction="First"),
Task(task_id="3", dependent_task_ids=["2"], instruction="Second")
] # 2 -> 3 -> 1
plan.add_tasks(tasks)
assert [task.task_id for task in plan.tasks] == ["2", "3", "1"]
def test_add_tasks_to_existing_no_common_prefix(self):
plan = Plan(goal="")
tasks = [
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
Task(task_id="2", instruction="First"),
Task(task_id="3", dependent_task_ids=["2"], instruction="Second", is_finished=True)
] # 2 -> 3 -> 1
plan.add_tasks(tasks)
new_tasks = [Task(task_id="3", instruction="")]
plan.add_tasks(new_tasks)
assert [task.task_id for task in plan.tasks] == ["3"]
assert not plan.tasks[0].is_finished # must be the new unfinished task
def test_add_tasks_to_existing_with_common_prefix(self):
plan = Plan(goal="")
tasks = [
Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"),
Task(task_id="2", instruction="First"),
Task(task_id="3", dependent_task_ids=["2"], instruction="Second")
] # 2 -> 3 -> 1
plan.add_tasks(tasks)
plan.finish_current_task() # finish 2
plan.finish_current_task() # finish 3
new_tasks = [
Task(task_id="4", dependent_task_ids=["3"], instruction="Third"),
Task(task_id="2", instruction="First"),
Task(task_id="3", dependent_task_ids=["2"], instruction="Second")
] # 2 -> 3 -> 4, so the common prefix is 2 -> 3, and these two should be obtained from the existing tasks
plan.add_tasks(new_tasks)
assert [task.task_id for task in plan.tasks] == ["2", "3", "4"]
assert plan.tasks[0].is_finished and plan.tasks[1].is_finished # "2" and "3" should be the original finished one
assert plan.current_task_id == "4"
def test_current_task(self):
plan = Plan(goal="")
tasks = [
Task(task_id="1", dependent_task_ids=["2"], instruction="Second"),
Task(task_id="2", instruction="First")
]
plan.add_tasks(tasks)
assert plan.current_task.task_id == "2"
def test_finish_task(self):
plan = Plan(goal="")
tasks = [
Task(task_id="1", instruction="First"),
Task(task_id="2", dependent_task_ids=["1"], instruction="Second")
]
plan.add_tasks(tasks)
plan.finish_current_task()
assert plan.current_task.task_id == "2"
def test_finished_tasks(self):
plan = Plan(goal="")
tasks = [
Task(task_id="1", instruction="First"),
Task(task_id="2", dependent_task_ids=["1"], instruction="Second")
]
plan.add_tasks(tasks)
plan.finish_current_task()
finished_tasks = plan.get_finished_tasks()
assert len(finished_tasks) == 1
assert finished_tasks[0].task_id == "1"