mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
rename field to tool_field
This commit is contained in:
parent
fdc49775e6
commit
f19003b413
4 changed files with 38 additions and 35 deletions
|
|
@ -6,3 +6,4 @@
|
|||
# @Desc :
|
||||
from metagpt.tools.functions.register.register import registry
|
||||
import metagpt.tools.functions.libs.feature_engineering
|
||||
print(registry.functions)
|
||||
|
|
@ -16,7 +16,7 @@ class NoDefault:
|
|||
pass
|
||||
|
||||
|
||||
def field(
|
||||
def tool_field(
|
||||
description: str, default: Any = NoDefault(), enum: Optional[List[Any]] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,37 +8,37 @@ from typing import List
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.functions.schemas.base import field, ToolSchema
|
||||
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
|
||||
|
||||
|
||||
class PolynomialExpansion(ToolSchema):
|
||||
"""Generate polynomial and interaction features from selected columns, excluding the bias column."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
cols: list = field(description="Columns for polynomial expansion.")
|
||||
degree: int = field(description="Degree of polynomial features.", default=2)
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
cols: list = tool_field(description="Columns for polynomial expansion.")
|
||||
degree: int = tool_field(description="Degree of polynomial features.", default=2)
|
||||
|
||||
|
||||
class OneHotEncoding(ToolSchema):
|
||||
"""Apply one-hot encoding to specified categorical columns in a DataFrame."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
cols: list = field(description="Categorical columns to be one-hot encoded.")
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
cols: list = tool_field(description="Categorical columns to be one-hot encoded.")
|
||||
|
||||
|
||||
class FrequencyEncoding(ToolSchema):
|
||||
"""Convert categorical columns to frequency encoding."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
cols: list = field(description="Categorical columns to be frequency encoded.")
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
cols: list = tool_field(description="Categorical columns to be frequency encoded.")
|
||||
|
||||
|
||||
class CatCross(ToolSchema):
|
||||
"""Create pairwise crossed features from categorical columns, joining values with '_'."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
cols: list = field(description="Columns to be pairwise crossed.")
|
||||
max_cat_num: int = field(
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
cols: list = tool_field(description="Columns to be pairwise crossed.")
|
||||
max_cat_num: int = tool_field(
|
||||
description="Maximum unique categories per crossed feature.", default=100
|
||||
)
|
||||
|
||||
|
|
@ -46,10 +46,10 @@ class CatCross(ToolSchema):
|
|||
class GroupStat(ToolSchema):
|
||||
"""Perform aggregation operations on a specified column grouped by certain categories."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
group_col: str = field(description="Column used for grouping.")
|
||||
agg_col: str = field(description="Column on which aggregation is performed.")
|
||||
agg_funcs: list = field(
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
group_col: str = tool_field(description="Column used for grouping.")
|
||||
agg_col: str = tool_field(description="Column on which aggregation is performed.")
|
||||
agg_funcs: list = tool_field(
|
||||
description="""List of aggregation functions to apply, such as ['mean', 'std'].
|
||||
Each function must be supported by pandas."""
|
||||
)
|
||||
|
|
@ -58,9 +58,11 @@ class GroupStat(ToolSchema):
|
|||
class ExtractTimeComps(ToolSchema):
|
||||
"""Extract specific time components from a designated time column in a DataFrame."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
time_col: str = field(description="The name of the column containing time data.")
|
||||
time_comps: List[str] = field(
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
time_col: str = tool_field(
|
||||
description="The name of the column containing time data."
|
||||
)
|
||||
time_comps: List[str] = tool_field(
|
||||
description="""List of time components to extract.
|
||||
Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend']."""
|
||||
)
|
||||
|
|
@ -69,12 +71,12 @@ class ExtractTimeComps(ToolSchema):
|
|||
class FeShiftByTime(ToolSchema):
|
||||
"""Shift column values in a DataFrame based on specified time intervals."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
time_col: str = field(description="Column for time-based shifting.")
|
||||
group_col: str = field(description="Column for grouping before shifting.")
|
||||
shift_col: str = field(description="Column to shift.")
|
||||
periods: list = field(description="Time intervals for shifting.")
|
||||
freq: str = field(
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
time_col: str = tool_field(description="Column for time-based shifting.")
|
||||
group_col: str = tool_field(description="Column for grouping before shifting.")
|
||||
shift_col: str = tool_field(description="Column to shift.")
|
||||
periods: list = tool_field(description="Time intervals for shifting.")
|
||||
freq: str = tool_field(
|
||||
description="Frequency unit for time intervals (e.g., 'D', 'M').",
|
||||
enum=["D", "M", "Y", "W", "H"],
|
||||
)
|
||||
|
|
@ -83,16 +85,16 @@ class FeShiftByTime(ToolSchema):
|
|||
class FeRollingByTime(ToolSchema):
|
||||
"""Calculate rolling statistics for a DataFrame column over time intervals."""
|
||||
|
||||
df: pd.DataFrame = field(description="DataFrame to process.")
|
||||
time_col: str = field(description="Column for time-based rolling.")
|
||||
group_col: str = field(description="Column for grouping before rolling.")
|
||||
rolling_col: str = field(description="Column for rolling calculations.")
|
||||
periods: list = field(description="Window sizes for rolling.")
|
||||
freq: str = field(
|
||||
df: pd.DataFrame = tool_field(description="DataFrame to process.")
|
||||
time_col: str = tool_field(description="Column for time-based rolling.")
|
||||
group_col: str = tool_field(description="Column for grouping before rolling.")
|
||||
rolling_col: str = tool_field(description="Column for rolling calculations.")
|
||||
periods: list = tool_field(description="Window sizes for rolling.")
|
||||
freq: str = tool_field(
|
||||
description="Frequency unit for time windows (e.g., 'D', 'M').",
|
||||
enum=["D", "M", "Y", "W", "H"],
|
||||
)
|
||||
agg_funcs: list = field(
|
||||
agg_funcs: list = tool_field(
|
||||
description="""List of aggregation functions for rolling, like ['mean', 'std'].
|
||||
Each function must be in ['mean', 'std', 'min', 'max', 'median', 'sum', 'count']."""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.functions.register.register import FunctionRegistry
|
||||
from metagpt.tools.functions.schemas.base import ToolSchema, field
|
||||
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -18,8 +18,8 @@ def registry():
|
|||
class AddNumbers(ToolSchema):
|
||||
"""Add two numbers"""
|
||||
|
||||
num1: int = field(description="First number")
|
||||
num2: int = field(description="Second number")
|
||||
num1: int = tool_field(description="First number")
|
||||
num2: int = tool_field(description="Second number")
|
||||
|
||||
|
||||
def test_register(registry):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue