fix unittest

This commit is contained in:
lidanyang 2023-12-14 10:32:58 +08:00
parent 3a3b31badb
commit 2a3f23ec62

View file

@ -31,22 +31,15 @@ async def test_tool_recommendation():
step 1: 对数据集进行去重
step 2: 对数据集进行缺失值处理
"""
available_tools = [
{
"name": "fill_missing_value",
"description": "Completing missing values with simple strategies",
},
{
"name": "split_bins",
"description": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
},
]
available_tools = {
"fill_missing_value": "Completing missing values with simple strategies",
"split_bins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
}
write_code = WriteCodeWithTools()
tools = await write_code._tool_recommendation(task, code_steps, available_tools)
assert len(tools) == 2
assert tools[0] == []
assert tools[1] == ["fill_missing_value"]
assert len(tools) == 1
assert tools[0] == ["fill_missing_value"]
@pytest.mark.asyncio
@ -57,7 +50,7 @@ async def test_write_code_with_tools():
"1": Task(
task_id="1",
instruction="随机生成一个pandas DataFrame数据集",
task_type="unknown",
task_type="other",
dependent_task_ids=[],
code="""
import pandas as pd
@ -75,6 +68,10 @@ async def test_write_code_with_tools():
instruction="对数据集进行数据清洗",
task_type="data_preprocess",
dependent_task_ids=["1"],
code_steps="""
{"Step 1": "对数据集进行去重",
"Step 2": "对数据集进行缺失值处理"}
"""
),
}
plan = Plan(
@ -83,13 +80,9 @@ async def test_write_code_with_tools():
task_map=task_map,
current_task_id="2",
)
task_guide = """
step 1: 对数据集进行去重
step 2: 对数据集进行缺失值处理
"""
data_desc = "None"
column_info = ""
code = await write_code.run(messages, plan, task_guide, data_desc)
code = await write_code.run(messages, plan, column_info)
assert len(code) > 0
print(code)