diff --git a/tests/metagpt/actions/test_write_analysis_code.py b/tests/metagpt/actions/test_write_analysis_code.py index cde5fa7ad..2319331d4 100644 --- a/tests/metagpt/actions/test_write_analysis_code.py +++ b/tests/metagpt/actions/test_write_analysis_code.py @@ -1,8 +1,8 @@ import pytest -from metagpt.actions.write_analysis_code import WriteCodeByGenerate +from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools from metagpt.actions.execute_code import ExecutePyCode -from metagpt.schema import Message +from metagpt.schema import Message, Plan, Task # @pytest.mark.asyncio @@ -37,3 +37,73 @@ async def test_write_code_by_list_plan(): output = await execute_code.run(code) print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n") messages.append(output[0]) + + +@pytest.mark.asyncio +async def test_tool_recommendation(): + task = "对已经读取的数据集进行数据清洗" + code_steps = """ + step 1: 对数据集进行去重 + step 2: 对数据集进行缺失值处理 + """ + available_tools = [ + { + "name": "fill_missing_value", + "description": "Completing missing values with simple strategies", + }, + { + "name": "split_bins", + "description": "Bin continuous data into intervals and return the bin identifier encoded as an integer value", + }, + ] + write_code = WriteCodeWithTools() + tools = await write_code._tool_recommendation(task, code_steps, available_tools) + + assert len(tools) == 2 + assert tools[0] == [] + assert tools[1] == ["fill_missing_value"] + + +@pytest.mark.asyncio +async def test_write_code_with_tools(): + write_code = WriteCodeWithTools() + messages = [] + task_map = { + "1": Task( + task_id="1", + instruction="随机生成一个pandas DataFrame数据集", + task_type="unknown", + dependent_task_ids=[], + code=""" + import pandas as pd + df = pd.DataFrame({ + 'a': [1, 2, 3, 4, 5], + 'b': [1.1, 2.2, 3.3, 4.4, np.nan], + 'c': ['aa', 'bb', 'cc', 'dd', 'ee'], + 'd': [1, 2, 3, 4, 5] + }) + """, + is_finished=True, + ), + "2": Task( + task_id="2", + instruction="对数据集进行数据清洗", + task_type="data_preprocess", + dependent_task_ids=["1"], + ), + } + plan = Plan( + goal="构造数据集并进行数据清洗", + tasks=list(task_map.values()), + task_map=task_map, + current_task_id="2", + ) + task_guide = """ + step 1: 对数据集进行去重 + step 2: 对数据集进行缺失值处理 + """ + data_desc = "None" + + code = await write_code.run(messages, plan, task_guide, data_desc) + assert len(code) > 0 + print(code)