From 3461b1b4c02e3891935c854448857d3c3d888a3c Mon Sep 17 00:00:00 2001 From: yzlin Date: Thu, 30 Nov 2023 14:40:51 +0800 Subject: [PATCH] add unit tests for reuse code --- metagpt/actions/write_analysis_code.py | 5 +- .../actions/test_write_analysis_code.py | 166 ++++++++++++++++-- 2 files changed, 150 insertions(+), 21 deletions(-) diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index 7e282b5a2..51cfa6d49 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -25,14 +25,13 @@ class BaseWriteAnalysisCode(Action): class WriteCodeByGenerate(BaseWriteAnalysisCode): """Write code fully by generation""" - DEFAULT_SYSTEM_MSG = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: Use !pip install to install missing packages.**""" - REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous steps in your current code block, include new codes only, DONT repeat codes!""" + DEFAULT_SYSTEM_MSG = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: Use !pip install in a standalone block to install missing packages.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt + REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous tasks in your current code block, include new codes only, DONT repeat codes!""" def __init__(self, name: str = "", context=None, llm=None) -> str: super().__init__(name, context, llm) def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None): - # Reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt default_system_msg = system_msg or self.DEFAULT_SYSTEM_MSG # 全部转成list if not isinstance(prompt, list): diff --git a/tests/metagpt/actions/test_write_analysis_code.py b/tests/metagpt/actions/test_write_analysis_code.py index cde5fa7ad..80d9438af 100644 --- a/tests/metagpt/actions/test_write_analysis_code.py +++ b/tests/metagpt/actions/test_write_analysis_code.py @@ -1,26 +1,10 @@ +import asyncio import pytest from metagpt.actions.write_analysis_code import WriteCodeByGenerate from metagpt.actions.execute_code import ExecutePyCode from metagpt.schema import Message - - -# @pytest.mark.asyncio -# async def test_write_code(): -# write_code = WriteCodeFunction() -# code = await write_code.run("Write a hello world code.") -# assert len(code) > 0 -# print(code) - - -# @pytest.mark.asyncio -# async def test_write_code_by_list_prompt(): -# write_code = WriteCodeFunction() -# msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] -# code = await write_code.run(msg) -# assert len(code) > 0 -# print(code) - +from metagpt.logs import logger @pytest.mark.asyncio async def test_write_code_by_list_plan(): @@ -37,3 +21,149 @@ async def test_write_code_by_list_plan(): output = await execute_code.run(code) print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n") messages.append(output[0]) + +@pytest.mark.asyncio +async def test_write_code_to_correct_error(): + + structural_context = """ + ## User Requirement + read a dataset test.csv and print its head + ## Current Plan + [ + { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "import pandas and load the dataset from 'test.csv'.", + "task_type": "", + "code": "", + "result": "", + "is_finished": false + }, + { + "task_id": "2", + "dependent_task_ids": [ + "1" + ], + "instruction": "Print the head of the dataset to display the first few rows.", + "task_type": "", + "code": "", + "result": "", + "is_finished": false + } + ] + ## Current Task + {"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false} + """ + wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv + error = """ + Traceback (most recent call last): + File "", line 2, in + File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel + io = ExcelFile(io, storage_options=storage_options, engine=engine) + File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__ + raise ValueError( + ValueError: Excel file format cannot be determined, you must specify an engine manually. + """ + context = [ + Message(content=structural_context, role="user"), + Message(content=wrong_code, role="assistant"), + Message(content=error, role="user"), + ] + new_code = await WriteCodeByGenerate().run(context=context) + print(new_code) + assert "read_csv" in new_code # should correct read_excel to read_csv + +@pytest.mark.asyncio +async def test_write_code_reuse_code_simple(): + structural_context = """ + ## User Requirement + read a dataset test.csv and print its head + ## Current Plan + [ + { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "import pandas and load the dataset from 'test.csv'.", + "task_type": "", + "code": "import pandas as pd\ndata = pd.read_csv('test.csv')", + "result": "", + "is_finished": true + }, + { + "task_id": "2", + "dependent_task_ids": [ + "1" + ], + "instruction": "Print the head of the dataset to display the first few rows.", + "task_type": "", + "code": "", + "result": "", + "is_finished": false + } + ] + ## Current Task + {"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false} + """ + context = [ + Message(content=structural_context, role="user"), + ] + code = await WriteCodeByGenerate().run(context=context) + print(code) + assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one + +@pytest.mark.asyncio +async def test_write_code_reuse_code_long(): + """test code reuse for long context""" + + structural_context = """ + ## User Requirement + Run data analysis on sklearn Iris dataset, include a plot + ## Current Plan + [ + { + "task_id": "1", + "dependent_task_ids": [], + "instruction": "Load the Iris dataset from sklearn.", + "task_type": "", + "code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]", + "result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))", + "is_finished": true + }, + { + "task_id": "2", + "dependent_task_ids": [ + "1" + ], + "instruction": "Perform exploratory data analysis on the Iris dataset.", + "task_type": "", + "code": "", + "result": "", + "is_finished": false + }, + { + "task_id": "3", + "dependent_task_ids": [ + "2" + ], + "instruction": "Create a plot visualizing the Iris dataset features.", + "task_type": "", + "code": "", + "result": "", + "is_finished": false + } + ] + ## Current Task + {"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false} + """ + context = [ + Message(content=structural_context, role="user"), + ] + trials_num = 5 + trials = [WriteCodeByGenerate().run(context=context) for _ in range(trials_num)] + trial_results = await asyncio.gather(*trials) + print(*trial_results, sep="\n\n***\n\n") + success = ["load_iris" not in result and "iris_data" in result \ + for result in trial_results] # should reuse iris_data from previous tasks + success_rate = sum(success) / trials_num + logger.info(f"success rate: {success_rate :.2f}") + assert success_rate >= 0.8