diff --git a/tests/metagpt/tools/test_tool_recommend.py b/tests/metagpt/tools/test_tool_recommend.py index 2fb3f9348..fafe0a638 100644 --- a/tests/metagpt/tools/test_tool_recommend.py +++ b/tests/metagpt/tools/test_tool_recommend.py @@ -2,7 +2,11 @@ import pytest from metagpt.schema import Plan, Task from metagpt.tools import TOOL_REGISTRY -from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender +from metagpt.tools.tool_recommend import ( + BM25ToolRecommender, + ToolRecommender, + TypeMatchToolRecommender, +) @pytest.fixture @@ -11,7 +15,7 @@ def mock_plan(mocker): "1": Task( task_id="1", instruction="conduct feature engineering, add new features on the dataset", - task_type="feature_engineering", + task_type="feature engineering", ) } plan = Plan( @@ -76,3 +80,11 @@ async def test_bm25_recommend_tools(mock_bm25_tr): async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr): result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan) assert isinstance(result, str) + + +@pytest.mark.asyncio +async def test_tm_tr_recall_with_plan(mock_plan, mock_bm25_tr): + tr = TypeMatchToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"]) + result = await tr.recall_tools(plan=mock_plan) + assert len(result) == 1 + assert result[0].name == "PolynomialExpansion"