From bb8c39a312c558d53d803832052a39854fe6aa60 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 15:01:52 +0800 Subject: [PATCH 1/7] init function tools and define tool schema --- metagpt/tools/functions/__init__.py | 8 ++ metagpt/tools/functions/libs/__init__.py | 6 ++ metagpt/tools/functions/schemas/__init__.py | 6 ++ metagpt/tools/functions/schemas/base.py | 100 ++++++++++++++++++++ 4 files changed, 120 insertions(+) create mode 100644 metagpt/tools/functions/__init__.py create mode 100644 metagpt/tools/functions/libs/__init__.py create mode 100644 metagpt/tools/functions/schemas/__init__.py create mode 100644 metagpt/tools/functions/schemas/base.py diff --git a/metagpt/tools/functions/__init__.py b/metagpt/tools/functions/__init__.py new file mode 100644 index 000000000..069e4297b --- /dev/null +++ b/metagpt/tools/functions/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:32 +# @Author : lidanyang +# @File : __init__.py +# @Desc : +from metagpt.tools.functions.register.register import registry +import metagpt.tools.functions.libs.machine_learning diff --git a/metagpt/tools/functions/libs/__init__.py b/metagpt/tools/functions/libs/__init__.py new file mode 100644 index 000000000..a0a43f507 --- /dev/null +++ b/metagpt/tools/functions/libs/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:32 +# @Author : lidanyang +# @File : __init__.py +# @Desc : diff --git a/metagpt/tools/functions/schemas/__init__.py b/metagpt/tools/functions/schemas/__init__.py new file mode 100644 index 000000000..e50f67d6f --- /dev/null +++ b/metagpt/tools/functions/schemas/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:33 +# @Author : lidanyang +# @File : __init__.py +# @Desc : diff --git a/metagpt/tools/functions/schemas/base.py b/metagpt/tools/functions/schemas/base.py new file mode 100644 index 000000000..35b9f77b7 --- /dev/null +++ b/metagpt/tools/functions/schemas/base.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:34 +# @Author : lidanyang +# @File : base.py +# @Desc : Build base class to generate schema for tool +from typing import Any, List, Optional, get_type_hints + + +class NoDefault: + """ + A class to represent a missing default value. + + This is used to distinguish between a default value of None and a missing default value. + """ + pass + + +def field( + description: str, default: Any = NoDefault(), enum: Optional[List[Any]] = None, **kwargs +): + """ + Create a field for a tool parameter. + + Args: + description (str): A description of the field. + default (Any, optional): The default value for the field. Defaults to None. + enum (Optional[List[Any]], optional): A list of possible values for the field. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + dict: A dictionary representing the field with provided attributes. + """ + field_info = { + "description": description, + "default": default, + "enum": enum, + } + field_info.update(kwargs) + return field_info + + +class ToolSchema: + @staticmethod + def format_type(type_hint): + """ + Format a type hint into a string representation. + + Args: + type_hint (type): The type hint to format. + + Returns: + str: A string representation of the type hint. + """ + if isinstance(type_hint, type): + # Handle built-in types separately + if type_hint.__module__ == "builtins": + return type_hint.__name__ + else: + return f"{type_hint.__module__}.{type_hint.__name__}" + elif hasattr(type_hint, "__origin__") and hasattr(type_hint, "__args__"): + # Handle generic types (like List[int]) + origin_type = ToolSchema.format_type(type_hint.__origin__) + args_type = ", ".join( + [ToolSchema.format_type(t) for t in type_hint.__args__] + ) + return f"{origin_type}[{args_type}]" + else: + return str(type_hint) + + @classmethod + def schema(cls): + """ + Generate a schema dictionary for the class. + + The schema includes the class name, description, and information about + each class parameter based on type hints and field definitions. + + Returns: + dict: A dictionary representing the schema of the class. + """ + schema = { + "name": cls.__name__, + "description": cls.__doc__, + "parameters": {"type": "object", "properties": {}, "required": []}, + } + type_hints = get_type_hints(cls) + for attr, type_hint in type_hints.items(): + value = getattr(cls, attr, None) + if isinstance(value, dict): + # Process each attribute that is defined using the field function + prop_info = {k: v for k, v in value.items() if v is not None or k == "default"} + if isinstance(prop_info["default"], NoDefault): + del prop_info["default"] + prop_info["type"] = ToolSchema.format_type(type_hint) + schema["parameters"]["properties"][attr] = prop_info + # Check for required fields + if "default" not in prop_info: + schema["parameters"]["required"].append(attr) + return schema From b0e28838e490db5577faa9092bc7055ff3d720ae Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 15:02:40 +0800 Subject: [PATCH 2/7] add function register --- metagpt/tools/functions/register/__init__.py | 6 ++ metagpt/tools/functions/register/register.py | 65 ++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 metagpt/tools/functions/register/__init__.py create mode 100644 metagpt/tools/functions/register/register.py diff --git a/metagpt/tools/functions/register/__init__.py b/metagpt/tools/functions/register/__init__.py new file mode 100644 index 000000000..c80872750 --- /dev/null +++ b/metagpt/tools/functions/register/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:37 +# @Author : lidanyang +# @File : __init__.py +# @Desc : diff --git a/metagpt/tools/functions/register/register.py b/metagpt/tools/functions/register/register.py new file mode 100644 index 000000000..120c7c4a2 --- /dev/null +++ b/metagpt/tools/functions/register/register.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:38 +# @Author : lidanyang +# @File : register.py +# @Desc : +from typing import Type, Optional, Callable, Dict, Union, List + +from metagpt.tools.functions.schemas.base import ToolSchema + + +class FunctionRegistry: + def __init__(self): + self.functions: Dict[str, Dict[str, Dict]] = {} + + def register(self, module: str, tool_schema: Type[ToolSchema]) -> Callable: + + def wrapper(func: Callable) -> Callable: + module_registry = self.functions.setdefault(module, {}) + + if func.__name__ in module_registry: + raise ValueError(f"Function {func.__name__} is already registered in {module}") + + schema = tool_schema.schema() + schema["name"] = func.__name__ + module_registry[func.__name__] = { + "func": func, + "schema": schema, + } + return func + + return wrapper + + def get(self, module: str, name: str) -> Optional[Union[Callable, Dict]]: + """Get function by module and name""" + module_registry = self.functions.get(module, {}) + return module_registry.get(name) + + def get_by_name(self, name: str) -> Optional[Dict]: + """Get function by name""" + for module_registry in self.functions.values(): + if name in module_registry: + return module_registry.get(name, {}) + + def get_all_by_module(self, module: str) -> Optional[Dict]: + """Get all functions by module""" + return self.functions.get(module, {}) + + def get_schema(self, module: str, name: str) -> Optional[Dict]: + """Get schema by module and name""" + module_registry = self.functions.get(module, {}) + return module_registry.get(name, {}).get("schema") + + def get_schemas(self, module: str, names: List[str]) -> List[Dict]: + """Get schemas by module and names""" + module_registry = self.functions.get(module, {}) + return [module_registry.get(name, {}).get("schema") for name in names] + + def get_all_schema_by_module(self, module: str) -> List[Dict]: + """Get all schemas by module""" + module_registry = self.functions.get(module, {}) + return [v.get("schema") for v in module_registry.values()] + + +registry = FunctionRegistry() From a911f5649df85df5f1e41827a5ffebf120edba94 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 15:03:03 +0800 Subject: [PATCH 3/7] add feature engineering tools --- .../libs/machine_learning/__init__.py | 7 + .../machine_learning/feature_engineering.py | 174 ++++++++++++++++++ .../schemas/machine_learning/__init__.py | 6 + .../machine_learning/feature_engineering.py | 98 ++++++++++ 4 files changed, 285 insertions(+) create mode 100644 metagpt/tools/functions/libs/machine_learning/__init__.py create mode 100644 metagpt/tools/functions/libs/machine_learning/feature_engineering.py create mode 100644 metagpt/tools/functions/schemas/machine_learning/__init__.py create mode 100644 metagpt/tools/functions/schemas/machine_learning/feature_engineering.py diff --git a/metagpt/tools/functions/libs/machine_learning/__init__.py b/metagpt/tools/functions/libs/machine_learning/__init__.py new file mode 100644 index 000000000..5e9760c64 --- /dev/null +++ b/metagpt/tools/functions/libs/machine_learning/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:36 +# @Author : lidanyang +# @File : __init__.py +# @Desc : +from metagpt.tools.functions.libs.machine_learning.feature_engineering import * diff --git a/metagpt/tools/functions/libs/machine_learning/feature_engineering.py b/metagpt/tools/functions/libs/machine_learning/feature_engineering.py new file mode 100644 index 000000000..584bd125d --- /dev/null +++ b/metagpt/tools/functions/libs/machine_learning/feature_engineering.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/17 10:33 +# @Author : lidanyang +# @File : feature_engineering.py +# @Desc : Feature Engineering Functions +import itertools + +from dateutil.relativedelta import relativedelta +from pandas.api.types import is_numeric_dtype +from sklearn.preprocessing import PolynomialFeatures, OneHotEncoder + +from metagpt.tools.functions import registry +from metagpt.tools.functions.schemas.machine_learning.feature_engineering import * + + +@registry.register("feature_engineering", PolynomialExpansion) +def polynomial_expansion(df, cols, degree=2): + for col in cols: + if not is_numeric_dtype(df[col]): + raise ValueError(f"Column '{col}' must be numeric.") + + poly = PolynomialFeatures(degree=degree, include_bias=False) + ts_data = poly.fit_transform(df[cols].fillna(0)) + new_columns = poly.get_feature_names_out(cols) + ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index) + ts_data = ts_data.drop(cols, axis=1) + df = pd.concat([df, ts_data], axis=1) + return df + + +@registry.register("feature_engineering", OneHotEncoding) +def one_hot_encoding(df, cols): + enc = OneHotEncoder(handle_unknown="ignore", sparse=False) + ts_data = enc.fit_transform(df[cols]) + new_columns = enc.get_feature_names_out(cols) + ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index) + df.drop(cols, axis=1, inplace=True) + df = pd.concat([df, ts_data], axis=1) + return df + + +@registry.register("feature_engineering", FrequencyEncoding) +def frequency_encoding(df, cols): + for col in cols: + encoder_dict = df[col].value_counts().to_dict() + df[f"{col}_cnt"] = df[col].map(encoder_dict) + return df + + +@registry.register("feature_engineering", CatCross) +def cat_cross(df, cols, max_cat_num=100): + for col in cols: + if df[col].nunique() > max_cat_num: + cols.remove(col) + + for col1, col2 in itertools.combinations(cols, 2): + cross_col = f"{col1}_cross_{col2}" + df[cross_col] = df[col1].astype(str) + "_" + df[col2].astype(str) + return df + + +@registry.register("feature_engineering", GroupStat) +def group_stat(df, group_col, agg_col, agg_funcs): + group_df = df.groupby(group_col)[agg_col].agg(agg_funcs).reset_index() + group_df.columns = group_col + [ + f"{agg_col}_{agg_func}_by_{group_col}" for agg_func in agg_funcs + ] + df = df.merge(group_df, on=group_col, how="left") + return df + + +@registry.register("feature_engineering", ExtractTimeComps) +def extract_time_comps(df, time_col, time_comps): + time_s = pd.to_datetime(df[time_col], errors="coerce") + time_comps_df = pd.DataFrame() + + if "year" in time_comps: + time_comps_df["year"] = time_s.dt.year + if "month" in time_comps: + time_comps_df["month"] = time_s.dt.month + if "day" in time_comps: + time_comps_df["day"] = time_s.dt.day + if "hour" in time_comps: + time_comps_df["hour"] = time_s.dt.hour + if "dayofweek" in time_comps: + time_comps_df["dayofweek"] = time_s.dt.dayofweek + 1 + if "is_weekend" in time_comps: + time_comps_df["is_weekend"] = time_s.dt.dayofweek.isin([5, 6]).astype(int) + df = pd.concat([df, time_comps_df], axis=1) + return df + + +@registry.register("feature_engineering", FeShiftByTime) +def fe_shift_by_time(df, time_col, group_col, shift_col, periods, freq): + df[time_col] = pd.to_datetime(df[time_col]) + + def shift_datetime(date, offset, unit): + if unit in ["year", "y", "Y"]: + return date + relativedelta(years=offset) + elif unit in ["month", "m", "M"]: + return date + relativedelta(months=offset) + elif unit in ["day", "d", "D"]: + return date + relativedelta(days=offset) + elif unit in ["week", "w", "W"]: + return date + relativedelta(weeks=offset) + elif unit in ["hour", "h", "H"]: + return date + relativedelta(hours=offset) + else: + return date + + def shift_by_time_on_key( + inner_df, time_col, group_col, shift_col, offset, unit, col_name + ): + inner_df = inner_df.drop_duplicates() + inner_df[time_col] = inner_df[time_col].map( + lambda x: shift_datetime(x, offset, unit) + ) + inner_df = inner_df.groupby([time_col, group_col], as_index=False)[ + shift_col + ].mean() + inner_df.rename(columns={shift_col: col_name}, inplace=True) + return inner_df + + shift_df = df[[time_col, group_col, shift_col]].copy() + for period in periods: + new_col_name = f"{group_col}_{shift_col}_lag_{period}_{freq}" + tmp = shift_by_time_on_key( + shift_df, time_col, group_col, shift_col, period, freq, new_col_name + ) + df = df.merge(tmp, on=[time_col, group_col], how="left") + + return df + + +@registry.register("feature_engineering", FeRollingByTime) +def fe_rolling_by_time(df, time_col, group_col, rolling_col, periods, freq, agg_funcs): + df[time_col] = pd.to_datetime(df[time_col]) + + def rolling_by_time_on_key(inner_df, offset, unit, agg_func, col_name): + time_freq = { + "Y": [365 * offset, "D"], + "M": [30 * offset, "D"], + "D": [offset, "D"], + "W": [7 * offset, "D"], + "H": [offset, "h"], + } + + if agg_func not in ["mean", "std", "max", "min", "median", "sum", "count"]: + raise ValueError(f"Invalid agg function: {agg_func}") + + rolling_feat = inner_df.rolling( + f"{time_freq[unit][0]}{time_freq[unit][1]}", closed="left" + ) + rolling_feat = getattr(rolling_feat, agg_func)() + depth = df.columns.nlevels + rolling_feat = rolling_feat.stack(list(range(depth))) + rolling_feat.name = col_name + return rolling_feat + + rolling_df = df[[time_col, group_col, rolling_col]].copy() + for period in periods: + for func in agg_funcs: + new_col_name = f"{group_col}_{rolling_col}_rolling_{period}_{freq}_{func}" + tmp = pd.pivot_table( + rolling_df, + index=time_col, + values=rolling_col, + columns=group_col, + ) + tmp = rolling_by_time_on_key(tmp, period, freq, func, new_col_name) + df = df.merge(tmp, on=[time_col, group_col], how="left") + + return df diff --git a/metagpt/tools/functions/schemas/machine_learning/__init__.py b/metagpt/tools/functions/schemas/machine_learning/__init__.py new file mode 100644 index 000000000..c80872750 --- /dev/null +++ b/metagpt/tools/functions/schemas/machine_learning/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/16 16:37 +# @Author : lidanyang +# @File : __init__.py +# @Desc : diff --git a/metagpt/tools/functions/schemas/machine_learning/feature_engineering.py b/metagpt/tools/functions/schemas/machine_learning/feature_engineering.py new file mode 100644 index 000000000..8237c83f4 --- /dev/null +++ b/metagpt/tools/functions/schemas/machine_learning/feature_engineering.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/17 10:34 +# @Author : lidanyang +# @File : feature_engineering.py +# @Desc : Schema for feature engineering functions +from typing import List + +import pandas as pd + +from metagpt.tools.functions.schemas.base import field, ToolSchema + + +class PolynomialExpansion(ToolSchema): + """Generate polynomial and interaction features from selected columns, excluding the bias column.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + cols: list = field(description="Columns for polynomial expansion.") + degree: int = field(description="Degree of polynomial features.", default=2) + + +class OneHotEncoding(ToolSchema): + """Apply one-hot encoding to specified categorical columns in a DataFrame.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + cols: list = field(description="Categorical columns to be one-hot encoded.") + + +class FrequencyEncoding(ToolSchema): + """Convert categorical columns to frequency encoding.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + cols: list = field(description="Categorical columns to be frequency encoded.") + + +class CatCross(ToolSchema): + """Create pairwise crossed features from categorical columns, joining values with '_'.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + cols: list = field(description="Columns to be pairwise crossed.") + max_cat_num: int = field( + description="Maximum unique categories per crossed feature.", default=100 + ) + + +class GroupStat(ToolSchema): + """Perform aggregation operations on a specified column grouped by certain categories.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + group_col: str = field(description="Column used for grouping.") + agg_col: str = field(description="Column on which aggregation is performed.") + agg_funcs: list = field( + description="""List of aggregation functions to apply, such as ['mean', 'std']. + Each function must be supported by pandas.""" + ) + + +class ExtractTimeComps(ToolSchema): + """Extract specific time components from a designated time column in a DataFrame.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + time_col: str = field(description="The name of the column containing time data.") + time_comps: List[str] = field( + description="""List of time components to extract. + Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend'].""" + ) + + +class FeShiftByTime(ToolSchema): + """Shift column values in a DataFrame based on specified time intervals.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + time_col: str = field(description="Column for time-based shifting.") + group_col: str = field(description="Column for grouping before shifting.") + shift_col: str = field(description="Column to shift.") + periods: list = field(description="Time intervals for shifting.") + freq: str = field( + description="Frequency unit for time intervals (e.g., 'D', 'M').", + enum=["D", "M", "Y", "W", "H"], + ) + + +class FeRollingByTime(ToolSchema): + """Calculate rolling statistics for a DataFrame column over time intervals.""" + + df: pd.DataFrame = field(description="DataFrame to process.") + time_col: str = field(description="Column for time-based rolling.") + group_col: str = field(description="Column for grouping before rolling.") + rolling_col: str = field(description="Column for rolling calculations.") + periods: list = field(description="Window sizes for rolling.") + freq: str = field( + description="Frequency unit for time windows (e.g., 'D', 'M').", + enum=["D", "M", "Y", "W", "H"], + ) + agg_funcs: list = field( + description="""List of aggregation functions for rolling, like ['mean', 'std']. + Each function must be in ['mean', 'std', 'min', 'max', 'median', 'sum', 'count'].""" + ) From 142b04fa760490062f8366b836784ee02206e491 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 15:04:14 +0800 Subject: [PATCH 4/7] test tool register --- tests/metagpt/tools/functions/__init__.py | 6 ++ .../tools/functions/register/__init__.py | 6 ++ .../tools/functions/register/test_register.py | 55 +++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 tests/metagpt/tools/functions/__init__.py create mode 100644 tests/metagpt/tools/functions/register/__init__.py create mode 100644 tests/metagpt/tools/functions/register/test_register.py diff --git a/tests/metagpt/tools/functions/__init__.py b/tests/metagpt/tools/functions/__init__.py new file mode 100644 index 000000000..7d36f3404 --- /dev/null +++ b/tests/metagpt/tools/functions/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/17 10:24 +# @Author : lidanyang +# @File : __init__.py +# @Desc : diff --git a/tests/metagpt/tools/functions/register/__init__.py b/tests/metagpt/tools/functions/register/__init__.py new file mode 100644 index 000000000..7d36f3404 --- /dev/null +++ b/tests/metagpt/tools/functions/register/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/17 10:24 +# @Author : lidanyang +# @File : __init__.py +# @Desc : diff --git a/tests/metagpt/tools/functions/register/test_register.py b/tests/metagpt/tools/functions/register/test_register.py new file mode 100644 index 000000000..a71f7d01c --- /dev/null +++ b/tests/metagpt/tools/functions/register/test_register.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/11/17 10:24 +# @Author : lidanyang +# @File : test_register.py +# @Desc : +import pytest + +from metagpt.tools.functions.register.register import FunctionRegistry +from metagpt.tools.functions.schemas.base import ToolSchema, field + + +@pytest.fixture +def registry(): + return FunctionRegistry() + + +class AddNumbers(ToolSchema): + """Add two numbers""" + + num1: int = field(description="First number") + num2: int = field(description="Second number") + + +def test_register(registry): + @registry.register("module1", AddNumbers) + def add_numbers(num1, num2): + return num1 + num2 + + assert len(registry.functions["module1"]) == 1 + assert "add_numbers" in registry.functions["module1"] + + with pytest.raises(ValueError): + + @registry.register("module1", AddNumbers) + def add_numbers(num1, num2): + return num1 + num2 + + func = registry.get("module1", "add_numbers") + assert func["func"](1, 2) == 3 + assert func["schema"] == { + "name": "add_numbers", + "description": "Add two numbers", + "parameters": { + "type": "object", + "properties": { + "num1": {"description": "First number", "type": "int"}, + "num2": {"description": "Second number", "type": "int"}, + }, + "required": ["num1", "num2"], + }, + } + + module1_funcs = registry.get_all_by_module("module1") + assert len(module1_funcs) == 1 From fdc49775e613036f6da3169a1298a28792aae018 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 17:23:39 +0800 Subject: [PATCH 5/7] reduce hierarchy of machine learning --- metagpt/tools/functions/__init__.py | 2 +- .../libs/{machine_learning => }/feature_engineering.py | 2 +- metagpt/tools/functions/libs/machine_learning/__init__.py | 7 ------- .../schemas/{machine_learning => }/feature_engineering.py | 0 .../tools/functions/schemas/machine_learning/__init__.py | 6 ------ 5 files changed, 2 insertions(+), 15 deletions(-) rename metagpt/tools/functions/libs/{machine_learning => }/feature_engineering.py (98%) delete mode 100644 metagpt/tools/functions/libs/machine_learning/__init__.py rename metagpt/tools/functions/schemas/{machine_learning => }/feature_engineering.py (100%) delete mode 100644 metagpt/tools/functions/schemas/machine_learning/__init__.py diff --git a/metagpt/tools/functions/__init__.py b/metagpt/tools/functions/__init__.py index 069e4297b..b81e85833 100644 --- a/metagpt/tools/functions/__init__.py +++ b/metagpt/tools/functions/__init__.py @@ -5,4 +5,4 @@ # @File : __init__.py # @Desc : from metagpt.tools.functions.register.register import registry -import metagpt.tools.functions.libs.machine_learning +import metagpt.tools.functions.libs.feature_engineering diff --git a/metagpt/tools/functions/libs/machine_learning/feature_engineering.py b/metagpt/tools/functions/libs/feature_engineering.py similarity index 98% rename from metagpt/tools/functions/libs/machine_learning/feature_engineering.py rename to metagpt/tools/functions/libs/feature_engineering.py index 584bd125d..0573f362d 100644 --- a/metagpt/tools/functions/libs/machine_learning/feature_engineering.py +++ b/metagpt/tools/functions/libs/feature_engineering.py @@ -11,7 +11,7 @@ from pandas.api.types import is_numeric_dtype from sklearn.preprocessing import PolynomialFeatures, OneHotEncoder from metagpt.tools.functions import registry -from metagpt.tools.functions.schemas.machine_learning.feature_engineering import * +from metagpt.tools.functions.schemas.feature_engineering import * @registry.register("feature_engineering", PolynomialExpansion) diff --git a/metagpt/tools/functions/libs/machine_learning/__init__.py b/metagpt/tools/functions/libs/machine_learning/__init__.py deleted file mode 100644 index 5e9760c64..000000000 --- a/metagpt/tools/functions/libs/machine_learning/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Time : 2023/11/16 16:36 -# @Author : lidanyang -# @File : __init__.py -# @Desc : -from metagpt.tools.functions.libs.machine_learning.feature_engineering import * diff --git a/metagpt/tools/functions/schemas/machine_learning/feature_engineering.py b/metagpt/tools/functions/schemas/feature_engineering.py similarity index 100% rename from metagpt/tools/functions/schemas/machine_learning/feature_engineering.py rename to metagpt/tools/functions/schemas/feature_engineering.py diff --git a/metagpt/tools/functions/schemas/machine_learning/__init__.py b/metagpt/tools/functions/schemas/machine_learning/__init__.py deleted file mode 100644 index c80872750..000000000 --- a/metagpt/tools/functions/schemas/machine_learning/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Time : 2023/11/16 16:37 -# @Author : lidanyang -# @File : __init__.py -# @Desc : From f19003b413ab216128f55f18d1679802308049cb Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 17:46:43 +0800 Subject: [PATCH 6/7] rename field to tool_field --- metagpt/tools/functions/__init__.py | 1 + metagpt/tools/functions/schemas/base.py | 2 +- .../functions/schemas/feature_engineering.py | 64 ++++++++++--------- .../tools/functions/register/test_register.py | 6 +- 4 files changed, 38 insertions(+), 35 deletions(-) diff --git a/metagpt/tools/functions/__init__.py b/metagpt/tools/functions/__init__.py index b81e85833..7ab850667 100644 --- a/metagpt/tools/functions/__init__.py +++ b/metagpt/tools/functions/__init__.py @@ -6,3 +6,4 @@ # @Desc : from metagpt.tools.functions.register.register import registry import metagpt.tools.functions.libs.feature_engineering +print(registry.functions) \ No newline at end of file diff --git a/metagpt/tools/functions/schemas/base.py b/metagpt/tools/functions/schemas/base.py index 35b9f77b7..aef604c8d 100644 --- a/metagpt/tools/functions/schemas/base.py +++ b/metagpt/tools/functions/schemas/base.py @@ -16,7 +16,7 @@ class NoDefault: pass -def field( +def tool_field( description: str, default: Any = NoDefault(), enum: Optional[List[Any]] = None, **kwargs ): """ diff --git a/metagpt/tools/functions/schemas/feature_engineering.py b/metagpt/tools/functions/schemas/feature_engineering.py index 8237c83f4..c14bb933e 100644 --- a/metagpt/tools/functions/schemas/feature_engineering.py +++ b/metagpt/tools/functions/schemas/feature_engineering.py @@ -8,37 +8,37 @@ from typing import List import pandas as pd -from metagpt.tools.functions.schemas.base import field, ToolSchema +from metagpt.tools.functions.schemas.base import ToolSchema, tool_field class PolynomialExpansion(ToolSchema): """Generate polynomial and interaction features from selected columns, excluding the bias column.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Columns for polynomial expansion.") - degree: int = field(description="Degree of polynomial features.", default=2) + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Columns for polynomial expansion.") + degree: int = tool_field(description="Degree of polynomial features.", default=2) class OneHotEncoding(ToolSchema): """Apply one-hot encoding to specified categorical columns in a DataFrame.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Categorical columns to be one-hot encoded.") + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Categorical columns to be one-hot encoded.") class FrequencyEncoding(ToolSchema): """Convert categorical columns to frequency encoding.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Categorical columns to be frequency encoded.") + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Categorical columns to be frequency encoded.") class CatCross(ToolSchema): """Create pairwise crossed features from categorical columns, joining values with '_'.""" - df: pd.DataFrame = field(description="DataFrame to process.") - cols: list = field(description="Columns to be pairwise crossed.") - max_cat_num: int = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + cols: list = tool_field(description="Columns to be pairwise crossed.") + max_cat_num: int = tool_field( description="Maximum unique categories per crossed feature.", default=100 ) @@ -46,10 +46,10 @@ class CatCross(ToolSchema): class GroupStat(ToolSchema): """Perform aggregation operations on a specified column grouped by certain categories.""" - df: pd.DataFrame = field(description="DataFrame to process.") - group_col: str = field(description="Column used for grouping.") - agg_col: str = field(description="Column on which aggregation is performed.") - agg_funcs: list = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + group_col: str = tool_field(description="Column used for grouping.") + agg_col: str = tool_field(description="Column on which aggregation is performed.") + agg_funcs: list = tool_field( description="""List of aggregation functions to apply, such as ['mean', 'std']. Each function must be supported by pandas.""" ) @@ -58,9 +58,11 @@ class GroupStat(ToolSchema): class ExtractTimeComps(ToolSchema): """Extract specific time components from a designated time column in a DataFrame.""" - df: pd.DataFrame = field(description="DataFrame to process.") - time_col: str = field(description="The name of the column containing time data.") - time_comps: List[str] = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + time_col: str = tool_field( + description="The name of the column containing time data." + ) + time_comps: List[str] = tool_field( description="""List of time components to extract. Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend'].""" ) @@ -69,12 +71,12 @@ class ExtractTimeComps(ToolSchema): class FeShiftByTime(ToolSchema): """Shift column values in a DataFrame based on specified time intervals.""" - df: pd.DataFrame = field(description="DataFrame to process.") - time_col: str = field(description="Column for time-based shifting.") - group_col: str = field(description="Column for grouping before shifting.") - shift_col: str = field(description="Column to shift.") - periods: list = field(description="Time intervals for shifting.") - freq: str = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + time_col: str = tool_field(description="Column for time-based shifting.") + group_col: str = tool_field(description="Column for grouping before shifting.") + shift_col: str = tool_field(description="Column to shift.") + periods: list = tool_field(description="Time intervals for shifting.") + freq: str = tool_field( description="Frequency unit for time intervals (e.g., 'D', 'M').", enum=["D", "M", "Y", "W", "H"], ) @@ -83,16 +85,16 @@ class FeShiftByTime(ToolSchema): class FeRollingByTime(ToolSchema): """Calculate rolling statistics for a DataFrame column over time intervals.""" - df: pd.DataFrame = field(description="DataFrame to process.") - time_col: str = field(description="Column for time-based rolling.") - group_col: str = field(description="Column for grouping before rolling.") - rolling_col: str = field(description="Column for rolling calculations.") - periods: list = field(description="Window sizes for rolling.") - freq: str = field( + df: pd.DataFrame = tool_field(description="DataFrame to process.") + time_col: str = tool_field(description="Column for time-based rolling.") + group_col: str = tool_field(description="Column for grouping before rolling.") + rolling_col: str = tool_field(description="Column for rolling calculations.") + periods: list = tool_field(description="Window sizes for rolling.") + freq: str = tool_field( description="Frequency unit for time windows (e.g., 'D', 'M').", enum=["D", "M", "Y", "W", "H"], ) - agg_funcs: list = field( + agg_funcs: list = tool_field( description="""List of aggregation functions for rolling, like ['mean', 'std']. Each function must be in ['mean', 'std', 'min', 'max', 'median', 'sum', 'count'].""" ) diff --git a/tests/metagpt/tools/functions/register/test_register.py b/tests/metagpt/tools/functions/register/test_register.py index a71f7d01c..8c9821268 100644 --- a/tests/metagpt/tools/functions/register/test_register.py +++ b/tests/metagpt/tools/functions/register/test_register.py @@ -7,7 +7,7 @@ import pytest from metagpt.tools.functions.register.register import FunctionRegistry -from metagpt.tools.functions.schemas.base import ToolSchema, field +from metagpt.tools.functions.schemas.base import ToolSchema, tool_field @pytest.fixture @@ -18,8 +18,8 @@ def registry(): class AddNumbers(ToolSchema): """Add two numbers""" - num1: int = field(description="First number") - num2: int = field(description="Second number") + num1: int = tool_field(description="First number") + num2: int = tool_field(description="Second number") def test_register(registry): From c159260717acb5f98c7ed3add259b5fe3db9c3d5 Mon Sep 17 00:00:00 2001 From: lidanyang Date: Fri, 24 Nov 2023 18:56:15 +0800 Subject: [PATCH 7/7] check_param_consistency --- metagpt/tools/functions/register/register.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/metagpt/tools/functions/register/register.py b/metagpt/tools/functions/register/register.py index 120c7c4a2..0731e31c0 100644 --- a/metagpt/tools/functions/register/register.py +++ b/metagpt/tools/functions/register/register.py @@ -4,6 +4,7 @@ # @Author : lidanyang # @File : register.py # @Desc : +import inspect from typing import Type, Optional, Callable, Dict, Union, List from metagpt.tools.functions.schemas.base import ToolSchema @@ -13,16 +14,28 @@ class FunctionRegistry: def __init__(self): self.functions: Dict[str, Dict[str, Dict]] = {} - def register(self, module: str, tool_schema: Type[ToolSchema]) -> Callable: + @staticmethod + def _check_param_consistency(func_params, schema): + param_names = set(func_params.keys()) + schema_names = set(schema["parameters"]["properties"].keys()) + if param_names != schema_names: + raise ValueError("Function parameters do not match schema properties") + + def register(self, module: str, tool_schema: Type[ToolSchema]) -> Callable: def wrapper(func: Callable) -> Callable: module_registry = self.functions.setdefault(module, {}) if func.__name__ in module_registry: raise ValueError(f"Function {func.__name__} is already registered in {module}") + func_params = inspect.signature(func).parameters + schema = tool_schema.schema() schema["name"] = func.__name__ + + self._check_param_consistency(func_params, schema) + module_registry[func.__name__] = { "func": func, "schema": schema,