add unit tests for reuse code

This commit is contained in:
yzlin 2023-11-30 14:40:51 +08:00
parent dc2247010e
commit 3461b1b4c0
2 changed files with 150 additions and 21 deletions

View file

@ -25,14 +25,13 @@ class BaseWriteAnalysisCode(Action):
class WriteCodeByGenerate(BaseWriteAnalysisCode):
"""Write code fully by generation"""
DEFAULT_SYSTEM_MSG = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: Use !pip install to install missing packages.**"""
REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous steps in your current code block, include new codes only, DONT repeat codes!"""
DEFAULT_SYSTEM_MSG = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: Use !pip install in a standalone block to install missing packages.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous tasks in your current code block, include new codes only, DONT repeat codes!"""
def __init__(self, name: str = "", context=None, llm=None) -> str:
super().__init__(name, context, llm)
def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None):
# Reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
default_system_msg = system_msg or self.DEFAULT_SYSTEM_MSG
# 全部转成list
if not isinstance(prompt, list):

View file

@ -1,26 +1,10 @@
import asyncio
import pytest
from metagpt.actions.write_analysis_code import WriteCodeByGenerate
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.schema import Message
# @pytest.mark.asyncio
# async def test_write_code():
# write_code = WriteCodeFunction()
# code = await write_code.run("Write a hello world code.")
# assert len(code) > 0
# print(code)
# @pytest.mark.asyncio
# async def test_write_code_by_list_prompt():
# write_code = WriteCodeFunction()
# msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
# code = await write_code.run(msg)
# assert len(code) > 0
# print(code)
from metagpt.logs import logger
@pytest.mark.asyncio
async def test_write_code_by_list_plan():
@ -37,3 +21,149 @@ async def test_write_code_by_list_plan():
output = await execute_code.run(code)
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
messages.append(output[0])
@pytest.mark.asyncio
async def test_write_code_to_correct_error():
structural_context = """
## User Requirement
read a dataset test.csv and print its head
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "import pandas and load the dataset from 'test.csv'.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "2",
"dependent_task_ids": [
"1"
],
"instruction": "Print the head of the dataset to display the first few rows.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "1", "dependent_task_ids": [], "instruction": "import pandas and load the dataset from 'test.csv'.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
wrong_code = """import pandas as pd\ndata = pd.read_excel('test.csv')\ndata""" # use read_excel to read a csv
error = """
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 478, in read_excel
io = ExcelFile(io, storage_options=storage_options, engine=engine)
File "/Users/gary/miniconda3/envs/py39_scratch/lib/python3.9/site-packages/pandas/io/excel/_base.py", line 1500, in __init__
raise ValueError(
ValueError: Excel file format cannot be determined, you must specify an engine manually.
"""
context = [
Message(content=structural_context, role="user"),
Message(content=wrong_code, role="assistant"),
Message(content=error, role="user"),
]
new_code = await WriteCodeByGenerate().run(context=context)
print(new_code)
assert "read_csv" in new_code # should correct read_excel to read_csv
@pytest.mark.asyncio
async def test_write_code_reuse_code_simple():
structural_context = """
## User Requirement
read a dataset test.csv and print its head
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "import pandas and load the dataset from 'test.csv'.",
"task_type": "",
"code": "import pandas as pd\ndata = pd.read_csv('test.csv')",
"result": "",
"is_finished": true
},
{
"task_id": "2",
"dependent_task_ids": [
"1"
],
"instruction": "Print the head of the dataset to display the first few rows.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Print the head of the dataset to display the first few rows.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
context = [
Message(content=structural_context, role="user"),
]
code = await WriteCodeByGenerate().run(context=context)
print(code)
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
@pytest.mark.asyncio
async def test_write_code_reuse_code_long():
"""test code reuse for long context"""
structural_context = """
## User Requirement
Run data analysis on sklearn Iris dataset, include a plot
## Current Plan
[
{
"task_id": "1",
"dependent_task_ids": [],
"instruction": "Load the Iris dataset from sklearn.",
"task_type": "",
"code": "from sklearn.datasets import load_iris\niris_data = load_iris()\niris_data['data'][0:5], iris_data['target'][0:5]",
"result": "(array([[5.1, 3.5, 1.4, 0.2],\n [4.9, 3. , 1.4, 0.2],\n [4.7, 3.2, 1.3, 0.2],\n [4.6, 3.1, 1.5, 0.2],\n [5. , 3.6, 1.4, 0.2]]),\n array([0, 0, 0, 0, 0]))",
"is_finished": true
},
{
"task_id": "2",
"dependent_task_ids": [
"1"
],
"instruction": "Perform exploratory data analysis on the Iris dataset.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
},
{
"task_id": "3",
"dependent_task_ids": [
"2"
],
"instruction": "Create a plot visualizing the Iris dataset features.",
"task_type": "",
"code": "",
"result": "",
"is_finished": false
}
]
## Current Task
{"task_id": "2", "dependent_task_ids": ["1"], "instruction": "Perform exploratory data analysis on the Iris dataset.", "task_type": "", "code": "", "result": "", "is_finished": false}
"""
context = [
Message(content=structural_context, role="user"),
]
trials_num = 5
trials = [WriteCodeByGenerate().run(context=context) for _ in range(trials_num)]
trial_results = await asyncio.gather(*trials)
print(*trial_results, sep="\n\n***\n\n")
success = ["load_iris" not in result and "iris_data" in result \
for result in trial_results] # should reuse iris_data from previous tasks
success_rate = sum(success) / trials_num
logger.info(f"success rate: {success_rate :.2f}")
assert success_rate >= 0.8