fix unit tests for tool module

This commit is contained in:
yzlin 2024-03-11 16:18:28 +08:00
parent bff3ef02bc
commit b5af9ccde6
7 changed files with 91 additions and 141 deletions

View file

@ -1,44 +1,8 @@
from typing import Literal, Union
import pandas as pd
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
def test_docstring_to_schema():
docstring = """
Some test desc.
Args:
features (list): Columns to be processed.
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
Defaults to None.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
expected = {
"description": "Some test desc.",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
},
"required": ["features"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
}
schema = docstring_to_schema(docstring)
assert schema == expected
from metagpt.tools.tool_convert import convert_code_to_tool_schema
class DummyClass:
@ -81,12 +45,26 @@ class DummyClass:
pass
def dummy_fn(df: pd.DataFrame) -> dict:
# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict:
def dummy_fn(
df: pd.DataFrame,
s: str,
k: int = 5,
type: Literal["a", "b", "c"] = "a",
test_dict: dict[str, int] = None,
test_union: Union[str, list[str]] = "",
) -> dict:
"""
Analyzes a DataFrame and categorizes its columns based on data types.
Args:
df (pd.DataFrame): The DataFrame to be analyzed.
df: The DataFrame to be analyzed.
Another line for df.
s (str): Some test string param.
Another line for s.
k (int, optional): Some test integer param. Defaults to 5.
type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
more_args: will be omitted here for testing
Returns:
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
@ -115,41 +93,21 @@ def test_convert_code_to_tool_schema_class():
"methods": {
"__init__": {
"type": "function",
"description": "Initialize self.",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
},
"required": ["features"],
},
"description": "Initialize self. ",
"signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
"parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
},
"fit": {
"type": "function",
"description": "Fit the FillMissingValue model.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"description": "Fit the FillMissingValue model. ",
"signature": "(self, df: pandas.core.frame.DataFrame)",
"parameters": "Args: df (pd.DataFrame): The input DataFrame.",
},
"transform": {
"type": "function",
"description": "Transform the input DataFrame with the fitted model.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
"description": "Transform the input DataFrame with the fitted model. ",
"signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
"parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
},
},
}
@ -160,11 +118,9 @@ def test_convert_code_to_tool_schema_class():
def test_convert_code_to_tool_schema_function():
expected = {
"type": "function",
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
"required": ["df"],
},
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
"signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
"parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
}
schema = convert_code_to_tool_schema(dummy_fn)
assert schema == expected

View file

@ -23,9 +23,15 @@ def mock_plan(mocker):
return plan
@pytest.fixture
def mock_bm25_tr(mocker):
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
return tr
def test_tr_init():
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web_scraping", "non-existing tool"])
# web_scraping is a tool type, it has one tool scrape_web_playwright
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
# web_scraping is a tool tag, it has one tool scrape_web_playwright
assert list(tr.tools.keys()) == [
"FillMissingValue",
"PolynomialExpansion",
@ -39,28 +45,34 @@ def test_tr_init_default_tools_value():
def test_tr_init_tools_all():
tr = ToolRecommender(tools="<all>")
tr = ToolRecommender(tools=["<all>"])
assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
@pytest.mark.asyncio
async def test_tr_recall_with_plan(mock_plan):
tr = ToolRecommender(
tools=[
"FillMissingValue",
"PolynomialExpansion",
"web_scraping",
]
)
result = await tr.recall_tools(plan=mock_plan)
assert len(result) == 1
async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
result = await mock_bm25_tr.recall_tools(plan=mock_plan)
assert len(result) == 3
assert result[0].name == "PolynomialExpansion"
@pytest.mark.asyncio
async def test_bm25_tr_recall(mock_plan):
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web_scraping"])
result = await tr.recall_tools(plan=mock_plan)
# print(result)
async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
result = await mock_bm25_tr.recall_tools(
context="conduct feature engineering, add new features on the dataset", plan=None
)
assert len(result) == 3
assert result[0].name == "PolynomialExpansion"
@pytest.mark.asyncio
async def test_bm25_recommend_tools(mock_bm25_tr):
result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
assert len(result) == 2 # web scraping tool should be filtered out at rank stage
assert result[0].name == "PolynomialExpansion"
@pytest.mark.asyncio
async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
assert isinstance(result, str)

View file

@ -1,7 +1,6 @@
import pytest
from metagpt.tools.tool_registry import ToolRegistry
from metagpt.tools.tool_type import ToolType
@pytest.fixture
@ -9,25 +8,11 @@ def tool_registry():
return ToolRegistry()
@pytest.fixture
def tool_registry_full():
return ToolRegistry(tool_types=ToolType)
# Test Initialization
def test_initialization(tool_registry):
assert isinstance(tool_registry, ToolRegistry)
assert tool_registry.tools == {}
assert tool_registry.tool_types == {}
assert tool_registry.tools_by_types == {}
# Test Initialization with tool types
def test_initialize_with_tool_types(tool_registry_full):
assert isinstance(tool_registry_full, ToolRegistry)
assert tool_registry_full.tools == {}
assert tool_registry_full.tools_by_types == {}
assert "data_preprocess" in tool_registry_full.tool_types
assert tool_registry.tools_by_tags == {}
class TestClassTool:
@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
assert "description" in tool.schemas
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
def test_has_tool_type(tool_registry_full):
assert tool_registry_full.has_tool_type("data_preprocess")
assert not tool_registry_full.has_tool_type("NonexistentType")
def test_has_tool_tag(tool_registry):
tool_registry.register_tool(
"TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
)
assert tool_registry.has_tool_tag("test")
assert not tool_registry.has_tool_tag("Non-existent tag")
def test_get_tool_type(tool_registry_full):
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
assert retrieved_type is not None
assert retrieved_type.name == "data_preprocess"
def test_get_tools_by_type(tool_registry):
tool_type_name = "TestType"
def test_get_tools_by_tag(tool_registry):
tool_tag_name = "Test Tag"
tool_name = "TestTool"
tool_path = "/path/to/tool"
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
assert tools_by_type is not None
assert tool_name in tools_by_type
tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
assert tools_by_tag is not None
assert tool_name in tools_by_tag
# Test case for when the tool type does not exist
def test_get_tools_by_nonexistent_type(tool_registry):
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
assert not tools_by_type
tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
assert not tools_by_tag_non_existent