diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index bb87f1b62..c1f604df9 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -7,11 +7,10 @@ """ from enum import Enum -from metagpt.tools import tool_types # this registers all tool types from metagpt.tools import libs # this registers all tools from metagpt.tools.tool_registry import TOOL_REGISTRY -_ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error +_ = libs, TOOL_REGISTRY # Avoid pre-commit error class SearchEngineType(Enum): diff --git a/metagpt/tools/libs/data_preprocess.py b/metagpt/tools/libs/data_preprocess.py index 307a6bc5b..9c571ad6b 100644 --- a/metagpt/tools/libs/data_preprocess.py +++ b/metagpt/tools/libs/data_preprocess.py @@ -19,14 +19,29 @@ from metagpt.tools.tool_types import ToolTypes TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name -class MLProcess(object): - def fit(self, df): +class MLProcess: + def fit(self, df: pd.DataFrame): + """ + Fit a model to be used in subsequent transform. + + Args: + df (pd.DataFrame): The input DataFrame. + """ raise NotImplementedError - def transform(self, df): + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Transform the input DataFrame with the fitted model. + + Args: + df (pd.DataFrame): The input DataFrame. + + Returns: + pd.DataFrame: The transformed DataFrame. + """ raise NotImplementedError - def fit_transform(self, df) -> pd.DataFrame: + def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame: """ Fit and transform the input DataFrame. @@ -40,6 +55,49 @@ class MLProcess(object): return self.transform(df) +class DataPreprocessTool(MLProcess): + """ + Completing a data preprocessing operation. + """ + + def __init__(self, features: list): + """ + Initialize self. + + Args: + features (list): Columns to be processed. + """ + self.features = features + self.model = None # to be filled by specific subclass Tool + + def fit(self, df: pd.DataFrame): + """ + Fit a model to be used in subsequent transform. + + Args: + df (pd.DataFrame): The input DataFrame. + """ + if len(self.features) == 0: + return + self.model.fit(df[self.features]) + + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Transform the input DataFrame with the fitted model. + + Args: + df (pd.DataFrame): The input DataFrame. + + Returns: + pd.DataFrame: The transformed DataFrame. + """ + if len(self.features) == 0: + return df + new_df = df.copy() + new_df[self.features] = self.model.transform(new_df[self.features]) + return new_df + + @register_tool(tool_type=TOOL_TYPE) class FillMissingValue(MLProcess): """ @@ -58,282 +116,77 @@ class FillMissingValue(MLProcess): Defaults to None. """ self.features = features - self.strategy = strategy - self.fill_value = fill_value - self.si = None - - def fit(self, df: pd.DataFrame): - """ - Fit the FillMissingValue model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ - if len(self.features) == 0: - return - self.si = SimpleImputer(strategy=self.strategy, fill_value=self.fill_value) - self.si.fit(df[self.features]) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ - if len(self.features) == 0: - return df - new_df = df.copy() - new_df[self.features] = self.si.transform(new_df[self.features]) - return new_df + self.model = SimpleImputer(strategy=strategy, fill_value=fill_value) @register_tool(tool_type=TOOL_TYPE) -class MinMaxScale(MLProcess): +class MinMaxScale(DataPreprocessTool): """ Transform features by scaling each feature to a range, which is (0, 1). """ def __init__(self, features: list): - """ - Initialize self. - - Args: - features (list): Columns to be processed. - """ self.features = features - self.mms = None - - def fit(self, df: pd.DataFrame): - """ - Fit the MinMaxScale model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ - self.mms = MinMaxScaler() - self.mms.fit(df[self.features]) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ - new_df = df.copy() - new_df[self.features] = self.mms.transform(new_df[self.features]) - return new_df + self.model = MinMaxScaler() @register_tool(tool_type=TOOL_TYPE) -class StandardScale(MLProcess): +class StandardScale(DataPreprocessTool): """ Standardize features by removing the mean and scaling to unit variance. """ def __init__(self, features: list): - """ - Initialize self. - - Args: - features (list): Columns to be processed. - """ self.features = features - self.ss = None - - def fit(self, df: pd.DataFrame): - """ - Fit the StandardScale model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ - self.ss = StandardScaler() - self.ss.fit(df[self.features]) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ - new_df = df.copy() - new_df[self.features] = self.ss.transform(new_df[self.features]) - return new_df + self.model = StandardScaler() @register_tool(tool_type=TOOL_TYPE) -class MaxAbsScale(MLProcess): +class MaxAbsScale(DataPreprocessTool): """ Scale each feature by its maximum absolute value. """ def __init__(self, features: list): - """ - Initialize self. - - Args: - features (list): Columns to be processed. - """ self.features = features - self.mas = None - - def fit(self, df: pd.DataFrame): - """ - Fit the MaxAbsScale model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ - self.mas = MaxAbsScaler() - self.mas.fit(df[self.features]) - - def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ - new_df = df.copy() - new_df[self.features] = self.mas.transform(new_df[self.features]) - return new_df + self.model = MaxAbsScaler() @register_tool(tool_type=TOOL_TYPE) -class RobustScale(MLProcess): +class RobustScale(DataPreprocessTool): """ Apply the RobustScaler to scale features using statistics that are robust to outliers. """ def __init__(self, features: list): - """ - Initialize the RobustScale instance with feature names. - - Args: - features (list): List of feature names to be scaled. - """ self.features = features - self.rs = None - - def fit(self, df: pd.DataFrame): - """ - Compute the median and IQR for scaling. - - Args: - df (pd.DataFrame): Dataframe containing the features. - """ - self.rs = RobustScaler() - self.rs.fit(df[self.features]) - - def transform(self, df: pd.DataFrame): - """ - Scale features using the previously computed median and IQR. - - Args: - df (pd.DataFrame): Dataframe containing the features to be scaled. - - Returns: - pd.DataFrame: A new dataframe with scaled features. - """ - new_df = df.copy() - new_df[self.features] = self.rs.transform(new_df[self.features]) - return new_df + self.model = RobustScaler() @register_tool(tool_type=TOOL_TYPE) -class OrdinalEncode(MLProcess): +class OrdinalEncode(DataPreprocessTool): """ Encode categorical features as ordinal integers. """ def __init__(self, features: list): - """ - Initialize the OrdinalEncode instance with feature names. - - Args: - features (list): List of categorical feature names to be encoded. - """ self.features = features - self.oe = None - - def fit(self, df: pd.DataFrame): - """ - Learn the ordinal encodings for the features. - - Args: - df (pd.DataFrame): Dataframe containing the categorical features. - """ - self.oe = OrdinalEncoder() - self.oe.fit(df[self.features]) - - def transform(self, df: pd.DataFrame): - """ - Convert the categorical features to ordinal integers. - - Args: - df (pd.DataFrame): Dataframe containing the categorical features to be encoded. - - Returns: - pd.DataFrame: A new dataframe with the encoded features. - """ - new_df = df.copy() - new_df[self.features] = self.oe.transform(new_df[self.features]) - return new_df + self.model = OrdinalEncoder() @register_tool(tool_type=TOOL_TYPE) -class OneHotEncode(MLProcess): +class OneHotEncode(DataPreprocessTool): """ Apply one-hot encoding to specified categorical columns, the original columns will be dropped. """ def __init__(self, features: list): - """ - Initialize self. - - Args: - features (list): Categorical columns to be one-hot encoded and dropped. - """ self.features = features - self.ohe = None - - def fit(self, df: pd.DataFrame): - """ - Fit the OneHotEncoding model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ - self.ohe = OneHotEncoder(handle_unknown="ignore", sparse=False) - self.ohe.fit(df[self.features]) + self.model = OneHotEncoder(handle_unknown="ignore", sparse=False) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ - ts_data = self.ohe.transform(df[self.features]) - new_columns = self.ohe.get_feature_names_out(self.features) + ts_data = self.model.transform(df[self.features]) + new_columns = self.model.get_feature_names_out(self.features) ts_data = pd.DataFrame(ts_data, columns=new_columns, index=df.index) new_df = df.drop(self.features, axis=1) new_df = pd.concat([new_df, ts_data], axis=1) @@ -341,7 +194,7 @@ class OneHotEncode(MLProcess): @register_tool(tool_type=TOOL_TYPE) -class LabelEncode(MLProcess): +class LabelEncode(DataPreprocessTool): """ Apply label encoding to specified categorical columns in-place. """ @@ -357,12 +210,6 @@ class LabelEncode(MLProcess): self.le_encoders = [] def fit(self, df: pd.DataFrame): - """ - Fit the LabelEncode model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ if len(self.features) == 0: return for col in self.features: @@ -370,15 +217,6 @@ class LabelEncode(MLProcess): self.le_encoders.append(le) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ if len(self.features) == 0: return df new_df = df.copy() diff --git a/metagpt/tools/libs/feature_engineering.py b/metagpt/tools/libs/feature_engineering.py index 6de5696d4..bbd16b681 100644 --- a/metagpt/tools/libs/feature_engineering.py +++ b/metagpt/tools/libs/feature_engineering.py @@ -45,12 +45,6 @@ class PolynomialExpansion(MLProcess): self.poly = PolynomialFeatures(degree=degree, include_bias=False) def fit(self, df: pd.DataFrame): - """ - Fit the PolynomialExpansion model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ if len(self.cols) == 0: return if len(self.cols) > 10: @@ -61,15 +55,6 @@ class PolynomialExpansion(MLProcess): self.poly.fit(df[self.cols].fillna(0)) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame without duplicated columns. - """ if len(self.cols) == 0: return df ts_data = self.poly.transform(df[self.cols].fillna(0)) @@ -97,24 +82,9 @@ class CatCount(MLProcess): self.encoder_dict = None def fit(self, df: pd.DataFrame): - """ - Fit the CatCount model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ self.encoder_dict = df[self.col].value_counts().to_dict() def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ new_df = df.copy() new_df[f"{self.col}_cnt"] = new_df[self.col].map(self.encoder_dict) return new_df @@ -139,24 +109,9 @@ class TargetMeanEncoder(MLProcess): self.encoder_dict = None def fit(self, df: pd.DataFrame): - """ - Fit the TargetMeanEncoder model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ self.encoder_dict = df.groupby(self.col)[self.label].mean().to_dict() def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ new_df = df.copy() new_df[f"{self.col}_target_mean"] = new_df[self.col].map(self.encoder_dict) return new_df @@ -185,12 +140,6 @@ class KFoldTargetMeanEncoder(MLProcess): self.encoder_dict = None def fit(self, df: pd.DataFrame): - """ - Fit the KFoldTargetMeanEncoder model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ tmp = df.copy() kf = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.random_state) @@ -203,15 +152,6 @@ class KFoldTargetMeanEncoder(MLProcess): self.encoder_dict = tmp.groupby(self.col)[col_name].mean().to_dict() def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ new_df = df.copy() new_df[f"{self.col}_kf_target_mean"] = new_df[self.col].map(self.encoder_dict) return new_df @@ -255,12 +195,6 @@ class CatCross(MLProcess): return new_col, comb_map def fit(self, df: pd.DataFrame): - """ - Fit the CatCross model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ for col in self.cols: if df[col].nunique() > self.max_cat_num: self.cols.remove(col) @@ -269,15 +203,6 @@ class CatCross(MLProcess): self.combs_map = dict(res) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ new_df = df.copy() for comb in self.combs: new_col = f"{comb[0]}_{comb[1]}" @@ -310,12 +235,6 @@ class GroupStat(MLProcess): self.group_df = None def fit(self, df: pd.DataFrame): - """ - Fit the GroupStat model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ group_df = df.groupby(self.group_col)[self.agg_col].agg(self.agg_funcs).reset_index() group_df.columns = [self.group_col] + [ f"{self.agg_col}_{agg_func}_by_{self.group_col}" for agg_func in self.agg_funcs @@ -323,15 +242,6 @@ class GroupStat(MLProcess): self.group_df = group_df def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ new_df = df.merge(self.group_df, on=self.group_col, how="left") return new_df @@ -355,25 +265,10 @@ class SplitBins(MLProcess): self.encoder = None def fit(self, df: pd.DataFrame): - """ - Fit the SplitBins model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ self.encoder = KBinsDiscretizer(strategy=self.strategy, encode="ordinal") self.encoder.fit(df[self.cols].fillna(0)) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ new_df = df.copy() new_df[self.cols] = self.encoder.transform(new_df[self.cols].fillna(0)) return new_df @@ -397,24 +292,9 @@ class ExtractTimeComps(MLProcess): self.time_comps = time_comps def fit(self, df: pd.DataFrame): - """ - Fit the ExtractTimeComps model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ pass def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame. - """ time_s = pd.to_datetime(df[self.time_col], errors="coerce") time_comps_df = pd.DataFrame() @@ -445,12 +325,6 @@ class GeneralSelection(MLProcess): self.feats = [] def fit(self, df: pd.DataFrame): - """ - Fit the GeneralSelection model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ feats = [f for f in df.columns if f != self.label_col] for col in df.columns: if df[col].isnull().sum() / df.shape[0] == 1: @@ -468,15 +342,6 @@ class GeneralSelection(MLProcess): self.feats = feats def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame contain label_col. - """ new_df = df[self.feats + [self.label_col]] return new_df @@ -501,12 +366,6 @@ class TreeBasedSelection(MLProcess): self.feats = None def fit(self, df: pd.DataFrame): - """ - Fit the TreeBasedSelection model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ params = { "boosting_type": "gbdt", "objective": "binary", @@ -538,15 +397,6 @@ class TreeBasedSelection(MLProcess): self.feats.append(self.label_col) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame contain label_col. - """ new_df = df[self.feats] return new_df @@ -571,12 +421,6 @@ class VarianceBasedSelection(MLProcess): self.selector = VarianceThreshold(threshold=self.threshold) def fit(self, df: pd.DataFrame): - """ - Fit the VarianceBasedSelection model. - - Args: - df (pd.DataFrame): The input DataFrame. - """ num_cols = df.select_dtypes(include=np.number).columns.tolist() cols = [f for f in num_cols if f not in [self.label_col]] @@ -585,14 +429,5 @@ class VarianceBasedSelection(MLProcess): self.feats.append(self.label_col) def transform(self, df: pd.DataFrame) -> pd.DataFrame: - """ - Transform the input DataFrame with the fitted model. - - Args: - df (pd.DataFrame): The input DataFrame. - - Returns: - pd.DataFrame: The transformed DataFrame contain label_col. - """ new_df = df[self.feats] return new_df diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index b8377e67a..417a938e1 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -12,7 +12,8 @@ def convert_code_to_tool_schema(obj, include: list[str] = []): for name, method in inspect.getmembers(obj, inspect.isfunction): if include and name not in include: continue - method_doc = inspect.getdoc(method) + # method_doc = inspect.getdoc(method) + method_doc = get_class_method_docstring(obj, name) if method_doc: schema["methods"][name] = docstring_to_schema(method_doc) @@ -22,8 +23,6 @@ def convert_code_to_tool_schema(obj, include: list[str] = []): **docstring_to_schema(docstring), } - schema = {obj.__name__: schema} - return schema @@ -70,3 +69,13 @@ def docstring_to_schema(docstring: str): schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns] return schema + + +def get_class_method_docstring(cls, method_name): + """Retrieve a method's docstring, searching the class hierarchy if necessary.""" + for base_class in cls.__mro__: + if method_name in base_class.__dict__: + method = base_class.__dict__[method_name] + if method.__doc__: + return method.__doc__ + return None # No docstring found in the class hierarchy diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 5922e7f69..299d62ca3 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -39,7 +39,6 @@ class ToolRegistry(BaseModel): tool_type="other", tool_source_object=None, include_functions=[], - make_schema_if_not_exists=True, verbose=False, ): if self.has_tool(tool_name): @@ -57,19 +56,11 @@ class ToolRegistry(BaseModel): schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml" - if not os.path.exists(schema_path): - if make_schema_if_not_exists: - logger.warning(f"no schema found, will make schema at {schema_path}") - schema_dict = make_schema(tool_source_object, include_functions, schema_path) - else: - logger.warning(f"no schema found at assumed schema_path {schema_path}, skip registering {tool_name}") - return - else: - with open(schema_path, "r", encoding="utf-8") as f: - schema_dict = yaml.safe_load(f) - if not schema_dict: + schemas = make_schema(tool_source_object, include_functions, schema_path) + + if not schemas: return - schemas = schema_dict.get(tool_name) or list(schema_dict.values())[0] + schemas["tool_path"] = tool_path # corresponding code file path of the tool try: ToolSchema(**schemas) # validation @@ -78,11 +69,13 @@ class ToolRegistry(BaseModel): # logger.warning( # f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}" # ) + tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code) self.tools[tool_name] = tool self.tools_by_types[tool_type][tool_name] = tool if verbose: logger.info(f"{tool_name} registered") + logger.info(f"schema made at {str(schema_path)}, can be used for checking") def has_tool(self, key: str) -> Tool: return key in self.tools @@ -107,12 +100,10 @@ class ToolRegistry(BaseModel): TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes) -def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: str = "", **kwargs): +def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs): """register a tool to registry""" - def decorator(cls, tool_name=tool_name): - tool_name = tool_name or cls.__name__ - + def decorator(cls): # Get the file path where the function / class is defined and the source code file_path = inspect.getfile(cls) if "metagpt" in file_path: @@ -120,7 +111,7 @@ def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: st source_code = inspect.getsource(cls) TOOL_REGISTRY.register_tool( - tool_name=tool_name, + tool_name=cls.__name__, tool_path=file_path, schema_path=schema_path, tool_code=source_code, @@ -142,7 +133,6 @@ def make_schema(tool_source_object, include, path): # import json # with open(str(path).replace("yml", "json"), "w", encoding="utf-8") as f: # json.dump(schema, f, ensure_ascii=False, indent=4) - logger.info(f"schema made at {path}") except Exception as e: schema = {} logger.error(f"Fail to make schema: {e}") diff --git a/metagpt/utils/parse_docstring.py b/metagpt/utils/parse_docstring.py index 8a017e1f7..e91be8e75 100644 --- a/metagpt/utils/parse_docstring.py +++ b/metagpt/utils/parse_docstring.py @@ -5,7 +5,7 @@ from pydantic import BaseModel def remove_spaces(text): - return re.sub(r"\s+", " ", text) + return re.sub(r"\s+", " ", text).strip() class DocstringParser(BaseModel): diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py index 1dad997bd..2ae2ea000 100644 --- a/tests/metagpt/tools/test_tool_convert.py +++ b/tests/metagpt/tools/test_tool_convert.py @@ -17,7 +17,7 @@ def test_docstring_to_schema(): pd.DataFrame: The transformed DataFrame. """ expected = { - "description": " Some test desc. ", + "description": "Some test desc.", "parameters": { "properties": { "features": {"type": "list", "description": "Columns to be processed."}, @@ -97,47 +97,45 @@ def dummy_fn(df: pd.DataFrame) -> dict: def test_convert_code_to_tool_schema_class(): expected = { - "DummyClass": { - "type": "class", - "description": "Completing missing values with simple strategies.", - "methods": { - "__init__": { - "description": "Initialize self. ", - "parameters": { - "properties": { - "features": {"type": "list", "description": "Columns to be processed."}, - "strategy": { - "type": "str", - "description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.", - "default": "'mean'", - "enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"], - }, - "fill_value": { - "type": "int", - "description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.", - "default": "None", - }, + "type": "class", + "description": "Completing missing values with simple strategies.", + "methods": { + "__init__": { + "description": "Initialize self.", + "parameters": { + "properties": { + "features": {"type": "list", "description": "Columns to be processed."}, + "strategy": { + "type": "str", + "description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.", + "default": "'mean'", + "enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"], + }, + "fill_value": { + "type": "int", + "description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.", + "default": "None", }, - "required": ["features"], }, - }, - "fit": { - "description": "Fit the FillMissingValue model. ", - "parameters": { - "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}}, - "required": ["df"], - }, - }, - "transform": { - "description": "Transform the input DataFrame with the fitted model. ", - "parameters": { - "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}}, - "required": ["df"], - }, - "returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}], + "required": ["features"], }, }, - } + "fit": { + "description": "Fit the FillMissingValue model.", + "parameters": { + "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}}, + "required": ["df"], + }, + }, + "transform": { + "description": "Transform the input DataFrame with the fitted model.", + "parameters": { + "properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}}, + "required": ["df"], + }, + "returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}], + }, + }, } schema = convert_code_to_tool_schema(DummyClass) assert schema == expected @@ -145,14 +143,12 @@ def test_convert_code_to_tool_schema_class(): def test_convert_code_to_tool_schema_function(): expected = { - "dummy_fn": { - "type": "function", - "description": "Analyzes a DataFrame and categorizes its columns based on data types. ", - "parameters": { - "properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}}, - "required": ["df"], - }, - } + "type": "function", + "description": "Analyzes a DataFrame and categorizes its columns based on data types.", + "parameters": { + "properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}}, + "required": ["df"], + }, } schema = convert_code_to_tool_schema(dummy_fn) assert schema == expected diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py index bb5d7a0bd..e41ddfa79 100644 --- a/tests/metagpt/tools/test_tool_registry.py +++ b/tests/metagpt/tools/test_tool_registry.py @@ -14,18 +14,6 @@ def tool_registry_full(): return ToolRegistry(tool_types=ToolTypes) -@pytest.fixture -def schema_yaml(mocker): - mock_yaml_content = """ - tool_name: - key1: value1 - key2: value2 - """ - mocker.patch("os.path.exists", return_value=True) - mocker.patch("builtins.open", mocker.mock_open(read_data=mock_yaml_content)) - return mocker - - # Test Initialization def test_initialization(tool_registry): assert isinstance(tool_registry, ToolRegistry) @@ -42,33 +30,46 @@ def test_initialize_with_tool_types(tool_registry_full): assert "data_preprocess" in tool_registry_full.tool_types -# Test Tool Registration -def test_register_tool(tool_registry, schema_yaml): - tool_registry.register_tool("TestTool", "/path/to/tool") - assert "TestTool" in tool_registry.tools +class TestClassTool: + """test class""" + + def test_class_fn(self): + """test class fn""" + pass -# Test Tool Registration with Non-existing Schema -def test_register_tool_no_schema(tool_registry, mocker): - mocker.patch("os.path.exists", return_value=False) - tool_registry.register_tool("TestTool", "/path/to/tool") - assert "TestTool" not in tool_registry.tools +def test_fn(): + """test function""" + pass + + +# Test Tool Registration Class +def test_register_tool_class(tool_registry): + tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool) + assert "TestClassTool" in tool_registry.tools + + +# Test Tool Registration Function +def test_register_tool_fn(tool_registry): + tool_registry.register_tool("test_fn", "/path/to/tool", tool_source_object=test_fn) + assert "test_fn" in tool_registry.tools # Test Tool Existence Checks -def test_has_tool(tool_registry, schema_yaml): - tool_registry.register_tool("TestTool", "/path/to/tool") - assert tool_registry.has_tool("TestTool") +def test_has_tool(tool_registry): + tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool) + assert tool_registry.has_tool("TestClassTool") assert not tool_registry.has_tool("NonexistentTool") # Test Tool Retrieval -def test_get_tool(tool_registry, schema_yaml): - tool_registry.register_tool("TestTool", "/path/to/tool") - tool = tool_registry.get_tool("TestTool") +def test_get_tool(tool_registry): + tool_registry.register_tool("TestClassTool", "/path/to/tool", tool_source_object=TestClassTool) + tool = tool_registry.get_tool("TestClassTool") assert tool is not None - assert tool.name == "TestTool" + assert tool.name == "TestClassTool" assert tool.path == "/path/to/tool" + assert "description" in tool.schemas # Similar tests for has_tool_type, get_tool_type, get_tools_by_type @@ -83,12 +84,12 @@ def test_get_tool_type(tool_registry_full): assert retrieved_type.name == "data_preprocess" -def test_get_tools_by_type(tool_registry, schema_yaml): +def test_get_tools_by_type(tool_registry): tool_type_name = "TestType" tool_name = "TestTool" tool_path = "/path/to/tool" - tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name) + tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool) tools_by_type = tool_registry.get_tools_by_type(tool_type_name) assert tools_by_type is not None diff --git a/tests/metagpt/utils/test_save_code.py b/tests/metagpt/utils/test_save_code.py index 5ab08c454..57a19049b 100644 --- a/tests/metagpt/utils/test_save_code.py +++ b/tests/metagpt/utils/test_save_code.py @@ -14,7 +14,7 @@ from metagpt.utils.save_code import DATA_PATH, save_code_file def test_save_code_file_python(): save_code_file("example", "print('Hello, World!')") file_path = DATA_PATH / "output" / "example" / "code.py" - assert file_path.exists, f"File does not exist: {file_path}" + assert file_path.exists(), f"File does not exist: {file_path}" content = file_path.read_text() assert "print('Hello, World!')" in content, "File content does not match" @@ -35,7 +35,7 @@ async def test_save_code_file_notebook(): # Save as a Notebook file save_code_file("example_nb", executor.nb, file_format="ipynb") file_path = DATA_PATH / "output" / "example_nb" / "code.ipynb" - assert file_path.exists, f"Notebook file does not exist: {file_path}" + assert file_path.exists(), f"Notebook file does not exist: {file_path}" # Additional checks specific to notebook format notebook = nbformat.read(file_path, as_version=4)