From f9683c0276dc0ca6d83e2f46f8f34eaf27654239 Mon Sep 17 00:00:00 2001 From: yzlin Date: Wed, 13 Mar 2024 21:17:24 +0800 Subject: [PATCH] test type match recommender --- tests/metagpt/tools/test_tool_recommend.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/metagpt/tools/test_tool_recommend.py b/tests/metagpt/tools/test_tool_recommend.py index 2fb3f9348..fafe0a638 100644 --- a/tests/metagpt/tools/test_tool_recommend.py +++ b/tests/metagpt/tools/test_tool_recommend.py @@ -2,7 +2,11 @@ import pytest from metagpt.schema import Plan, Task from metagpt.tools import TOOL_REGISTRY -from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender +from metagpt.tools.tool_recommend import ( + BM25ToolRecommender, + ToolRecommender, + TypeMatchToolRecommender, +) @pytest.fixture @@ -11,7 +15,7 @@ def mock_plan(mocker): "1": Task( task_id="1", instruction="conduct feature engineering, add new features on the dataset", - task_type="feature_engineering", + task_type="feature engineering", ) } plan = Plan( @@ -76,3 +80,11 @@ async def test_bm25_recommend_tools(mock_bm25_tr): async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr): result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan) assert isinstance(result, str) + + +@pytest.mark.asyncio +async def test_tm_tr_recall_with_plan(mock_plan, mock_bm25_tr): + tr = TypeMatchToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"]) + result = await tr.recall_tools(plan=mock_plan) + assert len(result) == 1 + assert result[0].name == "PolynomialExpansion"