rm redundant function and docstring in libs

This commit is contained in:
yzlin 2024-02-04 20:25:49 +08:00
parent b7d0379fae
commit 321a4c0d75
9 changed files with 176 additions and 508 deletions

View file

@ -7,11 +7,10 @@
"""
from enum import Enum
from metagpt.tools import tool_types # this registers all tool types
from metagpt.tools import libs # this registers all tools
from metagpt.tools.tool_registry import TOOL_REGISTRY
_ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error
_ = libs, TOOL_REGISTRY # Avoid pre-commit error
class SearchEngineType(Enum):

View file

@ -19,14 +19,29 @@ from metagpt.tools.tool_types import ToolTypes
TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name
class MLProcess(object):
def fit(self, df):
class MLProcess:
def fit(self, df: pd.DataFrame):
"""
Fit a model to be used in subsequent transform.
Args:
df (pd.DataFrame): The input DataFrame.
"""
raise NotImplementedError
def transform(self, df):
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
raise NotImplementedError
def fit_transform(self, df) -> pd.DataFrame:
def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Fit and transform the input DataFrame.
@ -40,6 +55,49 @@ class MLProcess(object):
return self.transform(df)
class DataPreprocessTool(MLProcess):
"""
Completing a data preprocessing operation.
"""
def __init__(self, features: list):
"""
Initialize self.
Args:
features (list): Columns to be processed.
"""
self.features = features
self.model = None # to be filled by specific subclass Tool
def fit(self, df: pd.DataFrame):
"""
Fit a model to be used in subsequent transform.
Args:
df (pd.DataFrame): The input DataFrame.
"""
if len(self.features) == 0:
return
self.model.fit(df[self.features])
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
if len(self.features) == 0:
return df
new_df = df.copy()
new_df[self.features] = self.model.transform(new_df[self.features])
return new_df
@register_tool(tool_type=TOOL_TYPE)
class FillMissingValue(MLProcess):
"""
@ -58,282 +116,77 @@ class FillMissingValue(MLProcess):
Defaults to None.
"""
self.features = features
self.strategy = strategy
self.fill_value = fill_value
self.si = None
def fit(self, df: pd.DataFrame):
"""
Fit the FillMissingValue model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
if len(self.features) == 0:
return
self.si = SimpleImputer(strategy=self.strategy, fill_value=self.fill_value)
self.si.fit(df[self.features])
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
if len(self.features) == 0:
return df
new_df = df.copy()
new_df[self.features] = self.si.transform(new_df[self.features])
return new_df
self.model = SimpleImputer(strategy=strategy, fill_value=fill_value)
@register_tool(tool_type=TOOL_TYPE)
class MinMaxScale(MLProcess):
class MinMaxScale(DataPreprocessTool):
"""
Transform features by scaling each feature to a range, which is (0, 1).
"""
def __init__(self, features: list):
"""
Initialize self.
Args:
features (list): Columns to be processed.
"""
self.features = features
self.mms = None
def fit(self, df: pd.DataFrame):
"""
Fit the MinMaxScale model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.mms = MinMaxScaler()
self.mms.fit(df[self.features])
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[self.features] = self.mms.transform(new_df[self.features])
return new_df
self.model = MinMaxScaler()
@register_tool(tool_type=TOOL_TYPE)
class StandardScale(MLProcess):
class StandardScale(DataPreprocessTool):
"""
Standardize features by removing the mean and scaling to unit variance.
"""
def __init__(self, features: list):
"""
Initialize self.
Args:
features (list): Columns to be processed.
"""
self.features = features
self.ss = None
def fit(self, df: pd.DataFrame):
"""
Fit the StandardScale model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.ss = StandardScaler()
self.ss.fit(df[self.features])
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[self.features] = self.ss.transform(new_df[self.features])
return new_df
self.model = StandardScaler()
@register_tool(tool_type=TOOL_TYPE)
class MaxAbsScale(MLProcess):
class MaxAbsScale(DataPreprocessTool):
"""
Scale each feature by its maximum absolute value.
"""
def __init__(self, features: list):
"""
Initialize self.
Args:
features (list): Columns to be processed.
"""
self.features = features
self.mas = None
def fit(self, df: pd.DataFrame):
"""
Fit the MaxAbsScale model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.mas = MaxAbsScaler()
self.mas.fit(df[self.features])
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[self.features] = self.mas.transform(new_df[self.features])
return new_df
self.model = MaxAbsScaler()
@register_tool(tool_type=TOOL_TYPE)
class RobustScale(MLProcess):
class RobustScale(DataPreprocessTool):
"""
Apply the RobustScaler to scale features using statistics that are robust to outliers.
"""
def __init__(self, features: list):
"""
Initialize the RobustScale instance with feature names.
Args:
features (list): List of feature names to be scaled.
"""
self.features = features
self.rs = None
def fit(self, df: pd.DataFrame):
"""
Compute the median and IQR for scaling.
Args:
df (pd.DataFrame): Dataframe containing the features.
"""
self.rs = RobustScaler()
self.rs.fit(df[self.features])
def transform(self, df: pd.DataFrame):
"""
Scale features using the previously computed median and IQR.
Args:
df (pd.DataFrame): Dataframe containing the features to be scaled.
Returns:
pd.DataFrame: A new dataframe with scaled features.
"""
new_df = df.copy()
new_df[self.features] = self.rs.transform(new_df[self.features])
return new_df
self.model = RobustScaler()
@register_tool(tool_type=TOOL_TYPE)
class OrdinalEncode(MLProcess):
class OrdinalEncode(DataPreprocessTool):
"""
Encode categorical features as ordinal integers.
"""
def __init__(self, features: list):
"""
Initialize the OrdinalEncode instance with feature names.
Args:
features (list): List of categorical feature names to be encoded.
"""
self.features = features
self.oe = None
def fit(self, df: pd.DataFrame):
"""
Learn the ordinal encodings for the features.
Args:
df (pd.DataFrame): Dataframe containing the categorical features.
"""
self.oe = OrdinalEncoder()
self.oe.fit(df[self.features])
def transform(self, df: pd.DataFrame):
"""
Convert the categorical features to ordinal integers.
Args:
df (pd.DataFrame): Dataframe containing the categorical features to be encoded.
Returns:
pd.DataFrame: A new dataframe with the encoded features.
"""
new_df = df.copy()
new_df[self.features] = self.oe.transform(new_df[self.features])
return new_df
self.model = OrdinalEncoder()
@register_tool(tool_type=TOOL_TYPE)
class OneHotEncode(MLProcess):
class OneHotEncode(DataPreprocessTool):
"""
Apply one-hot encoding to specified categorical columns, the original columns will be dropped.
"""
def __init__(self, features: list):
"""
Initialize self.
Args:
features (list): Categorical columns to be one-hot encoded and dropped.
"""
self.features = features
self.ohe = None
def fit(self, df: pd.DataFrame):
"""
Fit the OneHotEncoding model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.ohe = OneHotEncoder(handle_unknown="ignore", sparse=False)
self.ohe.fit(df[self.features])
self.model = OneHotEncoder(handle_unknown="ignore", sparse=False)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
ts_data = self.ohe.transform(df[self.features])
new_columns = self.ohe.get_feature_names_out(self.features)
ts_data = self.model.transform(df[self.features])
new_columns = self.model.get_feature_names_out(self.features)
ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index)
new_df = df.drop(self.features, axis=1)
new_df = pd.concat([new_df, ts_data], axis=1)
@ -341,7 +194,7 @@ class OneHotEncode(MLProcess):
@register_tool(tool_type=TOOL_TYPE)
class LabelEncode(MLProcess):
class LabelEncode(DataPreprocessTool):
"""
Apply label encoding to specified categorical columns in-place.
"""
@ -357,12 +210,6 @@ class LabelEncode(MLProcess):
self.le_encoders = []
def fit(self, df: pd.DataFrame):
"""
Fit the LabelEncode model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
if len(self.features) == 0:
return
for col in self.features:
@ -370,15 +217,6 @@ class LabelEncode(MLProcess):
self.le_encoders.append(le)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
if len(self.features) == 0:
return df
new_df = df.copy()

View file

@ -45,12 +45,6 @@ class PolynomialExpansion(MLProcess):
self.poly = PolynomialFeatures(degree=degree, include_bias=False)
def fit(self, df: pd.DataFrame):
"""
Fit the PolynomialExpansion model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
if len(self.cols) == 0:
return
if len(self.cols) > 10:
@ -61,15 +55,6 @@ class PolynomialExpansion(MLProcess):
self.poly.fit(df[self.cols].fillna(0))
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame without duplicated columns.
"""
if len(self.cols) == 0:
return df
ts_data = self.poly.transform(df[self.cols].fillna(0))
@ -97,24 +82,9 @@ class CatCount(MLProcess):
self.encoder_dict = None
def fit(self, df: pd.DataFrame):
"""
Fit the CatCount model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.encoder_dict = df[self.col].value_counts().to_dict()
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[f"{self.col}_cnt"] = new_df[self.col].map(self.encoder_dict)
return new_df
@ -139,24 +109,9 @@ class TargetMeanEncoder(MLProcess):
self.encoder_dict = None
def fit(self, df: pd.DataFrame):
"""
Fit the TargetMeanEncoder model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.encoder_dict = df.groupby(self.col)[self.label].mean().to_dict()
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[f"{self.col}_target_mean"] = new_df[self.col].map(self.encoder_dict)
return new_df
@ -185,12 +140,6 @@ class KFoldTargetMeanEncoder(MLProcess):
self.encoder_dict = None
def fit(self, df: pd.DataFrame):
"""
Fit the KFoldTargetMeanEncoder model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
tmp = df.copy()
kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state)
@ -203,15 +152,6 @@ class KFoldTargetMeanEncoder(MLProcess):
self.encoder_dict = tmp.groupby(self.col)[col_name].mean().to_dict()
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[f"{self.col}_kf_target_mean"] = new_df[self.col].map(self.encoder_dict)
return new_df
@ -255,12 +195,6 @@ class CatCross(MLProcess):
return new_col, comb_map
def fit(self, df: pd.DataFrame):
"""
Fit the CatCross model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
for col in self.cols:
if df[col].nunique() > self.max_cat_num:
self.cols.remove(col)
@ -269,15 +203,6 @@ class CatCross(MLProcess):
self.combs_map = dict(res)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
for comb in self.combs:
new_col = f"{comb[0]}_{comb[1]}"
@ -310,12 +235,6 @@ class GroupStat(MLProcess):
self.group_df = None
def fit(self, df: pd.DataFrame):
"""
Fit the GroupStat model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
group_df = df.groupby(self.group_col)[self.agg_col].agg(self.agg_funcs).reset_index()
group_df.columns = [self.group_col] + [
f"{self.agg_col}_{agg_func}_by_{self.group_col}" for agg_func in self.agg_funcs
@ -323,15 +242,6 @@ class GroupStat(MLProcess):
self.group_df = group_df
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.merge(self.group_df, on=self.group_col, how="left")
return new_df
@ -355,25 +265,10 @@ class SplitBins(MLProcess):
self.encoder = None
def fit(self, df: pd.DataFrame):
"""
Fit the SplitBins model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
self.encoder = KBinsDiscretizer(strategy=self.strategy, encode="ordinal")
self.encoder.fit(df[self.cols].fillna(0))
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
new_df = df.copy()
new_df[self.cols] = self.encoder.transform(new_df[self.cols].fillna(0))
return new_df
@ -397,24 +292,9 @@ class ExtractTimeComps(MLProcess):
self.time_comps = time_comps
def fit(self, df: pd.DataFrame):
"""
Fit the ExtractTimeComps model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
pass
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
time_s = pd.to_datetime(df[self.time_col], errors="coerce")
time_comps_df = pd.DataFrame()
@ -445,12 +325,6 @@ class GeneralSelection(MLProcess):
self.feats = []
def fit(self, df: pd.DataFrame):
"""
Fit the GeneralSelection model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
feats = [f for f in df.columns if f != self.label_col]
for col in df.columns:
if df[col].isnull().sum() / df.shape[0] == 1:
@ -468,15 +342,6 @@ class GeneralSelection(MLProcess):
self.feats = feats
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame contain label_col.
"""
new_df = df[self.feats + [self.label_col]]
return new_df
@ -501,12 +366,6 @@ class TreeBasedSelection(MLProcess):
self.feats = None
def fit(self, df: pd.DataFrame):
"""
Fit the TreeBasedSelection model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
params = {
"boosting_type": "gbdt",
"objective": "binary",
@ -538,15 +397,6 @@ class TreeBasedSelection(MLProcess):
self.feats.append(self.label_col)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame contain label_col.
"""
new_df = df[self.feats]
return new_df
@ -571,12 +421,6 @@ class VarianceBasedSelection(MLProcess):
self.selector = VarianceThreshold(threshold=self.threshold)
def fit(self, df: pd.DataFrame):
"""
Fit the VarianceBasedSelection model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
num_cols = df.select_dtypes(include=np.number).columns.tolist()
cols = [f for f in num_cols if f not in [self.label_col]]
@ -585,14 +429,5 @@ class VarianceBasedSelection(MLProcess):
self.feats.append(self.label_col)
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame contain label_col.
"""
new_df = df[self.feats]
return new_df

View file

@ -12,7 +12,8 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
for name, method in inspect.getmembers(obj, inspect.isfunction):
if include and name not in include:
continue
method_doc = inspect.getdoc(method)
# method_doc = inspect.getdoc(method)
method_doc = get_class_method_docstring(obj, name)
if method_doc:
schema["methods"][name] = docstring_to_schema(method_doc)
@ -22,8 +23,6 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
**docstring_to_schema(docstring),
}
schema = {obj.__name__: schema}
return schema
@ -70,3 +69,13 @@ def docstring_to_schema(docstring: str):
schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns]
return schema
def get_class_method_docstring(cls, method_name):
"""Retrieve a method's docstring, searching the class hierarchy if necessary."""
for base_class in cls.__mro__:
if method_name in base_class.__dict__:
method = base_class.__dict__[method_name]
if method.__doc__:
return method.__doc__
return None # No docstring found in the class hierarchy

View file

@ -39,7 +39,6 @@ class ToolRegistry(BaseModel):
tool_type="other",
tool_source_object=None,
include_functions=[],
make_schema_if_not_exists=True,
verbose=False,
):
if self.has_tool(tool_name):
@ -57,19 +56,11 @@ class ToolRegistry(BaseModel):
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
if not os.path.exists(schema_path):
if make_schema_if_not_exists:
logger.warning(f"no schema found, will make schema at {schema_path}")
schema_dict = make_schema(tool_source_object, include_functions, schema_path)
else:
logger.warning(f"no schema found at assumed schema_path {schema_path}, skip registering {tool_name}")
return
else:
with open(schema_path, "r", encoding="utf-8") as f:
schema_dict = yaml.safe_load(f)
if not schema_dict:
schemas = make_schema(tool_source_object, include_functions, schema_path)
if not schemas:
return
schemas = schema_dict.get(tool_name) or list(schema_dict.values())[0]
schemas["tool_path"] = tool_path # corresponding code file path of the tool
try:
ToolSchema(**schemas) # validation
@ -78,11 +69,13 @@ class ToolRegistry(BaseModel):
# logger.warning(
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
# )
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
self.tools[tool_name] = tool
self.tools_by_types[tool_type][tool_name] = tool
if verbose:
logger.info(f"{tool_name} registered")
logger.info(f"schema made at {str(schema_path)}, can be used for checking")
def has_tool(self, key: str) -> Tool:
return key in self.tools
@ -107,12 +100,10 @@ class ToolRegistry(BaseModel):
TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes)
def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: str = "", **kwargs):
def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
"""register a tool to registry"""
def decorator(cls, tool_name=tool_name):
tool_name = tool_name or cls.__name__
def decorator(cls):
# Get the file path where the function / class is defined and the source code
file_path = inspect.getfile(cls)
if "metagpt" in file_path:
@ -120,7 +111,7 @@ def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: st
source_code = inspect.getsource(cls)
TOOL_REGISTRY.register_tool(
tool_name=tool_name,
tool_name=cls.__name__,
tool_path=file_path,
schema_path=schema_path,
tool_code=source_code,
@ -142,7 +133,6 @@ def make_schema(tool_source_object, include, path):
# import json
# with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f:
# json.dump(schema, f, ensure_ascii=False, indent=4)
logger.info(f"schema made at {path}")
except Exception as e:
schema = {}
logger.error(f"Fail to make schema: {e}")

View file

@ -5,7 +5,7 @@ from pydantic import BaseModel
def remove_spaces(text):
return re.sub(r"\s+", " ", text)
return re.sub(r"\s+", " ", text).strip()
class DocstringParser(BaseModel):

View file

@ -17,7 +17,7 @@ def test_docstring_to_schema():
pd.DataFrame: The transformed DataFrame.
"""
expected = {
"description": " Some test desc. ",
"description": "Some test desc.",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
@ -97,47 +97,45 @@ def dummy_fn(df: pd.DataFrame) -> dict:
def test_convert_code_to_tool_schema_class():
expected = {
"DummyClass": {
"type": "class",
"description": "Completing missing values with simple strategies.",
"methods": {
"__init__": {
"description": "Initialize self. ",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
"type": "class",
"description": "Completing missing values with simple strategies.",
"methods": {
"__init__": {
"description": "Initialize self.",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
"required": ["features"],
},
},
"fit": {
"description": "Fit the FillMissingValue model. ",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
},
"transform": {
"description": "Transform the input DataFrame with the fitted model. ",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
"required": ["features"],
},
},
}
"fit": {
"description": "Fit the FillMissingValue model.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
},
"transform": {
"description": "Transform the input DataFrame with the fitted model.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
},
},
}
schema = convert_code_to_tool_schema(DummyClass)
assert schema == expected
@ -145,14 +143,12 @@ def test_convert_code_to_tool_schema_class():
def test_convert_code_to_tool_schema_function():
expected = {
"dummy_fn": {
"type": "function",
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
"required": ["df"],
},
}
"type": "function",
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
"required": ["df"],
},
}
schema = convert_code_to_tool_schema(dummy_fn)
assert schema == expected

View file

@ -14,18 +14,6 @@ def tool_registry_full():
return ToolRegistry(tool_types=ToolTypes)
@pytest.fixture
def schema_yaml(mocker):
mock_yaml_content = """
tool_name:
key1: value1
key2: value2
"""
mocker.patch("os.path.exists", return_value=True)
mocker.patch("builtins.open", mocker.mock_open(read_data=mock_yaml_content))
return mocker
# Test Initialization
def test_initialization(tool_registry):
assert isinstance(tool_registry, ToolRegistry)
@ -42,33 +30,46 @@ def test_initialize_with_tool_types(tool_registry_full):
assert "data_preprocess" in tool_registry_full.tool_types
# Test Tool Registration
def test_register_tool(tool_registry, schema_yaml):
tool_registry.register_tool("TestTool", "/path/to/tool")
assert "TestTool" in tool_registry.tools
class TestClassTool:
"""test class"""
def test_class_fn(self):
"""test class fn"""
pass
# Test Tool Registration with Non-existing Schema
def test_register_tool_no_schema(tool_registry, mocker):
mocker.patch("os.path.exists", return_value=False)
tool_registry.register_tool("TestTool", "/path/to/tool")
assert "TestTool" not in tool_registry.tools
def test_fn():
"""test function"""
pass
# Test Tool Registration Class
def test_register_tool_class(tool_registry):
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
assert "TestClassTool" in tool_registry.tools
# Test Tool Registration Function
def test_register_tool_fn(tool_registry):
tool_registry.register_tool("test_fn", "/path/to/tool", tool_source_object=test_fn)
assert "test_fn" in tool_registry.tools
# Test Tool Existence Checks
def test_has_tool(tool_registry, schema_yaml):
tool_registry.register_tool("TestTool", "/path/to/tool")
assert tool_registry.has_tool("TestTool")
def test_has_tool(tool_registry):
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
assert tool_registry.has_tool("TestClassTool")
assert not tool_registry.has_tool("NonexistentTool")
# Test Tool Retrieval
def test_get_tool(tool_registry, schema_yaml):
tool_registry.register_tool("TestTool", "/path/to/tool")
tool = tool_registry.get_tool("TestTool")
def test_get_tool(tool_registry):
tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool)
tool = tool_registry.get_tool("TestClassTool")
assert tool is not None
assert tool.name == "TestTool"
assert tool.name == "TestClassTool"
assert tool.path == "/path/to/tool"
assert "description" in tool.schemas
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
@ -83,12 +84,12 @@ def test_get_tool_type(tool_registry_full):
assert retrieved_type.name == "data_preprocess"
def test_get_tools_by_type(tool_registry, schema_yaml):
def test_get_tools_by_type(tool_registry):
tool_type_name = "TestType"
tool_name = "TestTool"
tool_path = "/path/to/tool"
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
assert tools_by_type is not None

View file

@ -14,7 +14,7 @@ from metagpt.utils.save_code import DATA_PATH, save_code_file
def test_save_code_file_python():
save_code_file("example", "print('Hello, World!')")
file_path = DATA_PATH / "output" / "example" / "code.py"
assert file_path.exists, f"File does not exist: {file_path}"
assert file_path.exists(), f"File does not exist: {file_path}"
content = file_path.read_text()
assert "print('Hello, World!')" in content, "File content does not match"
@ -35,7 +35,7 @@ async def test_save_code_file_notebook():
# Save as a Notebook file
save_code_file("example_nb", executor.nb, file_format="ipynb")
file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb"
assert file_path.exists, f"Notebook file does not exist: {file_path}"
assert file_path.exists(), f"Notebook file does not exist: {file_path}"
# Additional checks specific to notebook format
notebook = nbformat.read(file_path, as_version=4)