update test for write_code_with_tools

This commit is contained in:
lidanyang 2024-01-12 16:04:44 +08:00
parent 45e2d6c5f5
commit a9575380b5

View file

@ -3,8 +3,13 @@ import asyncio
import pytest
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
from metagpt.actions.write_analysis_code import (
WriteCodeByGenerate,
WriteCodeWithTools,
WriteCodeWithToolsML,
)
from metagpt.logs import logger
from metagpt.plan.planner import STRUCTURAL_CONTEXT
from metagpt.schema import Message, Plan, Task
@ -40,13 +45,15 @@ async def test_tool_recommendation():
tools = await write_code._tool_recommendation(task, code_steps, available_tools)
assert len(tools) == 1
assert tools[0] == ["fill_missing_value"]
assert tools[0] == "fill_missing_value"
@pytest.mark.asyncio
async def test_write_code_with_tools():
write_code = WriteCodeWithTools()
messages = []
write_code_ml = WriteCodeWithToolsML()
requirement = "构造数据集并进行数据清洗"
task_map = {
"1": Task(
task_id="1",
@ -69,10 +76,6 @@ async def test_write_code_with_tools():
instruction="对数据集进行数据清洗",
task_type="data_preprocess",
dependent_task_ids=["1"],
code_steps="""
{"Step 1": "对数据集进行去重",
"Step 2": "对数据集进行缺失值处理"}
""",
),
}
plan = Plan(
@ -83,10 +86,22 @@ async def test_write_code_with_tools():
)
column_info = ""
code = await write_code.run(messages, plan, column_info)
context = STRUCTURAL_CONTEXT.format(
user_requirement=requirement,
context=plan.context,
tasks=list(task_map.values()),
current_task=plan.current_task.json(),
)
context_msg = [Message(content=context, role="user")]
code = await write_code.run(context_msg, plan)
assert len(code) > 0
print(code)
code_with_ml = await write_code_ml.run([], plan, column_info)
assert len(code_with_ml) > 0
print(code_with_ml)
@pytest.mark.asyncio
async def test_write_code_to_correct_error():