mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
rm redundant function and docstring in libs
This commit is contained in:
parent
b7d0379fae
commit
321a4c0d75
9 changed files with 176 additions and 508 deletions
|
|
@ -7,11 +7,10 @@
|
|||
"""
|
||||
|
||||
from enum import Enum
|
||||
from metagpt.tools import tool_types # this registers all tool types
|
||||
from metagpt.tools import libs # this registers all tools
|
||||
from metagpt.tools.tool_registry import TOOL_REGISTRY
|
||||
|
||||
_ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error
|
||||
_ = libs, TOOL_REGISTRY # Avoid pre-commit error
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
|
|
|
|||
|
|
@ -19,14 +19,29 @@ from metagpt.tools.tool_types import ToolTypes
|
|||
TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name
|
||||
|
||||
|
||||
class MLProcess(object):
|
||||
def fit(self, df):
|
||||
class MLProcess:
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit a model to be used in subsequent transform.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def transform(self, df):
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fit_transform(self, df) -> pd.DataFrame:
|
||||
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Fit and transform the input DataFrame.
|
||||
|
||||
|
|
@ -40,6 +55,49 @@ class MLProcess(object):
|
|||
return self.transform(df)
|
||||
|
||||
|
||||
class DataPreprocessTool(MLProcess):
|
||||
"""
|
||||
Completing a data preprocessing operation.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
"""
|
||||
self.features = features
|
||||
self.model = None # to be filled by specific subclass Tool
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit a model to be used in subsequent transform.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
if len(self.features) == 0:
|
||||
return
|
||||
self.model.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.model.transform(new_df[self.features])
|
||||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class FillMissingValue(MLProcess):
|
||||
"""
|
||||
|
|
@ -58,282 +116,77 @@ class FillMissingValue(MLProcess):
|
|||
Defaults to None.
|
||||
"""
|
||||
self.features = features
|
||||
self.strategy = strategy
|
||||
self.fill_value = fill_value
|
||||
self.si = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the FillMissingValue model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
if len(self.features) == 0:
|
||||
return
|
||||
self.si = SimpleImputer(strategy=self.strategy, fill_value=self.fill_value)
|
||||
self.si.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.si.transform(new_df[self.features])
|
||||
return new_df
|
||||
self.model = SimpleImputer(strategy=strategy, fill_value=fill_value)
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MinMaxScale(MLProcess):
|
||||
class MinMaxScale(DataPreprocessTool):
|
||||
"""
|
||||
Transform features by scaling each feature to a range, which is (0, 1).
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
"""
|
||||
self.features = features
|
||||
self.mms = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the MinMaxScale model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.mms = MinMaxScaler()
|
||||
self.mms.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.mms.transform(new_df[self.features])
|
||||
return new_df
|
||||
self.model = MinMaxScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class StandardScale(MLProcess):
|
||||
class StandardScale(DataPreprocessTool):
|
||||
"""
|
||||
Standardize features by removing the mean and scaling to unit variance.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
"""
|
||||
self.features = features
|
||||
self.ss = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the StandardScale model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.ss = StandardScaler()
|
||||
self.ss.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.ss.transform(new_df[self.features])
|
||||
return new_df
|
||||
self.model = StandardScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MaxAbsScale(MLProcess):
|
||||
class MaxAbsScale(DataPreprocessTool):
|
||||
"""
|
||||
Scale each feature by its maximum absolute value.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
"""
|
||||
self.features = features
|
||||
self.mas = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the MaxAbsScale model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.mas = MaxAbsScaler()
|
||||
self.mas.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.mas.transform(new_df[self.features])
|
||||
return new_df
|
||||
self.model = MaxAbsScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class RobustScale(MLProcess):
|
||||
class RobustScale(DataPreprocessTool):
|
||||
"""
|
||||
Apply the RobustScaler to scale features using statistics that are robust to outliers.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize the RobustScale instance with feature names.
|
||||
|
||||
Args:
|
||||
features (list): List of feature names to be scaled.
|
||||
"""
|
||||
self.features = features
|
||||
self.rs = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Compute the median and IQR for scaling.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe containing the features.
|
||||
"""
|
||||
self.rs = RobustScaler()
|
||||
self.rs.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
"""
|
||||
Scale features using the previously computed median and IQR.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe containing the features to be scaled.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A new dataframe with scaled features.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.rs.transform(new_df[self.features])
|
||||
return new_df
|
||||
self.model = RobustScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OrdinalEncode(MLProcess):
|
||||
class OrdinalEncode(DataPreprocessTool):
|
||||
"""
|
||||
Encode categorical features as ordinal integers.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize the OrdinalEncode instance with feature names.
|
||||
|
||||
Args:
|
||||
features (list): List of categorical feature names to be encoded.
|
||||
"""
|
||||
self.features = features
|
||||
self.oe = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Learn the ordinal encodings for the features.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe containing the categorical features.
|
||||
"""
|
||||
self.oe = OrdinalEncoder()
|
||||
self.oe.fit(df[self.features])
|
||||
|
||||
def transform(self, df: pd.DataFrame):
|
||||
"""
|
||||
Convert the categorical features to ordinal integers.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): Dataframe containing the categorical features to be encoded.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A new dataframe with the encoded features.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[self.features] = self.oe.transform(new_df[self.features])
|
||||
return new_df
|
||||
self.model = OrdinalEncoder()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OneHotEncode(MLProcess):
|
||||
class OneHotEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply one-hot encoding to specified categorical columns, the original columns will be dropped.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Categorical columns to be one-hot encoded and dropped.
|
||||
"""
|
||||
self.features = features
|
||||
self.ohe = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the OneHotEncoding model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.ohe = OneHotEncoder(handle_unknown="ignore", sparse=False)
|
||||
self.ohe.fit(df[self.features])
|
||||
self.model = OneHotEncoder(handle_unknown="ignore", sparse=False)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
ts_data = self.ohe.transform(df[self.features])
|
||||
new_columns = self.ohe.get_feature_names_out(self.features)
|
||||
ts_data = self.model.transform(df[self.features])
|
||||
new_columns = self.model.get_feature_names_out(self.features)
|
||||
ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index)
|
||||
new_df = df.drop(self.features, axis=1)
|
||||
new_df = pd.concat([new_df, ts_data], axis=1)
|
||||
|
|
@ -341,7 +194,7 @@ class OneHotEncode(MLProcess):
|
|||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class LabelEncode(MLProcess):
|
||||
class LabelEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply label encoding to specified categorical columns in-place.
|
||||
"""
|
||||
|
|
@ -357,12 +210,6 @@ class LabelEncode(MLProcess):
|
|||
self.le_encoders = []
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the LabelEncode model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
if len(self.features) == 0:
|
||||
return
|
||||
for col in self.features:
|
||||
|
|
@ -370,15 +217,6 @@ class LabelEncode(MLProcess):
|
|||
self.le_encoders.append(le)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
if len(self.features) == 0:
|
||||
return df
|
||||
new_df = df.copy()
|
||||
|
|
|
|||
|
|
@ -45,12 +45,6 @@ class PolynomialExpansion(MLProcess):
|
|||
self.poly = PolynomialFeatures(degree=degree, include_bias=False)
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the PolynomialExpansion model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
if len(self.cols) == 0:
|
||||
return
|
||||
if len(self.cols) > 10:
|
||||
|
|
@ -61,15 +55,6 @@ class PolynomialExpansion(MLProcess):
|
|||
self.poly.fit(df[self.cols].fillna(0))
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame without duplicated columns.
|
||||
"""
|
||||
if len(self.cols) == 0:
|
||||
return df
|
||||
ts_data = self.poly.transform(df[self.cols].fillna(0))
|
||||
|
|
@ -97,24 +82,9 @@ class CatCount(MLProcess):
|
|||
self.encoder_dict = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the CatCount model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.encoder_dict = df[self.col].value_counts().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_cnt"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
|
@ -139,24 +109,9 @@ class TargetMeanEncoder(MLProcess):
|
|||
self.encoder_dict = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the TargetMeanEncoder model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.encoder_dict = df.groupby(self.col)[self.label].mean().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_target_mean"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
|
@ -185,12 +140,6 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
self.encoder_dict = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the KFoldTargetMeanEncoder model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
tmp = df.copy()
|
||||
kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)
|
||||
|
||||
|
|
@ -203,15 +152,6 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
self.encoder_dict = tmp.groupby(self.col)[col_name].mean().to_dict()
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[f"{self.col}_kf_target_mean"] = new_df[self.col].map(self.encoder_dict)
|
||||
return new_df
|
||||
|
|
@ -255,12 +195,6 @@ class CatCross(MLProcess):
|
|||
return new_col, comb_map
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the CatCross model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
for col in self.cols:
|
||||
if df[col].nunique() > self.max_cat_num:
|
||||
self.cols.remove(col)
|
||||
|
|
@ -269,15 +203,6 @@ class CatCross(MLProcess):
|
|||
self.combs_map = dict(res)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
for comb in self.combs:
|
||||
new_col = f"{comb[0]}_{comb[1]}"
|
||||
|
|
@ -310,12 +235,6 @@ class GroupStat(MLProcess):
|
|||
self.group_df = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the GroupStat model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
group_df = df.groupby(self.group_col)[self.agg_col].agg(self.agg_funcs).reset_index()
|
||||
group_df.columns = [self.group_col] + [
|
||||
f"{self.agg_col}_{agg_func}_by_{self.group_col}" for agg_func in self.agg_funcs
|
||||
|
|
@ -323,15 +242,6 @@ class GroupStat(MLProcess):
|
|||
self.group_df = group_df
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.merge(self.group_df, on=self.group_col, how="left")
|
||||
return new_df
|
||||
|
||||
|
|
@ -355,25 +265,10 @@ class SplitBins(MLProcess):
|
|||
self.encoder = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the SplitBins model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
self.encoder = KBinsDiscretizer(strategy=self.strategy, encode="ordinal")
|
||||
self.encoder.fit(df[self.cols].fillna(0))
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
new_df = df.copy()
|
||||
new_df[self.cols] = self.encoder.transform(new_df[self.cols].fillna(0))
|
||||
return new_df
|
||||
|
|
@ -397,24 +292,9 @@ class ExtractTimeComps(MLProcess):
|
|||
self.time_comps = time_comps
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the ExtractTimeComps model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
pass
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
time_s = pd.to_datetime(df[self.time_col], errors="coerce")
|
||||
time_comps_df = pd.DataFrame()
|
||||
|
||||
|
|
@ -445,12 +325,6 @@ class GeneralSelection(MLProcess):
|
|||
self.feats = []
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the GeneralSelection model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
feats = [f for f in df.columns if f != self.label_col]
|
||||
for col in df.columns:
|
||||
if df[col].isnull().sum() / df.shape[0] == 1:
|
||||
|
|
@ -468,15 +342,6 @@ class GeneralSelection(MLProcess):
|
|||
self.feats = feats
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame contain label_col.
|
||||
"""
|
||||
new_df = df[self.feats + [self.label_col]]
|
||||
return new_df
|
||||
|
||||
|
|
@ -501,12 +366,6 @@ class TreeBasedSelection(MLProcess):
|
|||
self.feats = None
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the TreeBasedSelection model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
params = {
|
||||
"boosting_type": "gbdt",
|
||||
"objective": "binary",
|
||||
|
|
@ -538,15 +397,6 @@ class TreeBasedSelection(MLProcess):
|
|||
self.feats.append(self.label_col)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame contain label_col.
|
||||
"""
|
||||
new_df = df[self.feats]
|
||||
return new_df
|
||||
|
||||
|
|
@ -571,12 +421,6 @@ class VarianceBasedSelection(MLProcess):
|
|||
self.selector = VarianceThreshold(threshold=self.threshold)
|
||||
|
||||
def fit(self, df: pd.DataFrame):
|
||||
"""
|
||||
Fit the VarianceBasedSelection model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
"""
|
||||
num_cols = df.select_dtypes(include=np.number).columns.tolist()
|
||||
cols = [f for f in num_cols if f not in [self.label_col]]
|
||||
|
||||
|
|
@ -585,14 +429,5 @@ class VarianceBasedSelection(MLProcess):
|
|||
self.feats.append(self.label_col)
|
||||
|
||||
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input DataFrame with the fitted model.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The input DataFrame.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame contain label_col.
|
||||
"""
|
||||
new_df = df[self.feats]
|
||||
return new_df
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
|
|||
for name, method in inspect.getmembers(obj, inspect.isfunction):
|
||||
if include and name not in include:
|
||||
continue
|
||||
method_doc = inspect.getdoc(method)
|
||||
# method_doc = inspect.getdoc(method)
|
||||
method_doc = get_class_method_docstring(obj, name)
|
||||
if method_doc:
|
||||
schema["methods"][name] = docstring_to_schema(method_doc)
|
||||
|
||||
|
|
@ -22,8 +23,6 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
|
|||
**docstring_to_schema(docstring),
|
||||
}
|
||||
|
||||
schema = {obj.__name__: schema}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
|
|
@ -70,3 +69,13 @@ def docstring_to_schema(docstring: str):
|
|||
schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns]
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def get_class_method_docstring(cls, method_name):
|
||||
"""Retrieve a method's docstring, searching the class hierarchy if necessary."""
|
||||
for base_class in cls.__mro__:
|
||||
if method_name in base_class.__dict__:
|
||||
method = base_class.__dict__[method_name]
|
||||
if method.__doc__:
|
||||
return method.__doc__
|
||||
return None # No docstring found in the class hierarchy
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ class ToolRegistry(BaseModel):
|
|||
tool_type="other",
|
||||
tool_source_object=None,
|
||||
include_functions=[],
|
||||
make_schema_if_not_exists=True,
|
||||
verbose=False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
|
|
@ -57,19 +56,11 @@ class ToolRegistry(BaseModel):
|
|||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
|
||||
if not os.path.exists(schema_path):
|
||||
if make_schema_if_not_exists:
|
||||
logger.warning(f"no schema found, will make schema at {schema_path}")
|
||||
schema_dict = make_schema(tool_source_object, include_functions, schema_path)
|
||||
else:
|
||||
logger.warning(f"no schema found at assumed schema_path {schema_path}, skip registering {tool_name}")
|
||||
return
|
||||
else:
|
||||
with open(schema_path, "r", encoding="utf-8") as f:
|
||||
schema_dict = yaml.safe_load(f)
|
||||
if not schema_dict:
|
||||
schemas = make_schema(tool_source_object, include_functions, schema_path)
|
||||
|
||||
if not schemas:
|
||||
return
|
||||
schemas = schema_dict.get(tool_name) or list(schema_dict.values())[0]
|
||||
|
||||
schemas["tool_path"] = tool_path # corresponding code file path of the tool
|
||||
try:
|
||||
ToolSchema(**schemas) # validation
|
||||
|
|
@ -78,11 +69,13 @@ class ToolRegistry(BaseModel):
|
|||
# logger.warning(
|
||||
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
|
||||
# )
|
||||
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
|
||||
self.tools[tool_name] = tool
|
||||
self.tools_by_types[tool_type][tool_name] = tool
|
||||
if verbose:
|
||||
logger.info(f"{tool_name} registered")
|
||||
logger.info(f"schema made at {str(schema_path)}, can be used for checking")
|
||||
|
||||
def has_tool(self, key: str) -> Tool:
|
||||
return key in self.tools
|
||||
|
|
@ -107,12 +100,10 @@ class ToolRegistry(BaseModel):
|
|||
TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes)
|
||||
|
||||
|
||||
def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: str = "", **kwargs):
|
||||
def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls, tool_name=tool_name):
|
||||
tool_name = tool_name or cls.__name__
|
||||
|
||||
def decorator(cls):
|
||||
# Get the file path where the function / class is defined and the source code
|
||||
file_path = inspect.getfile(cls)
|
||||
if "metagpt" in file_path:
|
||||
|
|
@ -120,7 +111,7 @@ def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: st
|
|||
source_code = inspect.getsource(cls)
|
||||
|
||||
TOOL_REGISTRY.register_tool(
|
||||
tool_name=tool_name,
|
||||
tool_name=cls.__name__,
|
||||
tool_path=file_path,
|
||||
schema_path=schema_path,
|
||||
tool_code=source_code,
|
||||
|
|
@ -142,7 +133,6 @@ def make_schema(tool_source_object, include, path):
|
|||
# import json
|
||||
# with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f:
|
||||
# json.dump(schema, f, ensure_ascii=False, indent=4)
|
||||
logger.info(f"schema made at {path}")
|
||||
except Exception as e:
|
||||
schema = {}
|
||||
logger.error(f"Fail to make schema: {e}")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
def remove_spaces(text):
|
||||
return re.sub(r"\s+", " ", text)
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
class DocstringParser(BaseModel):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def test_docstring_to_schema():
|
|||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": " Some test desc. ",
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
|
|
@ -97,47 +97,45 @@ def dummy_fn(df: pd.DataFrame) -> dict:
|
|||
|
||||
def test_convert_code_to_tool_schema_class():
|
||||
expected = {
|
||||
"DummyClass": {
|
||||
"type": "class",
|
||||
"description": "Completing missing values with simple strategies.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"description": "Initialize self. ",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
"type": "class",
|
||||
"description": "Completing missing values with simple strategies.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
},
|
||||
"fit": {
|
||||
"description": "Fit the FillMissingValue model. ",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
},
|
||||
"transform": {
|
||||
"description": "Transform the input DataFrame with the fitted model. ",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
"required": ["features"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"fit": {
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
},
|
||||
"transform": {
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
},
|
||||
},
|
||||
}
|
||||
schema = convert_code_to_tool_schema(DummyClass)
|
||||
assert schema == expected
|
||||
|
|
@ -145,14 +143,12 @@ def test_convert_code_to_tool_schema_class():
|
|||
|
||||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"dummy_fn": {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
}
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
|
|
|
|||
|
|
@ -14,18 +14,6 @@ def tool_registry_full():
|
|||
return ToolRegistry(tool_types=ToolTypes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema_yaml(mocker):
|
||||
mock_yaml_content = """
|
||||
tool_name:
|
||||
key1: value1
|
||||
key2: value2
|
||||
"""
|
||||
mocker.patch("os.path.exists", return_value=True)
|
||||
mocker.patch("builtins.open", mocker.mock_open(read_data=mock_yaml_content))
|
||||
return mocker
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
|
|
@ -42,33 +30,46 @@ def test_initialize_with_tool_types(tool_registry_full):
|
|||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
|
||||
|
||||
# Test Tool Registration
|
||||
def test_register_tool(tool_registry, schema_yaml):
|
||||
tool_registry.register_tool("TestTool", "/path/to/tool")
|
||||
assert "TestTool" in tool_registry.tools
|
||||
class TestClassTool:
|
||||
"""test class"""
|
||||
|
||||
def test_class_fn(self):
|
||||
"""test class fn"""
|
||||
pass
|
||||
|
||||
|
||||
# Test Tool Registration with Non-existing Schema
|
||||
def test_register_tool_no_schema(tool_registry, mocker):
|
||||
mocker.patch("os.path.exists", return_value=False)
|
||||
tool_registry.register_tool("TestTool", "/path/to/tool")
|
||||
assert "TestTool" not in tool_registry.tools
|
||||
def test_fn():
|
||||
"""test function"""
|
||||
pass
|
||||
|
||||
|
||||
# Test Tool Registration Class
|
||||
def test_register_tool_class(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
assert "TestClassTool" in tool_registry.tools
|
||||
|
||||
|
||||
# Test Tool Registration Function
|
||||
def test_register_tool_fn(tool_registry):
|
||||
tool_registry.register_tool("test_fn", "/path/to/tool", tool_source_object=test_fn)
|
||||
assert "test_fn" in tool_registry.tools
|
||||
|
||||
|
||||
# Test Tool Existence Checks
|
||||
def test_has_tool(tool_registry, schema_yaml):
|
||||
tool_registry.register_tool("TestTool", "/path/to/tool")
|
||||
assert tool_registry.has_tool("TestTool")
|
||||
def test_has_tool(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
assert tool_registry.has_tool("TestClassTool")
|
||||
assert not tool_registry.has_tool("NonexistentTool")
|
||||
|
||||
|
||||
# Test Tool Retrieval
|
||||
def test_get_tool(tool_registry, schema_yaml):
|
||||
tool_registry.register_tool("TestTool", "/path/to/tool")
|
||||
tool = tool_registry.get_tool("TestTool")
|
||||
def test_get_tool(tool_registry):
|
||||
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
|
||||
tool = tool_registry.get_tool("TestClassTool")
|
||||
assert tool is not None
|
||||
assert tool.name == "TestTool"
|
||||
assert tool.name == "TestClassTool"
|
||||
assert tool.path == "/path/to/tool"
|
||||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
|
|
@ -83,12 +84,12 @@ def test_get_tool_type(tool_registry_full):
|
|||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry, schema_yaml):
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.utils.save_code import DATA_PATH, save_code_file
|
|||
def test_save_code_file_python():
|
||||
save_code_file("example", "print('Hello, World!')")
|
||||
file_path = DATA_PATH / "output" / "example" / "code.py"
|
||||
assert file_path.exists, f"File does not exist: {file_path}"
|
||||
assert file_path.exists(), f"File does not exist: {file_path}"
|
||||
content = file_path.read_text()
|
||||
assert "print('Hello, World!')" in content, "File content does not match"
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ async def test_save_code_file_notebook():
|
|||
# Save as a Notebook file
|
||||
save_code_file("example_nb", executor.nb, file_format="ipynb")
|
||||
file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb"
|
||||
assert file_path.exists, f"Notebook file does not exist: {file_path}"
|
||||
assert file_path.exists(), f"Notebook file does not exist: {file_path}"
|
||||
|
||||
# Additional checks specific to notebook format
|
||||
notebook = nbformat.read(file_path, as_version=4)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue