From b5af9ccde6e0e87af40eca3b45fcb8569d588633 Mon Sep 17 00:00:00 2001 From: yzlin Date: Mon, 11 Mar 2024 16:18:28 +0800 Subject: [PATCH] fix unit tests for tool module --- metagpt/actions/mi/write_plan.py | 4 +- metagpt/roles/mi/interpreter.py | 15 ++- metagpt/tools/tool_recommend.py | 7 +- metagpt/tools/tool_registry.py | 2 +- tests/metagpt/tools/test_tool_convert.py | 106 ++++++--------------- tests/metagpt/tools/test_tool_recommend.py | 46 +++++---- tests/metagpt/tools/test_tool_registry.py | 52 +++------- 7 files changed, 91 insertions(+), 141 deletions(-) diff --git a/metagpt/actions/mi/write_plan.py b/metagpt/actions/mi/write_plan.py index 1839de0f9..b190733fc 100644 --- a/metagpt/actions/mi/write_plan.py +++ b/metagpt/actions/mi/write_plan.py @@ -42,9 +42,7 @@ class WritePlan(Action): """ async def run(self, context: list[Message], max_tasks: int = 5, use_tools: bool = False) -> str: - task_type_desc = "\n".join( - [f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType] - ) # task type are binded with tool type now, should be improved in the future + task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]) prompt = self.PROMPT_TEMPLATE.format( context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc ) diff --git a/metagpt/roles/mi/interpreter.py b/metagpt/roles/mi/interpreter.py index 58b38ac43..e71514b62 100644 --- a/metagpt/roles/mi/interpreter.py +++ b/metagpt/roles/mi/interpreter.py @@ -156,11 +156,16 @@ class Interpreter(Role): return code, todo async def _check_data(self): - if not self.use_plan or self.planner.plan.current_task.task_type not in [ - TaskType.DATA_PREPROCESS.type_name, - TaskType.FEATURE_ENGINEERING.type_name, - TaskType.MODEL_TRAIN.type_name, - ]: + if ( + not self.use_plan + or not self.planner.plan.get_finished_tasks() + or self.planner.plan.current_task.task_type + not in [ + TaskType.DATA_PREPROCESS.type_name, + TaskType.FEATURE_ENGINEERING.type_name, + TaskType.MODEL_TRAIN.type_name, + ] + ): return logger.info("Check updated data") code = await CheckData().run(self.planner.plan) diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index 9e06a67b4..fcdbc4254 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -174,9 +174,10 @@ class BM25ToolRecommender(ToolRecommender): doc_scores = self.bm25.get_scores(query_tokens) top_indexes = np.argsort(doc_scores)[::-1][:topk] recalled_tools = [list(self.tools.values())[index] for index in top_indexes] - print([doc_scores[index] for index in top_indexes]) - print([recalled_tools[i].name for i in range(len(recalled_tools))]) - print([recalled_tools[i].schemas["description"] for i in range(len(recalled_tools))]) + + logger.info( + f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}" + ) return recalled_tools diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 24c286c26..11269cb0f 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -134,7 +134,7 @@ def validate_tool_names(tools: Union[list[str], str]) -> str: # one can define either tool names or tool type names, take union to get the whole set if TOOL_REGISTRY.has_tool(key): valid_tools.update({key: TOOL_REGISTRY.get_tool(key)}) - elif TOOL_REGISTRY.tool_tool_tag(key): + elif TOOL_REGISTRY.has_tool_tag(key): valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key)) else: logger.warning(f"invalid tool name or tool type name: {key}, skipped") diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py index 8f26a211c..f85b84b71 100644 --- a/tests/metagpt/tools/test_tool_convert.py +++ b/tests/metagpt/tools/test_tool_convert.py @@ -1,44 +1,8 @@ +from typing import Literal, Union + import pandas as pd -from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema - - -def test_docstring_to_schema(): - docstring = """ - Some test desc. - - Args: - features (list): Columns to be processed. - strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be - used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. - fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. - Defaults to None. - Returns: - pd.DataFrame: The transformed DataFrame. - """ - expected = { - "description": "Some test desc.", - "parameters": { - "properties": { - "features": {"type": "list", "description": "Columns to be processed."}, - "strategy": { - "type": "str", - "description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.", - "default": "'mean'", - "enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"], - }, - "fill_value": { - "type": "int", - "description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.", - "default": "None", - }, - }, - "required": ["features"], - }, - "returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}], - } - schema = docstring_to_schema(docstring) - assert schema == expected +from metagpt.tools.tool_convert import convert_code_to_tool_schema class DummyClass: @@ -81,12 +45,26 @@ class DummyClass: pass -def dummy_fn(df: pd.DataFrame) -> dict: +# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict: +def dummy_fn( + df: pd.DataFrame, + s: str, + k: int = 5, + type: Literal["a", "b", "c"] = "a", + test_dict: dict[str, int] = None, + test_union: Union[str, list[str]] = "", +) -> dict: """ Analyzes a DataFrame and categorizes its columns based on data types. Args: - df (pd.DataFrame): The DataFrame to be analyzed. + df: The DataFrame to be analyzed. + Another line for df. + s (str): Some test string param. + Another line for s. + k (int, optional): Some test integer param. Defaults to 5. + type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'. + more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). @@ -115,41 +93,21 @@ def test_convert_code_to_tool_schema_class(): "methods": { "__init__": { "type": "function", - "description": "Initialize self.", - "parameters": { - "properties": { - "features": {"type": "list", "description": "Columns to be processed."}, - "strategy": { - "type": "str", - "description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.", - "default": "'mean'", - "enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"], - }, - "fill_value": { - "type": "int", - "description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.", - "default": "None", - }, - }, - "required": ["features"], - }, + "description": "Initialize self. ", + "signature": "(self, features: list, strategy: str = 'mean', fill_value=None)", + "parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.", }, "fit": { "type": "function", - "description": "Fit the FillMissingValue model.", - "parameters": { - "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}}, - "required": ["df"], - }, + "description": "Fit the FillMissingValue model. ", + "signature": "(self, df: pandas.core.frame.DataFrame)", + "parameters": "Args: df (pd.DataFrame): The input DataFrame.", }, "transform": { "type": "function", - "description": "Transform the input DataFrame with the fitted model.", - "parameters": { - "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}}, - "required": ["df"], - }, - "returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}], + "description": "Transform the input DataFrame with the fitted model. ", + "signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame", + "parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.", }, }, } @@ -160,11 +118,9 @@ def test_convert_code_to_tool_schema_class(): def test_convert_code_to_tool_schema_function(): expected = { "type": "function", - "description": "Analyzes a DataFrame and categorizes its columns based on data types.", - "parameters": { - "properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}}, - "required": ["df"], - }, + "description": "Analyzes a DataFrame and categorizes its columns based on data types. ", + "signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict", + "parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.", } schema = convert_code_to_tool_schema(dummy_fn) assert schema == expected diff --git a/tests/metagpt/tools/test_tool_recommend.py b/tests/metagpt/tools/test_tool_recommend.py index 1359d5834..2fb3f9348 100644 --- a/tests/metagpt/tools/test_tool_recommend.py +++ b/tests/metagpt/tools/test_tool_recommend.py @@ -23,9 +23,15 @@ def mock_plan(mocker): return plan +@pytest.fixture +def mock_bm25_tr(mocker): + tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"]) + return tr + + def test_tr_init(): - tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web_scraping", "non-existing tool"]) - # web_scraping is a tool type, it has one tool scrape_web_playwright + tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"]) + # web_scraping is a tool tag, it has one tool scrape_web_playwright assert list(tr.tools.keys()) == [ "FillMissingValue", "PolynomialExpansion", @@ -39,28 +45,34 @@ def test_tr_init_default_tools_value(): def test_tr_init_tools_all(): - tr = ToolRecommender(tools="") + tr = ToolRecommender(tools=[""]) assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys()) @pytest.mark.asyncio -async def test_tr_recall_with_plan(mock_plan): - tr = ToolRecommender( - tools=[ - "FillMissingValue", - "PolynomialExpansion", - "web_scraping", - ] - ) - result = await tr.recall_tools(plan=mock_plan) - assert len(result) == 1 +async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr): + result = await mock_bm25_tr.recall_tools(plan=mock_plan) + assert len(result) == 3 assert result[0].name == "PolynomialExpansion" @pytest.mark.asyncio -async def test_bm25_tr_recall(mock_plan): - tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web_scraping"]) - result = await tr.recall_tools(plan=mock_plan) - # print(result) +async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr): + result = await mock_bm25_tr.recall_tools( + context="conduct feature engineering, add new features on the dataset", plan=None + ) assert len(result) == 3 assert result[0].name == "PolynomialExpansion" + + +@pytest.mark.asyncio +async def test_bm25_recommend_tools(mock_bm25_tr): + result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset") + assert len(result) == 2 # web scraping tool should be filtered out at rank stage + assert result[0].name == "PolynomialExpansion" + + +@pytest.mark.asyncio +async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr): + result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan) + assert isinstance(result, str) diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py index 2fd487fb7..f44dfea0b 100644 --- a/tests/metagpt/tools/test_tool_registry.py +++ b/tests/metagpt/tools/test_tool_registry.py @@ -1,7 +1,6 @@ import pytest from metagpt.tools.tool_registry import ToolRegistry -from metagpt.tools.tool_type import ToolType @pytest.fixture @@ -9,25 +8,11 @@ def tool_registry(): return ToolRegistry() -@pytest.fixture -def tool_registry_full(): - return ToolRegistry(tool_types=ToolType) - - # Test Initialization def test_initialization(tool_registry): assert isinstance(tool_registry, ToolRegistry) assert tool_registry.tools == {} - assert tool_registry.tool_types == {} - assert tool_registry.tools_by_types == {} - - -# Test Initialization with tool types -def test_initialize_with_tool_types(tool_registry_full): - assert isinstance(tool_registry_full, ToolRegistry) - assert tool_registry_full.tools == {} - assert tool_registry_full.tools_by_types == {} - assert "data_preprocess" in tool_registry_full.tool_types + assert tool_registry.tools_by_tags == {} class TestClassTool: @@ -72,31 +57,24 @@ def test_get_tool(tool_registry): assert "description" in tool.schemas -# Similar tests for has_tool_type, get_tool_type, get_tools_by_type -def test_has_tool_type(tool_registry_full): - assert tool_registry_full.has_tool_type("data_preprocess") - assert not tool_registry_full.has_tool_type("NonexistentType") +def test_has_tool_tag(tool_registry): + tool_registry.register_tool( + "TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"] + ) + assert tool_registry.has_tool_tag("test") + assert not tool_registry.has_tool_tag("Non-existent tag") -def test_get_tool_type(tool_registry_full): - retrieved_type = tool_registry_full.get_tool_type("data_preprocess") - assert retrieved_type is not None - assert retrieved_type.name == "data_preprocess" - - -def test_get_tools_by_type(tool_registry): - tool_type_name = "TestType" +def test_get_tools_by_tag(tool_registry): + tool_tag_name = "Test Tag" tool_name = "TestTool" tool_path = "/path/to/tool" - tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool) + tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool) - tools_by_type = tool_registry.get_tools_by_type(tool_type_name) - assert tools_by_type is not None - assert tool_name in tools_by_type + tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name) + assert tools_by_tag is not None + assert tool_name in tools_by_tag - -# Test case for when the tool type does not exist -def test_get_tools_by_nonexistent_type(tool_registry): - tools_by_type = tool_registry.get_tools_by_type("NonexistentType") - assert not tools_by_type + tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag") + assert not tools_by_tag_non_existent