mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 20:03:28 +02:00
fix unit tests for tool module
This commit is contained in:
parent
bff3ef02bc
commit
b5af9ccde6
7 changed files with 91 additions and 141 deletions
|
|
@ -1,44 +1,8 @@
|
|||
from typing import Literal, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
|
||||
|
||||
|
||||
def test_docstring_to_schema():
|
||||
docstring = """
|
||||
Some test desc.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
|
||||
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
}
|
||||
schema = docstring_to_schema(docstring)
|
||||
assert schema == expected
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -81,12 +45,26 @@ class DummyClass:
|
|||
pass
|
||||
|
||||
|
||||
def dummy_fn(df: pd.DataFrame) -> dict:
|
||||
# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict:
|
||||
def dummy_fn(
|
||||
df: pd.DataFrame,
|
||||
s: str,
|
||||
k: int = 5,
|
||||
type: Literal["a", "b", "c"] = "a",
|
||||
test_dict: dict[str, int] = None,
|
||||
test_union: Union[str, list[str]] = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
df: The DataFrame to be analyzed.
|
||||
Another line for df.
|
||||
s (str): Some test string param.
|
||||
Another line for s.
|
||||
k (int, optional): Some test integer param. Defaults to 5.
|
||||
type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
|
||||
more_args: will be omitted here for testing
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
|
|
@ -115,41 +93,21 @@ def test_convert_code_to_tool_schema_class():
|
|||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"description": "Initialize self. ",
|
||||
"signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
|
||||
"parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
},
|
||||
"fit": {
|
||||
"type": "function",
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Fit the FillMissingValue model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame)",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame.",
|
||||
},
|
||||
"transform": {
|
||||
"type": "function",
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
"description": "Transform the input DataFrame with the fitted model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -160,11 +118,9 @@ def test_convert_code_to_tool_schema_class():
|
|||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
|
||||
"signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
|
||||
"parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
|
|
|
|||
|
|
@ -23,9 +23,15 @@ def mock_plan(mocker):
|
|||
return plan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bm25_tr(mocker):
|
||||
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
|
||||
return tr
|
||||
|
||||
|
||||
def test_tr_init():
|
||||
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web_scraping", "non-existing tool"])
|
||||
# web_scraping is a tool type, it has one tool scrape_web_playwright
|
||||
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
|
||||
# web_scraping is a tool tag, it has one tool scrape_web_playwright
|
||||
assert list(tr.tools.keys()) == [
|
||||
"FillMissingValue",
|
||||
"PolynomialExpansion",
|
||||
|
|
@ -39,28 +45,34 @@ def test_tr_init_default_tools_value():
|
|||
|
||||
|
||||
def test_tr_init_tools_all():
|
||||
tr = ToolRecommender(tools="<all>")
|
||||
tr = ToolRecommender(tools=["<all>"])
|
||||
assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tr_recall_with_plan(mock_plan):
|
||||
tr = ToolRecommender(
|
||||
tools=[
|
||||
"FillMissingValue",
|
||||
"PolynomialExpansion",
|
||||
"web_scraping",
|
||||
]
|
||||
)
|
||||
result = await tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 1
|
||||
async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall(mock_plan):
|
||||
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web_scraping"])
|
||||
result = await tr.recall_tools(plan=mock_plan)
|
||||
# print(result)
|
||||
async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(
|
||||
context="conduct feature engineering, add new features on the dataset", plan=None
|
||||
)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_recommend_tools(mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
|
||||
assert len(result) == 2 # web scraping tool should be filtered out at rank stage
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
|
||||
assert isinstance(result, str)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -9,25 +8,11 @@ def tool_registry():
|
|||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
assert tool_registry.tools == {}
|
||||
assert tool_registry.tool_types == {}
|
||||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
assert tool_registry.tools_by_tags == {}
|
||||
|
||||
|
||||
class TestClassTool:
|
||||
|
|
@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
|
|||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
def test_has_tool_tag(tool_registry):
|
||||
tool_registry.register_tool(
|
||||
"TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
|
||||
)
|
||||
assert tool_registry.has_tool_tag("test")
|
||||
assert not tool_registry.has_tool_tag("Non-existent tag")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
def test_get_tools_by_tag(tool_registry):
|
||||
tool_tag_name = "Test Tag"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
assert tool_name in tools_by_type
|
||||
tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
|
||||
assert tools_by_tag is not None
|
||||
assert tool_name in tools_by_tag
|
||||
|
||||
|
||||
# Test case for when the tool type does not exist
|
||||
def test_get_tools_by_nonexistent_type(tool_registry):
|
||||
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
|
||||
assert not tools_by_type
|
||||
tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
|
||||
assert not tools_by_tag_non_existent
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue