renaming and integrate sd tool, fix import issue

This commit is contained in:
yzlin 2024-01-13 12:28:52 +08:00
parent 46cd219e81
commit d7ab4d315d
8 changed files with 41 additions and 80 deletions

View file

@ -7,6 +7,13 @@
"""
from enum import Enum
from metagpt.tools import tool_types # this registers all tool types
from metagpt.tools.functions import libs # this registers all tools
from metagpt.tools.tool_registry import TOOL_REGISTRY
_ = tool_types # Avoid pre-commit error
_ = libs # Avoid pre-commit error
_ = TOOL_REGISTRY # Avoid pre-commit error
class SearchEngineType(Enum):
@ -26,62 +33,3 @@ class WebBrowserEngineType(Enum):
def __missing__(cls, key):
"""Default type conversion"""
return cls.CUSTOM
class ToolType(BaseModel):
name: str
module: str = ""
desc: str
usage_prompt: str = ""
TOOL_TYPE_MAPPINGS = {
"data_preprocess": ToolType(
name="data_preprocess",
module=str(TOOL_LIBS_PATH / "data_preprocess"),
desc="Only for changing value inplace.",
usage_prompt=DATA_PREPROCESS_PROMPT,
),
"feature_engineering": ToolType(
name="feature_engineering",
module=str(TOOL_LIBS_PATH / "feature_engineering"),
desc="Only for creating new columns for input data.",
usage_prompt=FEATURE_ENGINEERING_PROMPT,
),
"model_train": ToolType(
name="model_train",
module="",
desc="Only for training model.",
usage_prompt=MODEL_TRAIN_PROMPT,
),
"model_evaluate": ToolType(
name="model_evaluate",
module="",
desc="Only for evaluating model.",
usage_prompt=MODEL_EVALUATE_PROMPT,
),
"stable_diffusion": ToolType(
name="stable_diffusion",
module="metagpt.tools.sd_engine",
desc="Related to text2image, image2image using stable diffusion model.",
usage_prompt="",
),
"scrape_web": ToolType(
name="scrape_web",
module="metagpt.tools.functions.libs.scrape_web.scrape_web",
desc="Scrape data from web page.",
usage_prompt="",
),
"vision": ToolType(
name="vision",
module=str(TOOL_LIBS_PATH / "vision"),
desc="Only for converting image into webpage code.",
usage_prompt=VISION_PROMPT,
),
"other": ToolType(
name="other",
module="",
desc="Any tasks that do not fit into the previous categories",
usage_prompt="",
),
}

View file

@ -4,3 +4,10 @@
# @Author : lidanyang
# @File : __init__.py
# @Desc :
from metagpt.tools.functions.libs import (
data_preprocess,
feature_engineering,
)
_ = data_preprocess # Avoid pre-commit error
_ = feature_engineering # Avoid pre-commit error

View file

@ -14,8 +14,8 @@ from sklearn.preprocessing import (
)
from metagpt.tools.functions.libs.base import MLProcess
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_schema import ToolTypeEnum
TOOL_TYPE = ToolTypeEnum.DATA_PREPROCESS.value

View file

@ -16,8 +16,8 @@ from sklearn.model_selection import KFold
from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
from metagpt.tools.functions.libs.base import MLProcess
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_schema import ToolTypeEnum
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value

View file

@ -16,6 +16,8 @@ from PIL import Image, PngImagePlugin
from metagpt.config import CONFIG
from metagpt.const import SD_OUTPUT_FILE_REPO
from metagpt.logs import logger
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
payload = {
"prompt": "",
@ -51,6 +53,7 @@ payload = {
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION)
class SDEngine:
def __init__(self, sd_url=""):
# Initialize the SDEngine with configuration

View file

@ -8,6 +8,7 @@ class ToolTypeEnum(Enum):
FEATURE_ENGINEERING = "feature_engineering"
MODEL_TRAIN = "model_train"
MODEL_EVALUATE = "model_evaluate"
STABLE_DIFFUSION = "stable_diffusion"
OTHER = "other"
def __missing__(self, key):

View file

@ -5,28 +5,27 @@
@Author : garylin2099
@File : tool_registry.py
"""
import os
from collections import defaultdict
import inspect
import os
import re
from collections import defaultdict
import yaml
from metagpt.tools.tool_schema import ToolType, ToolSchema, Tool
from metagpt.logs import logger
from metagpt.const import TOOL_SCHEMA_PATH
from metagpt.logs import logger
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType
class ToolRegistry:
def __init__(self):
self.tools = {}
self.tool_types = {}
self.tools_by_types = defaultdict(
dict
) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
def register_tool_type(self, tool_type: ToolType):
self.tool_types[tool_type.name] = tool_type
logger.info(f"{tool_type.name} registered")
def register_tool(
self,
@ -55,7 +54,7 @@ class ToolRegistry:
schema["tool_path"] = tool_path # corresponding code file path of the tool
try:
ToolSchema(**schema) # validation
except Exception as e:
except Exception:
pass
# logger.warning(
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
@ -67,19 +66,19 @@ class ToolRegistry:
def has_tool(self, key):
return key in self.tools
def get_tool(self, key):
return self.tools.get(key)
def get_tools_by_type(self, key):
return self.tools_by_types.get(key)
def has_tool_type(self, key):
return key in self.tool_types
def get_tool_type(self, key):
return self.tool_types.get(key)
def get_tool_types(self):
return self.tool_types
@ -99,7 +98,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
def decorator(cls, tool_name=tool_name):
tool_name = tool_name or cls.__name__
# Get the file path where the function / class is defined and the source code
file_path = inspect.getfile(cls)
if "metagpt" in file_path:
@ -119,9 +118,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
def make_schema(tool_code, path):
os.makedirs(
os.path.dirname(path), exist_ok=True
) # Create the necessary directories
os.makedirs(os.path.dirname(path), exist_ok=True) # Create the necessary directories
schema = {} # an empty schema for now
with open(path, "w", encoding="utf-8") as f:
yaml.dump(schema, f)

View file

@ -1,10 +1,10 @@
from metagpt.prompts.tool_type import (
DATA_PREPROCESS_PROMPT,
FEATURE_ENGINEERING_PROMPT,
MODEL_TRAIN_PROMPT,
MODEL_EVALUATE_PROMPT,
MODEL_TRAIN_PROMPT,
)
from metagpt.tools.tool_schema import ToolTypeEnum, ToolType
from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum
from metagpt.tools.tool_registry import register_tool_type
@ -36,8 +36,13 @@ class ModelEvaluate(ToolType):
usage_prompt: str = MODEL_EVALUATE_PROMPT
@register_tool_type
class StableDiffusion(ToolType):
name: str = ToolTypeEnum.STABLE_DIFFUSION.value
desc: str = "Related to text2image, image2image using stable diffusion model."
@register_tool_type
class Other(ToolType):
name: str = ToolTypeEnum.OTHER.value
desc: str = "Any tools not in the defined categories"
usage_prompt: str = ""