diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index b004bd58e..5055ce276 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -23,6 +23,9 @@ from metagpt.actions.write_code_review import WriteCodeReview from metagpt.actions.write_prd import WritePRD from metagpt.actions.write_prd_review import WritePRDReview from metagpt.actions.write_test import WriteTest +from metagpt.actions.execute_code import ExecutePyCode +from metagpt.actions.write_analysis_code import WriteCodeByGenerate +from metagpt.actions.write_plan import WritePlan class ActionType(Enum): @@ -45,6 +48,9 @@ class ActionType(Enum): COLLECT_LINKS = CollectLinks WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize CONDUCT_RESEARCH = ConductResearch + EXECUTE_PYCODE = ExecutePyCode + WRITE_CODE_BY_GENERATE = WriteCodeByGenerate + WRITE_PLAN = WritePlan __all__ = [ diff --git a/metagpt/actions/execute_code.py b/metagpt/actions/execute_code.py new file mode 100644 index 000000000..7b16d559a --- /dev/null +++ b/metagpt/actions/execute_code.py @@ -0,0 +1,178 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/17 14:22:15 +@Author : orange-crow +@File : code_executor.py +""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Tuple, Union +import traceback + +import nbformat +from nbclient import NotebookClient +from nbformat.v4 import new_code_cell, new_output +from rich.console import Console +from rich.syntax import Syntax + +from metagpt.actions import Action +from metagpt.schema import Message + + +class ExecuteCode(ABC): + @abstractmethod + async def build(self): + """build code executor""" + ... + + @abstractmethod + async def run(self, code: str): + """run code""" + ... + + @abstractmethod + async def terminate(self): + """terminate executor""" + ... + + @abstractmethod + async def reset(self): + """reset executor""" + ... + + +class ExecutePyCode(ExecuteCode, Action): + """execute code, return result to llm, and display it.""" + + def __init__(self, name: str = "python_executor", context=None, llm=None): + super().__init__(name, context, llm) + self.nb = nbformat.v4.new_notebook() + self.nb_client = NotebookClient(self.nb) + self.console = Console() + self.interaction = "ipython" if self.is_ipython() else "terminal" + + async def build(self): + if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): + self.nb_client.create_kernel_manager() + self.nb_client.start_new_kernel() + self.nb_client.start_new_kernel_client() + + async def terminate(self): + """kill NotebookClient""" + await self.nb_client._async_cleanup_kernel() + + async def reset(self): + """reset NotebookClient""" + await self.terminate() + self.nb_client = NotebookClient(self.nb) + + def add_code_cell(self, code): + self.nb.cells.append(new_code_cell(source=code)) + + def _display(self, code, language: str = "python"): + if language == "python": + code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True) + self.console.print("\n") + self.console.print(code) + + def add_output_to_cell(self, cell, output): + if "outputs" not in cell: + cell["outputs"] = [] + # TODO: show figures + else: + cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) + + def parse_outputs(self, outputs: List) -> str: + assert isinstance(outputs, list) + parsed_output = "" + + # empty outputs: such as 'x=1\ny=2' + if not outputs: + return parsed_output + + for output in outputs: + if output["output_type"] == "stream": + parsed_output += output["text"] + elif output["output_type"] == "display_data": + self.show_bytes_figure(output["data"]["image/png"], self.interaction) + elif output["output_type"] == "execute_result": + parsed_output += output["data"]["text/plain"] + return parsed_output + + def show_bytes_figure(self, image_base64: str, interaction_type: str = "ipython"): + import base64 + + image_bytes = base64.b64decode(image_base64) + if interaction_type == "ipython": + from IPython.display import Image, display + + display(Image(data=image_bytes)) + else: + import io + + from PIL import Image + + image = Image.open(io.BytesIO(image_bytes)) + image.show() + + def is_ipython(self) -> bool: + try: + # 如果在Jupyter Notebook中运行,__file__ 变量不存在 + from IPython import get_ipython + + if get_ipython() is not None and "IPKernelApp" in get_ipython().config: + return True + else: + return False + except NameError: + # 如果在Python脚本中运行,__file__ 变量存在 + return False + + def _process_code(self, code: Union[str, Dict, Message], language: str = None) -> Tuple: + language = language or 'python' + if isinstance(code, str) and Path(code).suffix in (".py", ".txt"): + code = Path(code).read_text(encoding="utf-8") + return code, language + + if isinstance(code, str): + return code, language + + if isinstance(code, dict): + assert "code" in code + if "language" not in code: + code['language'] = 'python' + code, language = code["code"], code["language"] + elif isinstance(code, Message): + if isinstance(code.content, dict) and "language" not in code.content: + code.content["language"] = 'python' + code, language = code.content["code"], code.content["language"] + elif isinstance(code.content, str): + code, language = code.content, language + else: + raise ValueError(f"Not support code type {type(code).__name__}.") + + return code, language + + async def run(self, code: Union[str, Dict, Message], language: str = "python") -> Tuple[str, bool]: + code, language = self._process_code(code, language) + + self._display(code, language) + + if language == "python": + # add code to the notebook + self.add_code_cell(code=code) + try: + # build code executor + await self.build() + # run code + # TODO: add max_tries for run code. + cell_index = len(self.nb.cells) - 1 + await self.nb_client.async_execute_cell(self.nb.cells[-1], cell_index) + return self.parse_outputs(self.nb.cells[-1].outputs), True + except Exception as e: + # FIXME: CellExecutionError is hard to read. for example `1\0` raise ZeroDivisionError: + # CellExecutionError('An error occurred while executing the following cell:\n------------------\nz=1/0\n------------------\n\n\n\x1b[0;31m---------------------------------------------------------------------------\x1b[0m\n\x1b[0;31mZeroDivisionError\x1b[0m Traceback (most recent call last)\nCell \x1b[0;32mIn[1], line 1\x1b[0m\n\x1b[0;32m----> 1\x1b[0m z\x1b[38;5;241m=\x1b[39m\x1b[38;5;241;43m1\x1b[39;49m\x1b[38;5;241;43m/\x1b[39;49m\x1b[38;5;241;43m0\x1b[39;49m\n\n\x1b[0;31mZeroDivisionError\x1b[0m: division by zero\n') + return traceback.format_exc(), False + else: + # TODO: markdown + raise NotImplementedError(f"Not support this code type : {language}, Only support code!") diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py new file mode 100644 index 000000000..84922ada4 --- /dev/null +++ b/metagpt/actions/write_analysis_code.py @@ -0,0 +1,71 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 13:19:39 +@Author : orange-crow +@File : write_code_v2.py +""" +from typing import Dict, List, Union + +from metagpt.actions import Action +from metagpt.schema import Message, Plan + +class BaseWriteAnalysisCode(Action): + + async def run(self, context: List[Message], plan: Plan = None, task_guide: str = "") -> str: + """Run of a code writing action, used in data analysis or modeling + + Args: + context (List[Message]): Action output history, source action denoted by Message.cause_by + plan (Plan, optional): Overall plan. Defaults to None. + task_guide (str, optional): suggested step breakdown for the current task. Defaults to "". + + Returns: + str: The code string. + """ + +class WriteCodeByGenerate(BaseWriteAnalysisCode): + """Write code fully by generation""" + + def __init__(self, name: str = "", context=None, llm=None) -> str: + super().__init__(name, context, llm) + + def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None): + if isinstance(prompt, str): + return system_msg + prompt if system_msg else prompt + + if isinstance(prompt, Message): + if isinstance(prompt.content, dict): + prompt.content = system_msg + str([(k, v) for k, v in prompt.content.items()])\ + if system_msg else prompt.content + else: + prompt.content = system_msg + prompt.content if system_msg else prompt.content + return prompt + + if isinstance(prompt, list): + _prompt = [] + for msg in prompt: + if isinstance(msg, Message) and isinstance(msg.content, dict): + msg.content = str([(k, v) for k, v in msg.content.items()]) + if isinstance(msg, Message): + msg = msg.to_dict() + _prompt.append(msg) + prompt = _prompt + + if isinstance(prompt, list) and system_msg: + if system_msg not in prompt[0]['content']: + prompt[0]['content'] = system_msg + prompt[0]['content'] + return prompt + + async def run( + self, context: [List[Message]], plan: Plan = None, task_guide: str = "", system_msg: str = None, **kwargs + ) -> str: + prompt = self.process_msg(context, system_msg) + code_content = await self.llm.aask_code(prompt, **kwargs) + return code_content["code"] + + +class WriteCodeWithTools(BaseWriteAnalysisCode): + """Write code with help of local available tools. Choose tools first, then generate code to use the tools""" + + async def run(self, context: List[Message], plan: Plan = None, task_guide: str = "") -> str: + return "print('abc')" diff --git a/metagpt/actions/write_plan.py b/metagpt/actions/write_plan.py new file mode 100644 index 000000000..e35ba7a92 --- /dev/null +++ b/metagpt/actions/write_plan.py @@ -0,0 +1,48 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 11:24:03 +@Author : orange-crow +@File : plan.py +""" +from typing import List +import json + +from metagpt.actions import Action +from metagpt.schema import Message, Task +from metagpt.utils.common import CodeParser + +class WritePlan(Action): + PROMPT_TEMPLATE = """ + # Context: + __context__ + # Current Plan: + __current_plan__ + # Task: + Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks. + If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. + Output a list of jsons following the format: + ```json + [ + { + "task_id": str = "unique identifier for a task in plan, can be an ordinal", + "dependent_task_ids": list[str] = "ids of tasks prerequisite to this task", + "instruction": "what you should do in this task, one short phrase or sentence", + }, + ... + ] + ``` + """ + async def run(self, context: List[Message], current_plan: str = "", max_tasks: int = 5) -> str: + prompt = ( + self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context])) + .replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks)) + ) + rsp = await self._aask(prompt) + rsp = CodeParser.parse_code(block=None, text=rsp) + return rsp + + @staticmethod + def rsp_to_tasks(rsp: str) -> List[Task]: + rsp = json.loads(rsp) + tasks = [Task(**task_config) for task_config in rsp] + return tasks diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py new file mode 100644 index 000000000..480f6cecf --- /dev/null +++ b/metagpt/roles/ml_engineer.py @@ -0,0 +1,136 @@ +from typing import Dict, List, Union +import json +import subprocess + +import fire + +from metagpt.roles import Role +from metagpt.actions import Action +from metagpt.schema import Message, Task, Plan +from metagpt.logs import logger +from metagpt.actions.write_plan import WritePlan +from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools +from metagpt.actions.execute_code import ExecutePyCode + +class AskReview(Action): + + async def run(self, context: List[Message], plan: Plan = None): + logger.info("Current overall plan:") + logger.info("\n".join([f"{task.task_id}: {task.instruction}" for task in plan.tasks])) + + logger.info("most recent context:") + # prompt = "\n".join( + # [f"{msg.cause_by.__name__ if msg.cause_by else 'Main Requirement'}: {msg.content}" for msg in context] + # ) + prompt = "" + latest_action = context[-1].cause_by.__name__ + prompt += f"\nPlease review output from {latest_action}:\n" \ + "If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \ + "If you confirm the output and wish to continue with the current process, type CONFIRM:\n" + rsp = input(prompt) + confirmed = "confirm" in rsp.lower() + + return rsp, confirmed + +class WriteTaskGuide(Action): + + async def run(self, task_instruction: str, data_desc: str = "") -> str: + return "" + +class MLEngineer(Role): + def __init__(self, name="ABC", profile="MLEngineer", goal=""): + super().__init__(name=name, profile=profile, goal=goal) + self._set_react_mode(react_mode="plan_and_act") + self.plan = Plan(goal=goal) + self.use_tools = False + self.use_task_guide = False + + async def _plan_and_act(self): + + # create initial plan and update until confirmation + await self._update_plan() + + while self.plan.current_task: + task = self.plan.current_task + logger.info(f"ready to take on task {task}") + + # take on current task + code, result, success = await self._write_and_exec_code() + + # ask for acceptance, users can other refuse and change tasks in the plan + task_result_confirmed = await self._ask_review() + + if success and task_result_confirmed: + # tick off this task and record progress + task.code = code + task.result = result + self.plan.finish_current_task() + + else: + # update plan according to user's feedback and to take on changed tasks + await self._update_plan() + + async def _write_and_exec_code(self, max_retry: int = 3): + + task_guide = await WriteTaskGuide().run(self.plan.current_task.instruction) if self.use_task_guide else "" + + counter = 0 + success = False + while not success and counter < max_retry: + context = self.get_useful_memories() + + if not self.use_tools: + # code = "print('abc')" + code = await WriteCodeByGenerate().run(context=context, plan=self.plan, task_guide=task_guide) + cause_by = WriteCodeByGenerate + + else: + code = await WriteCodeWithTools().run(context=context, plan=self.plan, task_guide=task_guide) + cause_by = WriteCodeWithTools + + self._rc.memory.add(Message(content=code, role="assistant", cause_by=cause_by)) + + result, success = await ExecutePyCode().run(code) + print(result) + self._rc.memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) + + # if not success: + # await self._ask_review() + + counter += 1 + + return code, result, success + + async def _ask_review(self): + context = self.get_useful_memories() + review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan) + self._rc.memory.add(Message(content=review, role="user", cause_by=AskReview)) + return confirmed + + async def _update_plan(self, max_tasks: int = 3): + current_plan = str([task.json() for task in self.plan.tasks]) + plan_confirmed = False + while not plan_confirmed: + context = self.get_useful_memories() + rsp = await WritePlan().run(context, current_plan=current_plan, max_tasks=max_tasks) + self._rc.memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan)) + plan_confirmed = await self._ask_review() + + tasks = WritePlan.rsp_to_tasks(rsp) + self.plan.add_tasks(tasks) + + def get_useful_memories(self, current_task_memories: List[str] = []) -> List[Message]: + """find useful memories only to reduce context length and improve performance""" + memories = super().get_memories() + return memories + + +if __name__ == "__main__": + # requirement = "create a normal distribution and visualize it" + requirement = "run some analysis on iris dataset" + + async def main(requirement: str = requirement): + role = MLEngineer(goal=requirement) + await role.run(requirement) + + fire.Fire(main) diff --git a/metagpt/schema.py b/metagpt/schema.py index bdca093c2..e39f54a0c 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -30,6 +30,7 @@ class Message: sent_from: str = field(default="") send_to: str = field(default="") restricted_to: str = field(default="") + state: str = None # None, done, todo, doing, error def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -72,6 +73,116 @@ class AIMessage(Message): super().__init__(content, 'assistant') +class Task(BaseModel): + task_id: str = "" + dependent_task_ids: list[str] = [] # Tasks prerequisite to this Task + instruction: str = "" + task_type: str = "" + code: str = "" + result: str = "" + is_finished: bool = False + + +class Plan(BaseModel): + goal: str + tasks: list[Task] = [] + task_map: dict[str, Task] = {} + current_task_id = "" + + def _topological_sort(self, tasks: list[Task]): + task_map = {task.task_id: task for task in tasks} + dependencies = {task.task_id: set(task.dependent_task_ids) for task in tasks} + sorted_tasks = [] + visited = set() + + def visit(task_id): + if task_id in visited: + return + visited.add(task_id) + for dependent_id in dependencies.get(task_id, []): + visit(dependent_id) + sorted_tasks.append(task_map[task_id]) + + for task in tasks: + visit(task.task_id) + + return sorted_tasks + + def add_tasks(self, tasks: list[Task]): + """ + Integrates new tasks into the existing plan, ensuring dependency order is maintained. + + This method performs two primary functions based on the current state of the task list: + 1. If there are no existing tasks, it topologically sorts the provided tasks to ensure + correct execution order based on dependencies, and sets these as the current tasks. + 2. If there are existing tasks, it merges the new tasks with the existing ones. It maintains + any common prefix of tasks (based on task_id and instruction) and appends the remainder + of the new tasks. The current task is updated to the first unfinished task in this merged list. + + Args: + tasks (list[Task]): A list of tasks (may be unordered) to add to the plan. + + Returns: + None: The method updates the internal state of the plan but does not return anything. + """ + if not tasks: + return + + # Topologically sort the new tasks to ensure correct dependency order + new_tasks = self._topological_sort(tasks) + + if not self.tasks: + # If there are no existing tasks, set the new tasks as the current tasks + self.tasks = new_tasks + + else: + # Find the length of the common prefix between existing and new tasks + prefix_length = 0 + for old_task, new_task in zip(self.tasks, new_tasks): + if old_task.task_id != new_task.task_id or old_task.instruction != new_task.instruction: + break + prefix_length += 1 + + # Combine the common prefix with the remainder of the new tasks + final_tasks = self.tasks[:prefix_length] + new_tasks[prefix_length:] + self.tasks = final_tasks + + # Update current_task_id to the first unfinished task in the merged list + for task in self.tasks: + if not task.is_finished: + self.current_task_id = task.task_id + break + + # Update the task map for quick access to tasks by ID + self.task_map = {task.task_id: task for task in self.tasks} + + @property + def current_task(self) -> Task: + """Find current task to execute + + Returns: + Task: the current task to be executed + """ + return self.task_map.get(self.current_task_id, None) + + def finish_current_task(self): + """Finish current task, set Task.is_finished=True, set current task to next task + """ + if self.current_task_id: + current_task = self.current_task + current_task.is_finished = True + next_task_index = self.tasks.index(current_task) + 1 + self.current_task_id = self.tasks[next_task_index].task_id if next_task_index < len(self.tasks) else None + + def get_finished_tasks(self) -> list[Task]: + """return all finished tasks in correct linearized order + + Returns: + list[Task]: list of finished tasks + """ + return [task for task in self.tasks if task.is_finished] + + if __name__ == '__main__': test_content = 'test_message' msgs = [ diff --git a/requirements.txt b/requirements.txt index f0169d7fa..c0f466457 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,8 @@ semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 zhipuai==1.0.7 +rich==13.6.0 +nbclient==0.9.0 +nbformat==5.9.2 +ipython==8.17.2 +ipykernel==6.27.0 \ No newline at end of file diff --git a/tests/metagpt/actions/test_execute_code.py b/tests/metagpt/actions/test_execute_code.py new file mode 100644 index 000000000..88c5adf18 --- /dev/null +++ b/tests/metagpt/actions/test_execute_code.py @@ -0,0 +1,57 @@ +import pytest + +from metagpt.actions import ExecutePyCode +from metagpt.schema import Message + + +@pytest.mark.asyncio +async def test_code_running(): + pi = ExecutePyCode() + output = await pi.run("print('hello world!')") + assert output.state == "done" + output = await pi.run({"code": "print('hello world!')", "language": "python"}) + assert output.state == "done" + code_msg = Message("print('hello world!')") + output = await pi.run(code_msg) + assert output.state == "done" + + +@pytest.mark.asyncio +async def test_split_code_running(): + pi = ExecutePyCode() + output = await pi.run("x=1\ny=2") + output = await pi.run("z=x+y") + output = await pi.run("assert z==3") + assert output.state == "done" + + +@pytest.mark.asyncio +async def test_execute_error(): + pi = ExecutePyCode() + output = await pi.run("z=1/0") + assert output.state == "error" + + +@pytest.mark.asyncio +async def test_plotting_code(): + pi = ExecutePyCode() + code = """ + import numpy as np + import matplotlib.pyplot as plt + + # 生成随机数据 + random_data = np.random.randn(1000) # 生成1000个符合标准正态分布的随机数 + + # 绘制直方图 + plt.hist(random_data, bins=30, density=True, alpha=0.7, color='blue', edgecolor='black') + + # 添加标题和标签 + plt.title('Histogram of Random Data') + plt.xlabel('Value') + plt.ylabel('Frequency') + + # 显示图形 + plt.show() + """ + output = await pi.run(code) + assert output.state == "done" diff --git a/tests/metagpt/actions/test_write_analysis_code.py b/tests/metagpt/actions/test_write_analysis_code.py new file mode 100644 index 000000000..41c0479a9 --- /dev/null +++ b/tests/metagpt/actions/test_write_analysis_code.py @@ -0,0 +1,41 @@ +import pytest + +from metagpt.actions.write_analysis_code import WriteCodeByGenerate +from metagpt.actions.execute_code import ExecutePyCode + + +@pytest.mark.asyncio +async def test_write_code(): + write_code = WriteCodeByGenerate() + code = await write_code.run("Write a hello world code.") + assert "language" in code.content + assert "code" in code.content + print(code) + + +@pytest.mark.asyncio +async def test_write_code_by_list_prompt(): + write_code = WriteCodeByGenerate() + msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] + code = await write_code.run(msg) + assert "language" in code.content + assert "code" in code.content + print(code) + + +@pytest.mark.asyncio +async def test_write_code_by_list_plan(): + write_code = WriteCodeByGenerate() + execute_code = ExecutePyCode() + messages = [] + plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "求均值"] + for task in plan: + print(f"\n任务: {task}\n\n") + messages.append(task) + code = await write_code.run(messages) + messages.append(code) + assert "language" in code.content + assert "code" in code.content + output = await execute_code.run(code) + print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n") + messages.append(output) diff --git a/tests/metagpt/actions/test_write_plan.py b/tests/metagpt/actions/test_write_plan.py new file mode 100644 index 000000000..2bf200ab3 --- /dev/null +++ b/tests/metagpt/actions/test_write_plan.py @@ -0,0 +1,13 @@ +import pytest + +from metagpt.actions.write_plan import WritePlan + + +@pytest.mark.asyncio +async def test_plan(): + p = WritePlan() + task_desc = """Here’s some background information on Cyclistic, a bike-sharing company designing a marketing strategy aimed at converting casual riders into annual members: So far, Cyclistic’s marketing strategy has relied on building general awareness and engaging a wide range of consumers. group. One way to help achieve these goals is the flexibility of its pricing plans: one-way passes, full-day passes, and annual memberships. Customers who purchase a one-way or full-day pass are known as recreational riders. Customers purchasing an annual membership are Cyclistic members. I will provide you with a data sheet that records user behavior: '/Users/vicis/Downloads/202103-divvy-tripdata.csv""" + rsp = await p.run(task_desc, role="data analyst") + assert len(rsp.content) > 0 + assert rsp.sent_from == "WritePlan" + print(rsp) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 12666e0d3..8f65d3785 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -6,6 +6,7 @@ @File : test_schema.py """ from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.schema import Task, Plan def test_messages(): @@ -19,3 +20,87 @@ def test_messages(): text = str(msgs) roles = ['user', 'system', 'assistant', 'QA'] assert all([i in text for i in roles]) + + +class TestPlan: + def test_add_tasks_ordering(self): + plan = Plan(goal="") + + tasks = [ + Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"), + Task(task_id="2", instruction="First"), + Task(task_id="3", dependent_task_ids=["2"], instruction="Second") + ] # 2 -> 3 -> 1 + plan.add_tasks(tasks) + + assert [task.task_id for task in plan.tasks] == ["2", "3", "1"] + + def test_add_tasks_to_existing_no_common_prefix(self): + plan = Plan(goal="") + + tasks = [ + Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"), + Task(task_id="2", instruction="First"), + Task(task_id="3", dependent_task_ids=["2"], instruction="Second", is_finished=True) + ] # 2 -> 3 -> 1 + plan.add_tasks(tasks) + + new_tasks = [Task(task_id="3", instruction="")] + plan.add_tasks(new_tasks) + + assert [task.task_id for task in plan.tasks] == ["3"] + assert not plan.tasks[0].is_finished # must be the new unfinished task + + def test_add_tasks_to_existing_with_common_prefix(self): + plan = Plan(goal="") + + tasks = [ + Task(task_id="1", dependent_task_ids=["2", "3"], instruction="Third"), + Task(task_id="2", instruction="First"), + Task(task_id="3", dependent_task_ids=["2"], instruction="Second") + ] # 2 -> 3 -> 1 + plan.add_tasks(tasks) + plan.finish_current_task() # finish 2 + plan.finish_current_task() # finish 3 + + new_tasks = [ + Task(task_id="4", dependent_task_ids=["3"], instruction="Third"), + Task(task_id="2", instruction="First"), + Task(task_id="3", dependent_task_ids=["2"], instruction="Second") + ] # 2 -> 3 -> 4, so the common prefix is 2 -> 3, and these two should be obtained from the existing tasks + plan.add_tasks(new_tasks) + + assert [task.task_id for task in plan.tasks] == ["2", "3", "4"] + assert plan.tasks[0].is_finished and plan.tasks[1].is_finished # "2" and "3" should be the original finished one + assert plan.current_task_id == "4" + + def test_current_task(self): + plan = Plan(goal="") + tasks = [ + Task(task_id="1", dependent_task_ids=["2"], instruction="Second"), + Task(task_id="2", instruction="First") + ] + plan.add_tasks(tasks) + assert plan.current_task.task_id == "2" + + def test_finish_task(self): + plan = Plan(goal="") + tasks = [ + Task(task_id="1", instruction="First"), + Task(task_id="2", dependent_task_ids=["1"], instruction="Second") + ] + plan.add_tasks(tasks) + plan.finish_current_task() + assert plan.current_task.task_id == "2" + + def test_finished_tasks(self): + plan = Plan(goal="") + tasks = [ + Task(task_id="1", instruction="First"), + Task(task_id="2", dependent_task_ids=["1"], instruction="Second") + ] + plan.add_tasks(tasks) + plan.finish_current_task() + finished_tasks = plan.get_finished_tasks() + assert len(finished_tasks) == 1 + assert finished_tasks[0].task_id == "1"