fix ml_engineer test

This commit is contained in:
yzlin 2024-01-31 18:09:01 +08:00
parent b585064edc
commit a1b16b7e99
3 changed files with 84 additions and 3 deletions

View file

@ -1,7 +1,11 @@
import pytest
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.logs import logger
from metagpt.roles.ml_engineer import MLEngineer
from metagpt.schema import Message, Plan, Task
from metagpt.tools.tool_data_type import ToolTypeEnum
from tests.metagpt.actions.test_debug_code import CODE, DebugContext, ErrorStr
def test_mle_init():
@ -9,13 +13,80 @@ def test_mle_init():
assert ci.tools == []
MockPlan = Plan(
goal="This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: 'tests/data/ml_datasets/titanic/split_train.csv', eval data path: 'tests/data/ml_datasets/titanic/split_eval.csv'.",
context="",
tasks=[
Task(
task_id="1",
dependent_task_ids=[],
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
task_type="eda",
code_steps="",
code="",
result="",
is_success=False,
is_finished=False,
)
],
task_map={
"1": Task(
task_id="1",
dependent_task_ids=[],
instruction="Perform exploratory data analysis on the train dataset to understand the features and target variable.",
task_type="eda",
code_steps="",
code="",
result="",
is_success=False,
is_finished=False,
)
},
current_task_id="1",
)
@pytest.mark.asyncio
async def test_mle_write_code(mocker):
data_path = "tests/data/ml_datasets/titanic"
mle = MLEngineer(auto_run=True, use_tools=True)
mle.planner.plan = MockPlan
code, _ = await mle._write_code()
assert data_path in code["code"]
@pytest.mark.asyncio
async def test_mle_update_data_columns(mocker):
mle = MLEngineer(auto_run=True, use_tools=True)
mle.planner.plan = MockPlan
# manually update task type to test update
mle.planner.plan.current_task.task_type = ToolTypeEnum.DATA_PREPROCESS.value
result = await mle._update_data_columns()
assert result is not None
@pytest.mark.asyncio
async def test_mle_debug_code(mocker):
mle = MLEngineer(auto_run=True, use_tools=True)
mle.working_memory.add(Message(content=ErrorStr, cause_by=ExecutePyCode))
mle.latest_code = CODE
mle.debug_context = DebugContext
code, _ = await mle._write_code()
assert len(code) > 0
@pytest.mark.skip
@pytest.mark.asyncio
async def test_ml_engineer():
data_path = "tests/data/ml_datasets/titanic"
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
mle = MLEngineer(goal=requirement, auto_run=True, use_tools=True, tools=tools)
mle = MLEngineer(auto_run=True, use_tools=True, tools=tools)
rsp = await mle.run(requirement)
logger.info(rsp)
assert len(rsp.content) > 0