diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index f743d63c7..4ca46fc89 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -7,6 +7,13 @@ """ from enum import Enum +from metagpt.tools import tool_types # this registers all tool types +from metagpt.tools.functions import libs # this registers all tools +from metagpt.tools.tool_registry import TOOL_REGISTRY + +_ = tool_types # Avoid pre-commit error +_ = libs # Avoid pre-commit error +_ = TOOL_REGISTRY # Avoid pre-commit error class SearchEngineType(Enum): @@ -26,62 +33,3 @@ class WebBrowserEngineType(Enum): def __missing__(cls, key): """Default type conversion""" return cls.CUSTOM - - -class ToolType(BaseModel): - name: str - module: str = "" - desc: str - usage_prompt: str = "" - - -TOOL_TYPE_MAPPINGS = { - "data_preprocess": ToolType( - name="data_preprocess", - module=str(TOOL_LIBS_PATH / "data_preprocess"), - desc="Only for changing value inplace.", - usage_prompt=DATA_PREPROCESS_PROMPT, - ), - "feature_engineering": ToolType( - name="feature_engineering", - module=str(TOOL_LIBS_PATH / "feature_engineering"), - desc="Only for creating new columns for input data.", - usage_prompt=FEATURE_ENGINEERING_PROMPT, - ), - "model_train": ToolType( - name="model_train", - module="", - desc="Only for training model.", - usage_prompt=MODEL_TRAIN_PROMPT, - ), - "model_evaluate": ToolType( - name="model_evaluate", - module="", - desc="Only for evaluating model.", - usage_prompt=MODEL_EVALUATE_PROMPT, - ), - "stable_diffusion": ToolType( - name="stable_diffusion", - module="metagpt.tools.sd_engine", - desc="Related to text2image, image2image using stable diffusion model.", - usage_prompt="", - ), - "scrape_web": ToolType( - name="scrape_web", - module="metagpt.tools.functions.libs.scrape_web.scrape_web", - desc="Scrape data from web page.", - usage_prompt="", - ), - "vision": ToolType( - name="vision", - module=str(TOOL_LIBS_PATH / "vision"), - desc="Only for converting image into webpage code.", - usage_prompt=VISION_PROMPT, - ), - "other": ToolType( - name="other", - module="", - desc="Any tasks that do not fit into the previous categories", - usage_prompt="", - ), -} diff --git a/metagpt/tools/functions/libs/__init__.py b/metagpt/tools/functions/libs/__init__.py index a0a43f507..f0a61a7d9 100644 --- a/metagpt/tools/functions/libs/__init__.py +++ b/metagpt/tools/functions/libs/__init__.py @@ -4,3 +4,10 @@ # @Author : lidanyang # @File : __init__.py # @Desc : +from metagpt.tools.functions.libs import ( + data_preprocess, + feature_engineering, +) + +_ = data_preprocess # Avoid pre-commit error +_ = feature_engineering # Avoid pre-commit error diff --git a/metagpt/tools/functions/libs/data_preprocess.py b/metagpt/tools/functions/libs/data_preprocess.py index 59ede3ffc..019ffd34e 100644 --- a/metagpt/tools/functions/libs/data_preprocess.py +++ b/metagpt/tools/functions/libs/data_preprocess.py @@ -14,8 +14,8 @@ from sklearn.preprocessing import ( ) from metagpt.tools.functions.libs.base import MLProcess +from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_schema import ToolTypeEnum TOOL_TYPE = ToolTypeEnum.DATA_PREPROCESS.value diff --git a/metagpt/tools/functions/libs/feature_engineering.py b/metagpt/tools/functions/libs/feature_engineering.py index 8b96cbd07..cd03592a6 100644 --- a/metagpt/tools/functions/libs/feature_engineering.py +++ b/metagpt/tools/functions/libs/feature_engineering.py @@ -16,8 +16,8 @@ from sklearn.model_selection import KFold from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures from metagpt.tools.functions.libs.base import MLProcess +from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_schema import ToolTypeEnum TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index ba61fd496..2e3f36ef8 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -16,6 +16,8 @@ from PIL import Image, PngImagePlugin from metagpt.config import CONFIG from metagpt.const import SD_OUTPUT_FILE_REPO from metagpt.logs import logger +from metagpt.tools.tool_data_type import ToolTypeEnum +from metagpt.tools.tool_registry import register_tool payload = { "prompt": "", @@ -51,6 +53,7 @@ payload = { default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" +@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION) class SDEngine: def __init__(self, sd_url=""): # Initialize the SDEngine with configuration diff --git a/metagpt/tools/tool_schema.py b/metagpt/tools/tool_data_type.py similarity index 92% rename from metagpt/tools/tool_schema.py rename to metagpt/tools/tool_data_type.py index 2b90996e5..c767fef9b 100644 --- a/metagpt/tools/tool_schema.py +++ b/metagpt/tools/tool_data_type.py @@ -8,6 +8,7 @@ class ToolTypeEnum(Enum): FEATURE_ENGINEERING = "feature_engineering" MODEL_TRAIN = "model_train" MODEL_EVALUATE = "model_evaluate" + STABLE_DIFFUSION = "stable_diffusion" OTHER = "other" def __missing__(self, key): diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 201c63c71..e6519bba9 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -5,28 +5,27 @@ @Author : garylin2099 @File : tool_registry.py """ -import os -from collections import defaultdict import inspect +import os import re +from collections import defaultdict import yaml -from metagpt.tools.tool_schema import ToolType, ToolSchema, Tool -from metagpt.logs import logger from metagpt.const import TOOL_SCHEMA_PATH +from metagpt.logs import logger +from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType class ToolRegistry: def __init__(self): self.tools = {} self.tool_types = {} - self.tools_by_types = defaultdict( - dict - ) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...} + self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...} def register_tool_type(self, tool_type: ToolType): self.tool_types[tool_type.name] = tool_type + logger.info(f"{tool_type.name} registered") def register_tool( self, @@ -55,7 +54,7 @@ class ToolRegistry: schema["tool_path"] = tool_path # corresponding code file path of the tool try: ToolSchema(**schema) # validation - except Exception as e: + except Exception: pass # logger.warning( # f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}" @@ -67,19 +66,19 @@ class ToolRegistry: def has_tool(self, key): return key in self.tools - + def get_tool(self, key): return self.tools.get(key) - + def get_tools_by_type(self, key): return self.tools_by_types.get(key) - + def has_tool_type(self, key): return key in self.tool_types def get_tool_type(self, key): return self.tool_types.get(key) - + def get_tool_types(self): return self.tool_types @@ -99,7 +98,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None): def decorator(cls, tool_name=tool_name): tool_name = tool_name or cls.__name__ - + # Get the file path where the function / class is defined and the source code file_path = inspect.getfile(cls) if "metagpt" in file_path: @@ -119,9 +118,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None): def make_schema(tool_code, path): - os.makedirs( - os.path.dirname(path), exist_ok=True - ) # Create the necessary directories + os.makedirs(os.path.dirname(path), exist_ok=True) # Create the necessary directories schema = {} # an empty schema for now with open(path, "w", encoding="utf-8") as f: yaml.dump(schema, f) diff --git a/metagpt/tools/tool_types.py b/metagpt/tools/tool_types.py index 9104f90b8..97eb574da 100644 --- a/metagpt/tools/tool_types.py +++ b/metagpt/tools/tool_types.py @@ -1,10 +1,10 @@ from metagpt.prompts.tool_type import ( DATA_PREPROCESS_PROMPT, FEATURE_ENGINEERING_PROMPT, - MODEL_TRAIN_PROMPT, MODEL_EVALUATE_PROMPT, + MODEL_TRAIN_PROMPT, ) -from metagpt.tools.tool_schema import ToolTypeEnum, ToolType +from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum from metagpt.tools.tool_registry import register_tool_type @@ -36,8 +36,13 @@ class ModelEvaluate(ToolType): usage_prompt: str = MODEL_EVALUATE_PROMPT +@register_tool_type +class StableDiffusion(ToolType): + name: str = ToolTypeEnum.STABLE_DIFFUSION.value + desc: str = "Related to text2image, image2image using stable diffusion model." + + @register_tool_type class Other(ToolType): name: str = ToolTypeEnum.OTHER.value desc: str = "Any tools not in the defined categories" - usage_prompt: str = ""