diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index 402f56ccc..5cea9fe51 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -77,8 +77,8 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode): ) -> dict: # context.append(Message(content=self.REUSE_CODE_INSTRUCTION, role="user")) prompt = self.process_msg(context, system_msg) - code_content = await self.llm.aask_code(prompt, **kwargs) - return code_content + rsp = await self.llm.aask_code(prompt, **kwargs) + return rsp class WriteCodeWithTools(BaseWriteAnalysisCode): diff --git a/tests/metagpt/actions/test_ml_action.py b/tests/metagpt/actions/test_ml_action.py new file mode 100644 index 000000000..2c8d34da8 --- /dev/null +++ b/tests/metagpt/actions/test_ml_action.py @@ -0,0 +1,46 @@ +import pytest + +from metagpt.actions.ml_action import WriteCodeWithToolsML +from metagpt.schema import Plan, Task + + +@pytest.mark.asyncio +async def test_write_code_with_tools(): + write_code_ml = WriteCodeWithToolsML() + + task_map = { + "1": Task( + task_id="1", + instruction="随机生成一个pandas DataFrame数据集", + task_type="other", + dependent_task_ids=[], + code=""" + import pandas as pd + df = pd.DataFrame({ + 'a': [1, 2, 3, 4, 5], + 'b': [1.1, 2.2, 3.3, 4.4, np.nan], + 'c': ['aa', 'bb', 'cc', 'dd', 'ee'], + 'd': [1, 2, 3, 4, 5] + }) + """, + is_finished=True, + ), + "2": Task( + task_id="2", + instruction="对数据集进行数据清洗", + task_type="data_preprocess", + dependent_task_ids=["1"], + ), + } + plan = Plan( + goal="构造数据集并进行数据清洗", + tasks=list(task_map.values()), + task_map=task_map, + current_task_id="2", + ) + column_info = "" + + _, code_with_ml = await write_code_ml.run([], plan, column_info) + code_with_ml = code_with_ml["code"] + assert len(code_with_ml) > 0 + print(code_with_ml) diff --git a/tests/metagpt/actions/test_write_analysis_code.py b/tests/metagpt/actions/test_write_analysis_code.py index 3e20a8bfb..43f23848d 100644 --- a/tests/metagpt/actions/test_write_analysis_code.py +++ b/tests/metagpt/actions/test_write_analysis_code.py @@ -3,13 +3,13 @@ import asyncio import pytest from metagpt.actions.execute_code import ExecutePyCode -from metagpt.actions.ml_action import WriteCodeWithToolsML from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools from metagpt.logs import logger from metagpt.plan.planner import STRUCTURAL_CONTEXT from metagpt.schema import Message, Plan, Task +@pytest.mark.skip @pytest.mark.asyncio async def test_write_code_by_list_plan(): write_code = WriteCodeByGenerate() @@ -20,35 +20,31 @@ async def test_write_code_by_list_plan(): print(f"\n任务: {task}\n\n") messages.append(Message(task, role="assistant")) code = await write_code.run(messages) - messages.append(Message(code, role="assistant")) + messages.append(Message(code["code"], role="assistant")) assert len(code) > 0 - output = await execute_code.run(code) + output = await execute_code.run(code["code"]) print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n") messages.append(output[0]) @pytest.mark.asyncio async def test_tool_recommendation(): - task = "对已经读取的数据集进行数据清洗" - code_steps = """ - step 1: 对数据集进行去重 - step 2: 对数据集进行缺失值处理 - """ + task = "clean and preprocess the data" + code_steps = "" available_tools = { - "fill_missing_value": "Completing missing values with simple strategies", - "split_bins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value", + "FillMissingValue": "Filling missing values", + "SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value", } write_code = WriteCodeWithTools() - tools = await write_code._tool_recommendation(task, code_steps, available_tools) + tools = await write_code._recommend_tool(task, code_steps, available_tools) assert len(tools) == 1 - assert tools[0] == "fill_missing_value" + assert "FillMissingValue" in tools @pytest.mark.asyncio async def test_write_code_with_tools(): write_code = WriteCodeWithTools() - write_code_ml = WriteCodeWithToolsML() requirement = "构造数据集并进行数据清洗" task_map = { @@ -81,7 +77,6 @@ async def test_write_code_with_tools(): task_map=task_map, current_task_id="2", ) - column_info = "" context = STRUCTURAL_CONTEXT.format( user_requirement=requirement, @@ -92,13 +87,10 @@ async def test_write_code_with_tools(): context_msg = [Message(content=context, role="user")] code = await write_code.run(context_msg, plan) + code = code["code"] assert len(code) > 0 print(code) - code_with_ml = await write_code_ml.run([], plan, column_info) - assert len(code_with_ml) > 0 - print(code_with_ml) - @pytest.mark.asyncio async def test_write_code_to_correct_error(): @@ -147,6 +139,7 @@ async def test_write_code_to_correct_error(): Message(content=error, role="user"), ] new_code = await WriteCodeByGenerate().run(context=context) + new_code = new_code["code"] print(new_code) assert "read_csv" in new_code # should correct read_excel to read_csv @@ -186,10 +179,12 @@ async def test_write_code_reuse_code_simple(): Message(content=structural_context, role="user"), ] code = await WriteCodeByGenerate().run(context=context) + code = code["code"] print(code) assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one +@pytest.mark.skip @pytest.mark.asyncio async def test_write_code_reuse_code_long(): """test code reuse for long context""" @@ -242,13 +237,14 @@ async def test_write_code_reuse_code_long(): trial_results = await asyncio.gather(*trials) print(*trial_results, sep="\n\n***\n\n") success = [ - "load_iris" not in result and "iris_data" in result for result in trial_results + "load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results ] # should reuse iris_data from previous tasks success_rate = sum(success) / trials_num logger.info(f"success rate: {success_rate :.2f}") assert success_rate >= 0.8 +@pytest.mark.skip @pytest.mark.asyncio async def test_write_code_reuse_code_long_for_wine(): """test code reuse for long context""" @@ -315,7 +311,7 @@ async def test_write_code_reuse_code_long_for_wine(): trial_results = await asyncio.gather(*trials) print(*trial_results, sep="\n\n***\n\n") success = [ - "load_wine" not in result and "wine_data" in result for result in trial_results + "load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results ] # should reuse iris_data from previous tasks success_rate = sum(success) / trials_num logger.info(f"success rate: {success_rate :.2f}")