mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
renaming and integrate sd tool, fix import issue
This commit is contained in:
parent
46cd219e81
commit
d7ab4d315d
8 changed files with 41 additions and 80 deletions
|
|
@ -7,6 +7,13 @@
|
|||
"""
|
||||
|
||||
from enum import Enum
|
||||
from metagpt.tools import tool_types # this registers all tool types
|
||||
from metagpt.tools.functions import libs # this registers all tools
|
||||
from metagpt.tools.tool_registry import TOOL_REGISTRY
|
||||
|
||||
_ = tool_types # Avoid pre-commit error
|
||||
_ = libs # Avoid pre-commit error
|
||||
_ = TOOL_REGISTRY # Avoid pre-commit error
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
|
|
@ -26,62 +33,3 @@ class WebBrowserEngineType(Enum):
|
|||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
||||
|
||||
class ToolType(BaseModel):
|
||||
name: str
|
||||
module: str = ""
|
||||
desc: str
|
||||
usage_prompt: str = ""
|
||||
|
||||
|
||||
TOOL_TYPE_MAPPINGS = {
|
||||
"data_preprocess": ToolType(
|
||||
name="data_preprocess",
|
||||
module=str(TOOL_LIBS_PATH / "data_preprocess"),
|
||||
desc="Only for changing value inplace.",
|
||||
usage_prompt=DATA_PREPROCESS_PROMPT,
|
||||
),
|
||||
"feature_engineering": ToolType(
|
||||
name="feature_engineering",
|
||||
module=str(TOOL_LIBS_PATH / "feature_engineering"),
|
||||
desc="Only for creating new columns for input data.",
|
||||
usage_prompt=FEATURE_ENGINEERING_PROMPT,
|
||||
),
|
||||
"model_train": ToolType(
|
||||
name="model_train",
|
||||
module="",
|
||||
desc="Only for training model.",
|
||||
usage_prompt=MODEL_TRAIN_PROMPT,
|
||||
),
|
||||
"model_evaluate": ToolType(
|
||||
name="model_evaluate",
|
||||
module="",
|
||||
desc="Only for evaluating model.",
|
||||
usage_prompt=MODEL_EVALUATE_PROMPT,
|
||||
),
|
||||
"stable_diffusion": ToolType(
|
||||
name="stable_diffusion",
|
||||
module="metagpt.tools.sd_engine",
|
||||
desc="Related to text2image, image2image using stable diffusion model.",
|
||||
usage_prompt="",
|
||||
),
|
||||
"scrape_web": ToolType(
|
||||
name="scrape_web",
|
||||
module="metagpt.tools.functions.libs.scrape_web.scrape_web",
|
||||
desc="Scrape data from web page.",
|
||||
usage_prompt="",
|
||||
),
|
||||
"vision": ToolType(
|
||||
name="vision",
|
||||
module=str(TOOL_LIBS_PATH / "vision"),
|
||||
desc="Only for converting image into webpage code.",
|
||||
usage_prompt=VISION_PROMPT,
|
||||
),
|
||||
"other": ToolType(
|
||||
name="other",
|
||||
module="",
|
||||
desc="Any tasks that do not fit into the previous categories",
|
||||
usage_prompt="",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,3 +4,10 @@
|
|||
# @Author : lidanyang
|
||||
# @File : __init__.py
|
||||
# @Desc :
|
||||
from metagpt.tools.functions.libs import (
|
||||
data_preprocess,
|
||||
feature_engineering,
|
||||
)
|
||||
|
||||
_ = data_preprocess # Avoid pre-commit error
|
||||
_ = feature_engineering # Avoid pre-commit error
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from sklearn.preprocessing import (
|
|||
)
|
||||
|
||||
from metagpt.tools.functions.libs.base import MLProcess
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_schema import ToolTypeEnum
|
||||
|
||||
TOOL_TYPE = ToolTypeEnum.DATA_PREPROCESS.value
|
||||
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ from sklearn.model_selection import KFold
|
|||
from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
|
||||
|
||||
from metagpt.tools.functions.libs.base import MLProcess
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_schema import ToolTypeEnum
|
||||
|
||||
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from PIL import Image, PngImagePlugin
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
|
|
@ -51,6 +53,7 @@ payload = {
|
|||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
||||
|
||||
@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION)
|
||||
class SDEngine:
|
||||
def __init__(self, sd_url=""):
|
||||
# Initialize the SDEngine with configuration
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class ToolTypeEnum(Enum):
|
|||
FEATURE_ENGINEERING = "feature_engineering"
|
||||
MODEL_TRAIN = "model_train"
|
||||
MODEL_EVALUATE = "model_evaluate"
|
||||
STABLE_DIFFUSION = "stable_diffusion"
|
||||
OTHER = "other"
|
||||
|
||||
def __missing__(self, key):
|
||||
|
|
@ -5,28 +5,27 @@
|
|||
@Author : garylin2099
|
||||
@File : tool_registry.py
|
||||
"""
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
from metagpt.tools.tool_schema import ToolType, ToolSchema, Tool
|
||||
from metagpt.logs import logger
|
||||
from metagpt.const import TOOL_SCHEMA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
def __init__(self):
|
||||
self.tools = {}
|
||||
self.tool_types = {}
|
||||
self.tools_by_types = defaultdict(
|
||||
dict
|
||||
) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
|
||||
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
|
||||
|
||||
def register_tool_type(self, tool_type: ToolType):
|
||||
self.tool_types[tool_type.name] = tool_type
|
||||
logger.info(f"{tool_type.name} registered")
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
|
|
@ -55,7 +54,7 @@ class ToolRegistry:
|
|||
schema["tool_path"] = tool_path # corresponding code file path of the tool
|
||||
try:
|
||||
ToolSchema(**schema) # validation
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
# logger.warning(
|
||||
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
|
||||
|
|
@ -67,19 +66,19 @@ class ToolRegistry:
|
|||
|
||||
def has_tool(self, key):
|
||||
return key in self.tools
|
||||
|
||||
|
||||
def get_tool(self, key):
|
||||
return self.tools.get(key)
|
||||
|
||||
|
||||
def get_tools_by_type(self, key):
|
||||
return self.tools_by_types.get(key)
|
||||
|
||||
|
||||
def has_tool_type(self, key):
|
||||
return key in self.tool_types
|
||||
|
||||
def get_tool_type(self, key):
|
||||
return self.tool_types.get(key)
|
||||
|
||||
|
||||
def get_tool_types(self):
|
||||
return self.tool_types
|
||||
|
||||
|
|
@ -99,7 +98,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
|
|||
|
||||
def decorator(cls, tool_name=tool_name):
|
||||
tool_name = tool_name or cls.__name__
|
||||
|
||||
|
||||
# Get the file path where the function / class is defined and the source code
|
||||
file_path = inspect.getfile(cls)
|
||||
if "metagpt" in file_path:
|
||||
|
|
@ -119,9 +118,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
|
|||
|
||||
|
||||
def make_schema(tool_code, path):
|
||||
os.makedirs(
|
||||
os.path.dirname(path), exist_ok=True
|
||||
) # Create the necessary directories
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True) # Create the necessary directories
|
||||
schema = {} # an empty schema for now
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(schema, f)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from metagpt.prompts.tool_type import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
from metagpt.tools.tool_schema import ToolTypeEnum, ToolType
|
||||
from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool_type
|
||||
|
||||
|
||||
|
|
@ -36,8 +36,13 @@ class ModelEvaluate(ToolType):
|
|||
usage_prompt: str = MODEL_EVALUATE_PROMPT
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class StableDiffusion(ToolType):
|
||||
name: str = ToolTypeEnum.STABLE_DIFFUSION.value
|
||||
desc: str = "Related to text2image, image2image using stable diffusion model."
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class Other(ToolType):
|
||||
name: str = ToolTypeEnum.OTHER.value
|
||||
desc: str = "Any tools not in the defined categories"
|
||||
usage_prompt: str = ""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue