Merge branch 'dev_ldy' into 'dev'

Dev ldy

See merge request agents/data_agents_opt!3
This commit is contained in:
林义章 2023-11-24 12:45:38 +00:00
commit 5d3f51b010
11 changed files with 546 additions and 0 deletions

View file

@ -0,0 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/16 16:32
# @Author : lidanyang
# @File : __init__.py
# @Desc :
from metagpt.tools.functions.register.register import registry
import metagpt.tools.functions.libs.feature_engineering
print(registry.functions)

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/16 16:32
# @Author : lidanyang
# @File : __init__.py
# @Desc :

View file

@ -0,0 +1,174 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/17 10:33
# @Author : lidanyang
# @File : feature_engineering.py
# @Desc : Feature Engineering Functions
import itertools
from dateutil.relativedelta import relativedelta
from pandas.api.types import is_numeric_dtype
from sklearn.preprocessing import PolynomialFeatures, OneHotEncoder
from metagpt.tools.functions import registry
from metagpt.tools.functions.schemas.feature_engineering import *
@registry.register("feature_engineering", PolynomialExpansion)
def polynomial_expansion(df, cols, degree=2):
for col in cols:
if not is_numeric_dtype(df[col]):
raise ValueError(f"Column '{col}' must be numeric.")
poly = PolynomialFeatures(degree=degree, include_bias=False)
ts_data = poly.fit_transform(df[cols].fillna(0))
new_columns = poly.get_feature_names_out(cols)
ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index)
ts_data = ts_data.drop(cols, axis=1)
df = pd.concat([df, ts_data], axis=1)
return df
@registry.register("feature_engineering", OneHotEncoding)
def one_hot_encoding(df, cols):
enc = OneHotEncoder(handle_unknown="ignore", sparse=False)
ts_data = enc.fit_transform(df[cols])
new_columns = enc.get_feature_names_out(cols)
ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index)
df.drop(cols, axis=1, inplace=True)
df = pd.concat([df, ts_data], axis=1)
return df
@registry.register("feature_engineering", FrequencyEncoding)
def frequency_encoding(df, cols):
for col in cols:
encoder_dict = df[col].value_counts().to_dict()
df[f"{col}_cnt"] = df[col].map(encoder_dict)
return df
@registry.register("feature_engineering", CatCross)
def cat_cross(df, cols, max_cat_num=100):
for col in cols:
if df[col].nunique() > max_cat_num:
cols.remove(col)
for col1, col2 in itertools.combinations(cols, 2):
cross_col = f"{col1}_cross_{col2}"
df[cross_col] = df[col1].astype(str) + "_" + df[col2].astype(str)
return df
@registry.register("feature_engineering", GroupStat)
def group_stat(df, group_col, agg_col, agg_funcs):
group_df = df.groupby(group_col)[agg_col].agg(agg_funcs).reset_index()
group_df.columns = group_col + [
f"{agg_col}_{agg_func}_by_{group_col}" for agg_func in agg_funcs
]
df = df.merge(group_df, on=group_col, how="left")
return df
@registry.register("feature_engineering", ExtractTimeComps)
def extract_time_comps(df, time_col, time_comps):
time_s = pd.to_datetime(df[time_col], errors="coerce")
time_comps_df = pd.DataFrame()
if "year" in time_comps:
time_comps_df["year"] = time_s.dt.year
if "month" in time_comps:
time_comps_df["month"] = time_s.dt.month
if "day" in time_comps:
time_comps_df["day"] = time_s.dt.day
if "hour" in time_comps:
time_comps_df["hour"] = time_s.dt.hour
if "dayofweek" in time_comps:
time_comps_df["dayofweek"] = time_s.dt.dayofweek + 1
if "is_weekend" in time_comps:
time_comps_df["is_weekend"] = time_s.dt.dayofweek.isin([5, 6]).astype(int)
df = pd.concat([df, time_comps_df], axis=1)
return df
@registry.register("feature_engineering", FeShiftByTime)
def fe_shift_by_time(df, time_col, group_col, shift_col, periods, freq):
df[time_col] = pd.to_datetime(df[time_col])
def shift_datetime(date, offset, unit):
if unit in ["year", "y", "Y"]:
return date + relativedelta(years=offset)
elif unit in ["month", "m", "M"]:
return date + relativedelta(months=offset)
elif unit in ["day", "d", "D"]:
return date + relativedelta(days=offset)
elif unit in ["week", "w", "W"]:
return date + relativedelta(weeks=offset)
elif unit in ["hour", "h", "H"]:
return date + relativedelta(hours=offset)
else:
return date
def shift_by_time_on_key(
inner_df, time_col, group_col, shift_col, offset, unit, col_name
):
inner_df = inner_df.drop_duplicates()
inner_df[time_col] = inner_df[time_col].map(
lambda x: shift_datetime(x, offset, unit)
)
inner_df = inner_df.groupby([time_col, group_col], as_index=False)[
shift_col
].mean()
inner_df.rename(columns={shift_col: col_name}, inplace=True)
return inner_df
shift_df = df[[time_col, group_col, shift_col]].copy()
for period in periods:
new_col_name = f"{group_col}_{shift_col}_lag_{period}_{freq}"
tmp = shift_by_time_on_key(
shift_df, time_col, group_col, shift_col, period, freq, new_col_name
)
df = df.merge(tmp, on=[time_col, group_col], how="left")
return df
@registry.register("feature_engineering", FeRollingByTime)
def fe_rolling_by_time(df, time_col, group_col, rolling_col, periods, freq, agg_funcs):
df[time_col] = pd.to_datetime(df[time_col])
def rolling_by_time_on_key(inner_df, offset, unit, agg_func, col_name):
time_freq = {
"Y": [365 * offset, "D"],
"M": [30 * offset, "D"],
"D": [offset, "D"],
"W": [7 * offset, "D"],
"H": [offset, "h"],
}
if agg_func not in ["mean", "std", "max", "min", "median", "sum", "count"]:
raise ValueError(f"Invalid agg function: {agg_func}")
rolling_feat = inner_df.rolling(
f"{time_freq[unit][0]}{time_freq[unit][1]}", closed="left"
)
rolling_feat = getattr(rolling_feat, agg_func)()
depth = df.columns.nlevels
rolling_feat = rolling_feat.stack(list(range(depth)))
rolling_feat.name = col_name
return rolling_feat
rolling_df = df[[time_col, group_col, rolling_col]].copy()
for period in periods:
for func in agg_funcs:
new_col_name = f"{group_col}_{rolling_col}_rolling_{period}_{freq}_{func}"
tmp = pd.pivot_table(
rolling_df,
index=time_col,
values=rolling_col,
columns=group_col,
)
tmp = rolling_by_time_on_key(tmp, period, freq, func, new_col_name)
df = df.merge(tmp, on=[time_col, group_col], how="left")
return df

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/16 16:37
# @Author : lidanyang
# @File : __init__.py
# @Desc :

View file

@ -0,0 +1,78 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/16 16:38
# @Author : lidanyang
# @File : register.py
# @Desc :
import inspect
from typing import Type, Optional, Callable, Dict, Union, List
from metagpt.tools.functions.schemas.base import ToolSchema
class FunctionRegistry:
def __init__(self):
self.functions: Dict[str, Dict[str, Dict]] = {}
@staticmethod
def _check_param_consistency(func_params, schema):
param_names = set(func_params.keys())
schema_names = set(schema["parameters"]["properties"].keys())
if param_names != schema_names:
raise ValueError("Function parameters do not match schema properties")
def register(self, module: str, tool_schema: Type[ToolSchema]) -> Callable:
def wrapper(func: Callable) -> Callable:
module_registry = self.functions.setdefault(module, {})
if func.__name__ in module_registry:
raise ValueError(f"Function {func.__name__} is already registered in {module}")
func_params = inspect.signature(func).parameters
schema = tool_schema.schema()
schema["name"] = func.__name__
self._check_param_consistency(func_params, schema)
module_registry[func.__name__] = {
"func": func,
"schema": schema,
}
return func
return wrapper
def get(self, module: str, name: str) -> Optional[Union[Callable, Dict]]:
"""Get function by module and name"""
module_registry = self.functions.get(module, {})
return module_registry.get(name)
def get_by_name(self, name: str) -> Optional[Dict]:
"""Get function by name"""
for module_registry in self.functions.values():
if name in module_registry:
return module_registry.get(name, {})
def get_all_by_module(self, module: str) -> Optional[Dict]:
"""Get all functions by module"""
return self.functions.get(module, {})
def get_schema(self, module: str, name: str) -> Optional[Dict]:
"""Get schema by module and name"""
module_registry = self.functions.get(module, {})
return module_registry.get(name, {}).get("schema")
def get_schemas(self, module: str, names: List[str]) -> List[Dict]:
"""Get schemas by module and names"""
module_registry = self.functions.get(module, {})
return [module_registry.get(name, {}).get("schema") for name in names]
def get_all_schema_by_module(self, module: str) -> List[Dict]:
"""Get all schemas by module"""
module_registry = self.functions.get(module, {})
return [v.get("schema") for v in module_registry.values()]
registry = FunctionRegistry()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/16 16:33
# @Author : lidanyang
# @File : __init__.py
# @Desc :

View file

@ -0,0 +1,100 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/16 16:34
# @Author : lidanyang
# @File : base.py
# @Desc : Build base class to generate schema for tool
from typing import Any, List, Optional, get_type_hints
class NoDefault:
"""
A class to represent a missing default value.
This is used to distinguish between a default value of None and a missing default value.
"""
pass
def tool_field(
description: str, default: Any = NoDefault(), enum: Optional[List[Any]] = None, **kwargs
):
"""
Create a field for a tool parameter.
Args:
description (str): A description of the field.
default (Any, optional): The default value for the field. Defaults to None.
enum (Optional[List[Any]], optional): A list of possible values for the field. Defaults to None.
**kwargs: Additional keyword arguments.
Returns:
dict: A dictionary representing the field with provided attributes.
"""
field_info = {
"description": description,
"default": default,
"enum": enum,
}
field_info.update(kwargs)
return field_info
class ToolSchema:
@staticmethod
def format_type(type_hint):
"""
Format a type hint into a string representation.
Args:
type_hint (type): The type hint to format.
Returns:
str: A string representation of the type hint.
"""
if isinstance(type_hint, type):
# Handle built-in types separately
if type_hint.__module__ == "builtins":
return type_hint.__name__
else:
return f"{type_hint.__module__}.{type_hint.__name__}"
elif hasattr(type_hint, "__origin__") and hasattr(type_hint, "__args__"):
# Handle generic types (like List[int])
origin_type = ToolSchema.format_type(type_hint.__origin__)
args_type = ", ".join(
[ToolSchema.format_type(t) for t in type_hint.__args__]
)
return f"{origin_type}[{args_type}]"
else:
return str(type_hint)
@classmethod
def schema(cls):
"""
Generate a schema dictionary for the class.
The schema includes the class name, description, and information about
each class parameter based on type hints and field definitions.
Returns:
dict: A dictionary representing the schema of the class.
"""
schema = {
"name": cls.__name__,
"description": cls.__doc__,
"parameters": {"type": "object", "properties": {}, "required": []},
}
type_hints = get_type_hints(cls)
for attr, type_hint in type_hints.items():
value = getattr(cls, attr, None)
if isinstance(value, dict):
# Process each attribute that is defined using the field function
prop_info = {k: v for k, v in value.items() if v is not None or k == "default"}
if isinstance(prop_info["default"], NoDefault):
del prop_info["default"]
prop_info["type"] = ToolSchema.format_type(type_hint)
schema["parameters"]["properties"][attr] = prop_info
# Check for required fields
if "default" not in prop_info:
schema["parameters"]["required"].append(attr)
return schema

View file

@ -0,0 +1,100 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/17 10:34
# @Author : lidanyang
# @File : feature_engineering.py
# @Desc : Schema for feature engineering functions
from typing import List
import pandas as pd
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
class PolynomialExpansion(ToolSchema):
"""Generate polynomial and interaction features from selected columns, excluding the bias column."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Columns for polynomial expansion.")
degree: int = tool_field(description="Degree of polynomial features.", default=2)
class OneHotEncoding(ToolSchema):
"""Apply one-hot encoding to specified categorical columns in a DataFrame."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Categorical columns to be one-hot encoded.")
class FrequencyEncoding(ToolSchema):
"""Convert categorical columns to frequency encoding."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Categorical columns to be frequency encoded.")
class CatCross(ToolSchema):
"""Create pairwise crossed features from categorical columns, joining values with '_'."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
cols: list = tool_field(description="Columns to be pairwise crossed.")
max_cat_num: int = tool_field(
description="Maximum unique categories per crossed feature.", default=100
)
class GroupStat(ToolSchema):
"""Perform aggregation operations on a specified column grouped by certain categories."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
group_col: str = tool_field(description="Column used for grouping.")
agg_col: str = tool_field(description="Column on which aggregation is performed.")
agg_funcs: list = tool_field(
description="""List of aggregation functions to apply, such as ['mean', 'std'].
Each function must be supported by pandas."""
)
class ExtractTimeComps(ToolSchema):
"""Extract specific time components from a designated time column in a DataFrame."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
time_col: str = tool_field(
description="The name of the column containing time data."
)
time_comps: List[str] = tool_field(
description="""List of time components to extract.
Each component must be in ['year', 'month', 'day', 'hour', 'dayofweek', 'is_weekend']."""
)
class FeShiftByTime(ToolSchema):
"""Shift column values in a DataFrame based on specified time intervals."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
time_col: str = tool_field(description="Column for time-based shifting.")
group_col: str = tool_field(description="Column for grouping before shifting.")
shift_col: str = tool_field(description="Column to shift.")
periods: list = tool_field(description="Time intervals for shifting.")
freq: str = tool_field(
description="Frequency unit for time intervals (e.g., 'D', 'M').",
enum=["D", "M", "Y", "W", "H"],
)
class FeRollingByTime(ToolSchema):
"""Calculate rolling statistics for a DataFrame column over time intervals."""
df: pd.DataFrame = tool_field(description="DataFrame to process.")
time_col: str = tool_field(description="Column for time-based rolling.")
group_col: str = tool_field(description="Column for grouping before rolling.")
rolling_col: str = tool_field(description="Column for rolling calculations.")
periods: list = tool_field(description="Window sizes for rolling.")
freq: str = tool_field(
description="Frequency unit for time windows (e.g., 'D', 'M').",
enum=["D", "M", "Y", "W", "H"],
)
agg_funcs: list = tool_field(
description="""List of aggregation functions for rolling, like ['mean', 'std'].
Each function must be in ['mean', 'std', 'min', 'max', 'median', 'sum', 'count']."""
)

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/17 10:24
# @Author : lidanyang
# @File : __init__.py
# @Desc :

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/17 10:24
# @Author : lidanyang
# @File : __init__.py
# @Desc :

View file

@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/11/17 10:24
# @Author : lidanyang
# @File : test_register.py
# @Desc :
import pytest
from metagpt.tools.functions.register.register import FunctionRegistry
from metagpt.tools.functions.schemas.base import ToolSchema, tool_field
@pytest.fixture
def registry():
return FunctionRegistry()
class AddNumbers(ToolSchema):
"""Add two numbers"""
num1: int = tool_field(description="First number")
num2: int = tool_field(description="Second number")
def test_register(registry):
@registry.register("module1", AddNumbers)
def add_numbers(num1, num2):
return num1 + num2
assert len(registry.functions["module1"]) == 1
assert "add_numbers" in registry.functions["module1"]
with pytest.raises(ValueError):
@registry.register("module1", AddNumbers)
def add_numbers(num1, num2):
return num1 + num2
func = registry.get("module1", "add_numbers")
assert func["func"](1, 2) == 3
assert func["schema"] == {
"name": "add_numbers",
"description": "Add two numbers",
"parameters": {
"type": "object",
"properties": {
"num1": {"description": "First number", "type": "int"},
"num2": {"description": "Second number", "type": "int"},
},
"required": ["num1", "num2"],
},
}
module1_funcs = registry.get_all_by_module("module1")
assert len(module1_funcs) == 1