From a9575380b540382bebb1d7367e1dd5aa581f394a Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 12 Jan 2024 16:04:44 +0800 Subject: [PATCH] update test for write_code_with_tools --- .../actions/test_write_analysis_code.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/metagpt/actions/test_write_analysis_code.py b/tests/metagpt/actions/test_write_analysis_code.py index df1d39603..f5b22c327 100644 --- a/tests/metagpt/actions/test_write_analysis_code.py +++ b/tests/metagpt/actions/test_write_analysis_code.py @@ -3,8 +3,13 @@ import asyncio import pytest from metagpt.actions.execute_code import ExecutePyCode -from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools +from metagpt.actions.write_analysis_code import ( + WriteCodeByGenerate, + WriteCodeWithTools, + WriteCodeWithToolsML, +) from metagpt.logs import logger +from metagpt.plan.planner import STRUCTURAL_CONTEXT from metagpt.schema import Message, Plan, Task @@ -40,13 +45,15 @@ async def test_tool_recommendation(): tools = await write_code._tool_recommendation(task, code_steps, available_tools) assert len(tools) == 1 - assert tools[0] == ["fill_missing_value"] + assert tools[0] == "fill_missing_value" @pytest.mark.asyncio async def test_write_code_with_tools(): write_code = WriteCodeWithTools() - messages = [] + write_code_ml = WriteCodeWithToolsML() + + requirement = "构造数据集并进行数据清洗" task_map = { "1": Task( task_id="1", @@ -69,10 +76,6 @@ async def test_write_code_with_tools(): instruction="对数据集进行数据清洗", task_type="data_preprocess", dependent_task_ids=["1"], - code_steps=""" - {"Step 1": "对数据集进行去重", - "Step 2": "对数据集进行缺失值处理"} - """, ), } plan = Plan( @@ -83,10 +86,22 @@ async def test_write_code_with_tools(): ) column_info = "" - code = await write_code.run(messages, plan, column_info) + context = STRUCTURAL_CONTEXT.format( + user_requirement=requirement, + context=plan.context, + tasks=list(task_map.values()), + current_task=plan.current_task.json(), + ) + context_msg = [Message(content=context, role="user")] + + code = await write_code.run(context_msg, plan) assert len(code) > 0 print(code) + code_with_ml = await write_code_ml.run([], plan, column_info) + assert len(code_with_ml) > 0 + print(code_with_ml) + @pytest.mark.asyncio async def test_write_code_to_correct_error():