diff --git a/metagpt/tools/functions/__init__.py b/metagpt/tools/functions/__init__.py index b81e85833..7ab850667 100644 --- a/metagpt/tools/functions/__init__.py +++ b/metagpt/tools/functions/__init__.py @@ -6,3 +6,4 @@ # @Desc : from metagpt.tools.functions.register.register import registry import metagpt.tools.functions.libs.feature_engineering +print(registry.functions) \ No newline at end of file diff --git a/metagpt/tools/functions/schemas/base.py b/metagpt/tools/functions/schemas/base.py index 35b9f77b7..aef604c8d 100644 --- a/metagpt/tools/functions/schemas/base.py +++ b/metagpt/tools/functions/schemas/base.py @@ -16,7 +16,7 @@ class NoDefault: pass -def field( +def tool_field( description: str, default: Any = NoDefault(), enum: Optional[List[Any]] = None, **kwargs ): """ diff --git a/metagpt/tools/functions/schemas/feature_engineering.py b/metagpt/tools/functions/schemas/feature_engineering.py index 8237c83f4..c14bb933e 100644 --- a/metagpt/tools/functions/schemas/feature_engineering.py +++ b/metagpt/tools/functions/schemas/feature_engineering.py @@ -8,37 +8,37 @@ from typing import List import pandas as pd -from metagpt.tools.functions.schemas.base import field, ToolSchema +from metagpt.tools.functions.schemas.base import ToolSchema, tool_field class PolynomialExpansion(ToolSchema): """Generate polynomial and interaction features from selected columns, excluding the bias column.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Columns for polynomial expansion.") - degree: int = field(description="Degree of polynomial features.", default=2) + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Columns for polynomial expansion.") + degree: int = tool_field(description="Degree of polynomial features.", default=2) class OneHotEncoding(ToolSchema): """Apply one-hot encoding to specified categorical columns in a DataFrame.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Categorical columns to be one-hot encoded.") + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Categorical columns to be one-hot encoded.") class FrequencyEncoding(ToolSchema): """Convert categorical columns to frequency encoding.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Categorical columns to be frequency encoded.") + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Categorical columns to be frequency encoded.") class CatCross(ToolSchema): """Create pairwise crossed features from categorical columns, joining values with '_'.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Columns to be pairwise crossed.") - max_cat_num: int = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Columns to be pairwise crossed.") + max_cat_num: int = tool_field( description="Maximum unique categories per crossed feature.", default=100 ) @@ -46,10 +46,10 @@ class CatCross(ToolSchema): class GroupStat(ToolSchema): """Perform aggregation operations on a specified column grouped by certain categories.""" - df: pd.DataFrame = field(description="DataFrame to process.") - group_col: str = field(description="Column used for grouping.") - agg_col: str = field(description="Column on which aggregation is performed.") - agg_funcs: list = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + group_col: str = tool_field(description="Column used for grouping.") + agg_col: str = tool_field(description="Column on which aggregation is performed.") + agg_funcs: list = tool_field( description="""List of aggregation functions to apply, such as ['mean', 'std']. Each function must be supported by pandas.""" ) @@ -58,9 +58,11 @@ class GroupStat(ToolSchema): class ExtractTimeComps(ToolSchema): """Extract specific time components from a designated time column in a DataFrame.""" - df: pd.DataFrame = field(description="DataFrame to process.") - time_col: str = field(description="The name of the column containing time data.") - time_comps: List[str] = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + time_col: str = tool_field( + description="The name of the column containing time data." + ) + time_comps: List[str] = tool_field( description="""List of time components to extract. Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend'].""" ) @@ -69,12 +71,12 @@ class ExtractTimeComps(ToolSchema): class FeShiftByTime(ToolSchema): """Shift column values in a DataFrame based on specified time intervals.""" - df: pd.DataFrame = field(description="DataFrame to process.") - time_col: str = field(description="Column for time-based shifting.") - group_col: str = field(description="Column for grouping before shifting.") - shift_col: str = field(description="Column to shift.") - periods: list = field(description="Time intervals for shifting.") - freq: str = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + time_col: str = tool_field(description="Column for time-based shifting.") + group_col: str = tool_field(description="Column for grouping before shifting.") + shift_col: str = tool_field(description="Column to shift.") + periods: list = tool_field(description="Time intervals for shifting.") + freq: str = tool_field( description="Frequency unit for time intervals (e.g., 'D', 'M').", enum=["D", "M", "Y", "W", "H"], ) @@ -83,16 +85,16 @@ class FeShiftByTime(ToolSchema): class FeRollingByTime(ToolSchema): """Calculate rolling statistics for a DataFrame column over time intervals.""" - df: pd.DataFrame = field(description="DataFrame to process.") - time_col: str = field(description="Column for time-based rolling.") - group_col: str = field(description="Column for grouping before rolling.") - rolling_col: str = field(description="Column for rolling calculations.") - periods: list = field(description="Window sizes for rolling.") - freq: str = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + time_col: str = tool_field(description="Column for time-based rolling.") + group_col: str = tool_field(description="Column for grouping before rolling.") + rolling_col: str = tool_field(description="Column for rolling calculations.") + periods: list = tool_field(description="Window sizes for rolling.") + freq: str = tool_field( description="Frequency unit for time windows (e.g., 'D', 'M').", enum=["D", "M", "Y", "W", "H"], ) - agg_funcs: list = field( + agg_funcs: list = tool_field( description="""List of aggregation functions for rolling, like ['mean', 'std']. Each function must be in ['mean', 'std', 'min', 'max', 'median', 'sum', 'count'].""" ) diff --git a/tests/metagpt/tools/functions/register/test_register.py b/tests/metagpt/tools/functions/register/test_register.py index a71f7d01c..8c9821268 100644 --- a/tests/metagpt/tools/functions/register/test_register.py +++ b/tests/metagpt/tools/functions/register/test_register.py @@ -7,7 +7,7 @@ import pytest from metagpt.tools.functions.register.register import FunctionRegistry -from metagpt.tools.functions.schemas.base import ToolSchema, field +from metagpt.tools.functions.schemas.base import ToolSchema, tool_field @pytest.fixture @@ -18,8 +18,8 @@ def registry(): class AddNumbers(ToolSchema): """Add two numbers""" - num1: int = field(description="First number") - num2: int = field(description="Second number") + num1: int = tool_field(description="First number") + num2: int = tool_field(description="Second number") def test_register(registry):