rename field to tool_field

This commit is contained in:
lidanyang 2023-11-24 17:46:43 +08:00
parent fdc49775e6
commit f19003b413
4 changed files with 38 additions and 35 deletions

View file

@ -6,3 +6,4 @@
# @Desc :
from metagpt.tools.functions.register.register import registry
import metagpt.tools.functions.libs.feature_engineering
print(registry.functions)

View file

@ -16,7 +16,7 @@ class NoDefault:
pass
def field(
def tool_field(
description: str, default: Any = NoDefault(), enum: Optional[List[Any]] = None, **kwargs
):
"""

View file

@ -8,37 +8,37 @@ from typing import List
import pandas as pd
from metagpt.tools.functions.schemas.base import field, ToolSchema
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
class PolynomialExpansion(ToolSchema):
"""Generate polynomial and interaction features from selected columns, excluding the bias column."""
df: pd.DataFrame = field(description="DataFrame to process.")
cols: list = field(description="Columns for polynomial expansion.")
degree: int = field(description="Degree of polynomial features.", default=2)
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Columns for polynomial expansion.")
degree: int = tool_field(description="Degree of polynomial features.", default=2)
class OneHotEncoding(ToolSchema):
"""Apply one-hot encoding to specified categorical columns in a DataFrame."""
df: pd.DataFrame = field(description="DataFrame to process.")
cols: list = field(description="Categorical columns to be one-hot encoded.")
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Categorical columns to be one-hot encoded.")
class FrequencyEncoding(ToolSchema):
"""Convert categorical columns to frequency encoding."""
df: pd.DataFrame = field(description="DataFrame to process.")
cols: list = field(description="Categorical columns to be frequency encoded.")
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Categorical columns to be frequency encoded.")
class CatCross(ToolSchema):
"""Create pairwise crossed features from categorical columns, joining values with '_'."""
df: pd.DataFrame = field(description="DataFrame to process.")
cols: list = field(description="Columns to be pairwise crossed.")
max_cat_num: int = field(
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Columns to be pairwise crossed.")
max_cat_num: int = tool_field(
description="Maximum unique categories per crossed feature.", default=100
)
@ -46,10 +46,10 @@ class CatCross(ToolSchema):
class GroupStat(ToolSchema):
"""Perform aggregation operations on a specified column grouped by certain categories."""
df: pd.DataFrame = field(description="DataFrame to process.")
group_col: str = field(description="Column used for grouping.")
agg_col: str = field(description="Column on which aggregation is performed.")
agg_funcs: list = field(
df: pd.DataFrame = tool_field(description="DataFrame to process.")
group_col: str = tool_field(description="Column used for grouping.")
agg_col: str = tool_field(description="Column on which aggregation is performed.")
agg_funcs: list = tool_field(
description="""List of aggregation functions to apply, such as ['mean', 'std'].
Each function must be supported by pandas."""
)
@ -58,9 +58,11 @@ class GroupStat(ToolSchema):
class ExtractTimeComps(ToolSchema):
"""Extract specific time components from a designated time column in a DataFrame."""
df: pd.DataFrame = field(description="DataFrame to process.")
time_col: str = field(description="The name of the column containing time data.")
time_comps: List[str] = field(
df: pd.DataFrame = tool_field(description="DataFrame to process.")
time_col: str = tool_field(
description="The name of the column containing time data."
)
time_comps: List[str] = tool_field(
description="""List of time components to extract.
Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend']."""
)
@ -69,12 +71,12 @@ class ExtractTimeComps(ToolSchema):
class FeShiftByTime(ToolSchema):
"""Shift column values in a DataFrame based on specified time intervals."""
df: pd.DataFrame = field(description="DataFrame to process.")
time_col: str = field(description="Column for time-based shifting.")
group_col: str = field(description="Column for grouping before shifting.")
shift_col: str = field(description="Column to shift.")
periods: list = field(description="Time intervals for shifting.")
freq: str = field(
df: pd.DataFrame = tool_field(description="DataFrame to process.")
time_col: str = tool_field(description="Column for time-based shifting.")
group_col: str = tool_field(description="Column for grouping before shifting.")
shift_col: str = tool_field(description="Column to shift.")
periods: list = tool_field(description="Time intervals for shifting.")
freq: str = tool_field(
description="Frequency unit for time intervals (e.g., 'D', 'M').",
enum=["D", "M", "Y", "W", "H"],
)
@ -83,16 +85,16 @@ class FeShiftByTime(ToolSchema):
class FeRollingByTime(ToolSchema):
"""Calculate rolling statistics for a DataFrame column over time intervals."""
df: pd.DataFrame = field(description="DataFrame to process.")
time_col: str = field(description="Column for time-based rolling.")
group_col: str = field(description="Column for grouping before rolling.")
rolling_col: str = field(description="Column for rolling calculations.")
periods: list = field(description="Window sizes for rolling.")
freq: str = field(
df: pd.DataFrame = tool_field(description="DataFrame to process.")
time_col: str = tool_field(description="Column for time-based rolling.")
group_col: str = tool_field(description="Column for grouping before rolling.")
rolling_col: str = tool_field(description="Column for rolling calculations.")
periods: list = tool_field(description="Window sizes for rolling.")
freq: str = tool_field(
description="Frequency unit for time windows (e.g., 'D', 'M').",
enum=["D", "M", "Y", "W", "H"],
)
agg_funcs: list = field(
agg_funcs: list = tool_field(
description="""List of aggregation functions for rolling, like ['mean', 'std'].
Each function must be in ['mean', 'std', 'min', 'max', 'median', 'sum', 'count']."""
)

View file

@ -7,7 +7,7 @@
import pytest
from metagpt.tools.functions.register.register import FunctionRegistry
from metagpt.tools.functions.schemas.base import ToolSchema, field
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
@pytest.fixture
@ -18,8 +18,8 @@ def registry():
class AddNumbers(ToolSchema):
"""Add two numbers"""
num1: int = field(description="First number")
num2: int = field(description="Second number")
num1: int = tool_field(description="First number")
num2: int = tool_field(description="Second number")
def test_register(registry):