remove ToolTypesEnum

This commit is contained in:
yzlin 2024-02-02 17:57:49 +08:00
parent 35438e7b03
commit 1da50f1825
11 changed files with 109 additions and 109 deletions

View file

@ -3,7 +3,7 @@ from metagpt.actions.execute_nb_code import ExecuteNbCode
from metagpt.actions.ml_action import UpdateDataColumns, WriteCodeWithToolsML
from metagpt.logs import logger
from metagpt.roles.code_interpreter import CodeInterpreter
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_types import ToolTypes
from metagpt.utils.common import any_to_str
@ -51,9 +51,9 @@ class MLEngineer(CodeInterpreter):
async def _update_data_columns(self):
current_task = self.planner.plan.current_task
if current_task.task_type not in [
ToolTypeEnum.DATA_PREPROCESS.value,
ToolTypeEnum.FEATURE_ENGINEERING.value,
ToolTypeEnum.MODEL_TRAIN.value,
ToolTypes.DATA_PREPROCESS.type_name,
ToolTypes.FEATURE_ENGINEERING.type_name,
ToolTypes.MODEL_TRAIN.type_name,
]:
return ""
logger.info("Check columns in updated data")

View file

@ -13,10 +13,10 @@ from sklearn.preprocessing import (
StandardScaler,
)
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_types import ToolTypes
TOOL_TYPE = ToolTypeEnum.DATA_PREPROCESS.value
TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name
class MLProcess(object):

View file

@ -16,10 +16,10 @@ from sklearn.model_selection import KFold
from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
from metagpt.tools.libs.data_preprocess import MLProcess
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_types import ToolTypes
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value
TOOL_TYPE = ToolTypes.FEATURE_ENGINEERING.type_name
@register_tool(tool_type=TOOL_TYPE)

View file

@ -12,8 +12,8 @@ from pathlib import Path
import requests
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_types import ToolTypes
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
@ -30,7 +30,7 @@ As the design pays tribute to large companies, sometimes it is normal for some c
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value)
@register_tool(tool_type=ToolTypes.IMAGE2WEBPAGE.type_name)
class GPTvGenerator:
def __init__(self):
from metagpt.config2 import config

View file

@ -16,8 +16,8 @@ from PIL import Image, PngImagePlugin
#
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
from metagpt.logs import logger
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_types import ToolTypes
payload = {
"prompt": "",
@ -53,7 +53,7 @@ payload = {
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value)
@register_tool(tool_type=ToolTypes.STABLE_DIFFUSION.type_name)
class SDEngine:
def __init__(self, sd_url=""):
# Initialize the SDEngine with configuration

View file

@ -1,9 +1,9 @@
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_types import ToolTypes
from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper
@register_tool(tool_type=ToolTypeEnum.WEBSCRAPING.value)
@register_tool(tool_type=ToolTypes.WEBSCRAPING.type_name)
async def scrape_web_playwright(url, *urls):
"""
Scrape and save the HTML structure and inner text content of a web page using Playwright.

View file

@ -1,26 +1,9 @@
from enum import Enum
from pydantic import BaseModel
class ToolTypeEnum(Enum):
EDA = "eda"
DATA_PREPROCESS = "data_preprocess"
FEATURE_ENGINEERING = "feature_engineering"
MODEL_TRAIN = "model_train"
MODEL_EVALUATE = "model_evaluate"
STABLE_DIFFUSION = "stable_diffusion"
IMAGE2WEBPAGE = "image2webpage"
WEBSCRAPING = "web_scraping"
OTHER = "other"
def __missing__(self, key):
return self.OTHER
class ToolType(BaseModel):
name: str
desc: str
desc: str = ""
usage_prompt: str = ""

View file

@ -11,12 +11,13 @@ import re
from collections import defaultdict
import yaml
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from metagpt.const import TOOL_SCHEMA_PATH
from metagpt.logs import logger
from metagpt.tools.tool_convert import convert_code_to_tool_schema
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType
from metagpt.tools.tool_types import ToolTypes
class ToolRegistry(BaseModel):
@ -24,16 +25,16 @@ class ToolRegistry(BaseModel):
tool_types: dict = {}
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
def register_tool_type(self, tool_type: ToolType, verbose: bool = False):
self.tool_types[tool_type.name] = tool_type
if verbose:
logger.info(f"tool type {tool_type.name} registered")
@field_validator("tool_types", mode="before")
@classmethod
def init_tool_types(cls, tool_types: ToolTypes):
return {tool_type.type_name: tool_type.value for tool_type in tool_types}
def register_tool(
self,
tool_name,
tool_path,
schema_path=None,
schema_path="",
tool_code="",
tool_type="other",
tool_source_object=None,
@ -44,6 +45,16 @@ class ToolRegistry(BaseModel):
if self.has_tool(tool_name):
return
if tool_type not in self.tool_types:
# register new tool type on the fly
logger.warning(
f"{tool_type} not previously defined, will create a temporary ToolType with just a name. This ToolType is only effective during this runtime. You may consider add this ToolType with more configs permanently at metagpt.tools.tool_types"
)
temp_tool_type_obj = ToolType(name=tool_type)
self.tool_types[tool_type] = temp_tool_type_obj
if verbose:
logger.info(f"tool type {tool_type} registered")
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
if not os.path.exists(schema_path):
@ -93,16 +104,10 @@ class ToolRegistry(BaseModel):
# Registry instance
TOOL_REGISTRY = ToolRegistry()
TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes)
def register_tool_type(cls):
"""register a tool type to registry"""
TOOL_REGISTRY.register_tool_type(tool_type=cls())
return cls
def register_tool(tool_name="", tool_type="other", schema_path=None, **kwargs):
def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: str = "", **kwargs):
"""register a tool to registry"""
def decorator(cls, tool_name=tool_name):

View file

@ -1,3 +1,5 @@
from enum import Enum
from metagpt.prompts.tool_types import (
DATA_PREPROCESS_PROMPT,
FEATURE_ENGINEERING_PROMPT,
@ -5,64 +7,74 @@ from metagpt.prompts.tool_types import (
MODEL_EVALUATE_PROMPT,
MODEL_TRAIN_PROMPT,
)
from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum
from metagpt.tools.tool_registry import register_tool_type
from metagpt.tools.tool_data_type import ToolType
Eda = ToolType(name="eda", desc="For performing exploratory data analysis")
DataPreprocess = ToolType(
name="data_preprocess",
desc="Only for changing value inplace.",
usage_prompt=DATA_PREPROCESS_PROMPT,
)
@register_tool_type
class EDA(ToolType):
name: str = ToolTypeEnum.EDA.value
desc: str = "For performing exploratory data analysis"
FeatureEngineering = ToolType(
name="feature_engineering",
desc="Only for creating new columns for input data.",
usage_prompt=FEATURE_ENGINEERING_PROMPT,
)
@register_tool_type
class DataPreprocess(ToolType):
name: str = ToolTypeEnum.DATA_PREPROCESS.value
desc: str = "Only for changing value inplace."
usage_prompt: str = DATA_PREPROCESS_PROMPT
ModelTrain = ToolType(
name="model_train",
desc="Only for training model.",
usage_prompt=MODEL_TRAIN_PROMPT,
)
@register_tool_type
class FeatureEngineer(ToolType):
name: str = ToolTypeEnum.FEATURE_ENGINEERING.value
desc: str = "Only for creating new columns for input data."
usage_prompt: str = FEATURE_ENGINEERING_PROMPT
ModelEvaluate = ToolType(
name="model_evaluate",
desc="Only for evaluating model.",
usage_prompt=MODEL_EVALUATE_PROMPT,
)
@register_tool_type
class ModelTrain(ToolType):
name: str = ToolTypeEnum.MODEL_TRAIN.value
desc: str = "Only for training model."
usage_prompt: str = MODEL_TRAIN_PROMPT
StableDiffusion = ToolType(
name="stable_diffusion",
desc="Related to text2image, image2image using stable diffusion model.",
)
@register_tool_type
class ModelEvaluate(ToolType):
name: str = ToolTypeEnum.MODEL_EVALUATE.value
desc: str = "Only for evaluating model."
usage_prompt: str = MODEL_EVALUATE_PROMPT
Image2Webpage = ToolType(
name="image2webpage",
desc="For converting image into webpage code.",
usage_prompt=IMAGE2WEBPAGE_PROMPT,
)
@register_tool_type
class StableDiffusion(ToolType):
name: str = ToolTypeEnum.STABLE_DIFFUSION.value
desc: str = "Related to text2image, image2image using stable diffusion model."
WebScraping = ToolType(
name="web_scraping",
desc="For scraping data from web pages.",
)
@register_tool_type
class Image2Webpage(ToolType):
name: str = ToolTypeEnum.IMAGE2WEBPAGE.value
desc: str = "For converting image into webpage code."
usage_prompt: str = IMAGE2WEBPAGE_PROMPT
Other = ToolType(name="other", desc="Any tools not in the defined categories")
@register_tool_type
class WebScraping(ToolType):
name: str = ToolTypeEnum.WEBSCRAPING.value
desc: str = "For scraping data from web pages."
class ToolTypes(Enum):
EDA = Eda
DATA_PREPROCESS = DataPreprocess
FEATURE_ENGINEERING = FeatureEngineering
MODEL_TRAIN = ModelTrain
MODEL_EVALUATE = ModelEvaluate
STABLE_DIFFUSION = StableDiffusion
IMAGE2WEBPAGE = Image2Webpage
WEBSCRAPING = WebScraping
OTHER = Other
def __missing__(self, key):
return self.OTHER
@register_tool_type
class Other(ToolType):
name: str = ToolTypeEnum.OTHER.value
desc: str = "Any tools not in the defined categories"
@property
def type_name(self):
return self.value.name

View file

@ -4,7 +4,7 @@ from metagpt.actions.execute_nb_code import ExecuteNbCode
from metagpt.logs import logger
from metagpt.roles.ml_engineer import MLEngineer
from metagpt.schema import Message, Plan, Task
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_types import ToolTypes
from tests.metagpt.actions.test_debug_code import CODE, DebugContext, ErrorStr
@ -63,7 +63,7 @@ async def test_mle_update_data_columns(mocker):
mle.planner.plan = MockPlan
# manually update task type to test update
mle.planner.plan.current_task.task_type = ToolTypeEnum.DATA_PREPROCESS.value
mle.planner.plan.current_task.task_type = ToolTypes.DATA_PREPROCESS.value
result = await mle._update_data_columns()
assert result is not None

View file

@ -1,7 +1,7 @@
import pytest
from metagpt.tools.tool_registry import ToolRegistry
from metagpt.tools.tool_types import ToolType
from metagpt.tools.tool_types import ToolTypes
@pytest.fixture
@ -9,6 +9,11 @@ def tool_registry():
return ToolRegistry()
@pytest.fixture
def tool_registry_full():
return ToolRegistry(tool_types=ToolTypes)
@pytest.fixture
def schema_yaml(mocker):
mock_yaml_content = """
@ -29,11 +34,12 @@ def test_initialization(tool_registry):
assert tool_registry.tools_by_types == {}
# Test Tool Type Registration
def test_register_tool_type(tool_registry):
tool_type = ToolType(name="TestType", desc="test")
tool_registry.register_tool_type(tool_type)
assert "TestType" in tool_registry.tool_types
# Test Initialization with tool types
def test_initialize_with_tool_types(tool_registry_full):
assert isinstance(tool_registry_full, ToolRegistry)
assert tool_registry_full.tools == {}
assert tool_registry_full.tools_by_types == {}
assert "data_preprocess" in tool_registry_full.tool_types
# Test Tool Registration
@ -66,27 +72,21 @@ def test_get_tool(tool_registry, schema_yaml):
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
def test_has_tool_type(tool_registry):
tool_type = ToolType(name="TestType", desc="test")
tool_registry.register_tool_type(tool_type)
assert tool_registry.has_tool_type("TestType")
assert not tool_registry.has_tool_type("NonexistentType")
def test_has_tool_type(tool_registry_full):
assert tool_registry_full.has_tool_type("data_preprocess")
assert not tool_registry_full.has_tool_type("NonexistentType")
def test_get_tool_type(tool_registry):
tool_type = ToolType(name="TestType", desc="test")
tool_registry.register_tool_type(tool_type)
retrieved_type = tool_registry.get_tool_type("TestType")
def test_get_tool_type(tool_registry_full):
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
assert retrieved_type is not None
assert retrieved_type.name == "TestType"
assert retrieved_type.name == "data_preprocess"
def test_get_tools_by_type(tool_registry, schema_yaml):
tool_type_name = "TestType"
tool_name = "TestTool"
tool_path = "/path/to/tool"
tool_type = ToolType(name=tool_type_name, desc="test")
tool_registry.register_tool_type(tool_type)
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)