From 2a3f23ec62ebca8329c2748179d731025a685d0a Mon Sep 17 00:00:00 2001 From: lidanyang Date: Thu, 14 Dec 2023 10:32:58 +0800 Subject: [PATCH] fix unittest --- .../actions/test_write_analysis_code.py | 33 ++++++++----------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/tests/metagpt/actions/test_write_analysis_code.py b/tests/metagpt/actions/test_write_analysis_code.py index 661202115..1a568cdcd 100644 --- a/tests/metagpt/actions/test_write_analysis_code.py +++ b/tests/metagpt/actions/test_write_analysis_code.py @@ -31,22 +31,15 @@ async def test_tool_recommendation(): step 1: 对数据集进行去重 step 2: 对数据集进行缺失值处理 """ - available_tools = [ - { - "name": "fill_missing_value", - "description": "Completing missing values with simple strategies", - }, - { - "name": "split_bins", - "description": "Bin continuous data into intervals and return the bin identifier encoded as an integer value", - }, - ] + available_tools = { + "fill_missing_value": "Completing missing values with simple strategies", + "split_bins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value", + } write_code = WriteCodeWithTools() tools = await write_code._tool_recommendation(task, code_steps, available_tools) - assert len(tools) == 2 - assert tools[0] == [] - assert tools[1] == ["fill_missing_value"] + assert len(tools) == 1 + assert tools[0] == ["fill_missing_value"] @pytest.mark.asyncio @@ -57,7 +50,7 @@ async def test_write_code_with_tools(): "1": Task( task_id="1", instruction="随机生成一个pandas DataFrame数据集", - task_type="unknown", + task_type="other", dependent_task_ids=[], code=""" import pandas as pd @@ -75,6 +68,10 @@ async def test_write_code_with_tools(): instruction="对数据集进行数据清洗", task_type="data_preprocess", dependent_task_ids=["1"], + code_steps=""" + {"Step 1": "对数据集进行去重", + "Step 2": "对数据集进行缺失值处理"} + """ ), } plan = Plan( @@ -83,13 +80,9 @@ async def test_write_code_with_tools(): task_map=task_map, current_task_id="2", ) - task_guide = """ - step 1: 对数据集进行去重 - step 2: 对数据集进行缺失值处理 - """ - data_desc = "None" + column_info = "" - code = await write_code.run(messages, plan, task_guide, data_desc) + code = await write_code.run(messages, plan, column_info) assert len(code) > 0 print(code)