From 1da50f1825bcea713eb03b5075b5b6a4209751fa Mon Sep 17 00:00:00 2001 From: yzlin Date: Fri, 2 Feb 2024 17:57:49 +0800 Subject: [PATCH] remove ToolTypesEnum --- metagpt/roles/ml_engineer.py | 8 +- metagpt/tools/libs/data_preprocess.py | 4 +- metagpt/tools/libs/feature_engineering.py | 4 +- metagpt/tools/libs/gpt_v_generator.py | 4 +- metagpt/tools/libs/sd_engine.py | 4 +- metagpt/tools/libs/web_scraping.py | 4 +- metagpt/tools/tool_data_type.py | 19 +---- metagpt/tools/tool_registry.py | 33 ++++---- metagpt/tools/tool_types.py | 98 +++++++++++++---------- tests/metagpt/roles/test_ml_engineer.py | 4 +- tests/metagpt/tools/test_tool_registry.py | 36 ++++----- 11 files changed, 109 insertions(+), 109 deletions(-) diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index 633c3306c..9d222b0bf 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -3,7 +3,7 @@ from metagpt.actions.execute_nb_code import ExecuteNbCode from metagpt.actions.ml_action import UpdateDataColumns, WriteCodeWithToolsML from metagpt.logs import logger from metagpt.roles.code_interpreter import CodeInterpreter -from metagpt.tools.tool_data_type import ToolTypeEnum +from metagpt.tools.tool_types import ToolTypes from metagpt.utils.common import any_to_str @@ -51,9 +51,9 @@ class MLEngineer(CodeInterpreter): async def _update_data_columns(self): current_task = self.planner.plan.current_task if current_task.task_type not in [ - ToolTypeEnum.DATA_PREPROCESS.value, - ToolTypeEnum.FEATURE_ENGINEERING.value, - ToolTypeEnum.MODEL_TRAIN.value, + ToolTypes.DATA_PREPROCESS.type_name, + ToolTypes.FEATURE_ENGINEERING.type_name, + ToolTypes.MODEL_TRAIN.type_name, ]: return "" logger.info("Check columns in updated data") diff --git a/metagpt/tools/libs/data_preprocess.py b/metagpt/tools/libs/data_preprocess.py index 0480e71a7..307a6bc5b 100644 --- a/metagpt/tools/libs/data_preprocess.py +++ b/metagpt/tools/libs/data_preprocess.py @@ -13,10 +13,10 @@ from sklearn.preprocessing import ( StandardScaler, ) -from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool +from metagpt.tools.tool_types import ToolTypes -TOOL_TYPE = ToolTypeEnum.DATA_PREPROCESS.value +TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name class MLProcess(object): diff --git a/metagpt/tools/libs/feature_engineering.py b/metagpt/tools/libs/feature_engineering.py index 79e1c1b07..44cf98261 100644 --- a/metagpt/tools/libs/feature_engineering.py +++ b/metagpt/tools/libs/feature_engineering.py @@ -16,10 +16,10 @@ from sklearn.model_selection import KFold from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures from metagpt.tools.libs.data_preprocess import MLProcess -from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool +from metagpt.tools.tool_types import ToolTypes -TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value +TOOL_TYPE = ToolTypes.FEATURE_ENGINEERING.type_name @register_tool(tool_type=TOOL_TYPE) diff --git a/metagpt/tools/libs/gpt_v_generator.py b/metagpt/tools/libs/gpt_v_generator.py index bae8bcbc0..6a620f7e8 100644 --- a/metagpt/tools/libs/gpt_v_generator.py +++ b/metagpt/tools/libs/gpt_v_generator.py @@ -12,8 +12,8 @@ from pathlib import Path import requests from metagpt.const import DEFAULT_WORKSPACE_ROOT -from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool +from metagpt.tools.tool_types import ToolTypes ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image: @@ -30,7 +30,7 @@ As the design pays tribute to large companies, sometimes it is normal for some c Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:""" -@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value) +@register_tool(tool_type=ToolTypes.IMAGE2WEBPAGE.type_name) class GPTvGenerator: def __init__(self): from metagpt.config2 import config diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index 7001eadf5..6fb16993e 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -16,8 +16,8 @@ from PIL import Image, PngImagePlugin # from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT from metagpt.logs import logger -from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool +from metagpt.tools.tool_types import ToolTypes payload = { "prompt": "", @@ -53,7 +53,7 @@ payload = { default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" -@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value) +@register_tool(tool_type=ToolTypes.STABLE_DIFFUSION.type_name) class SDEngine: def __init__(self, sd_url=""): # Initialize the SDEngine with configuration diff --git a/metagpt/tools/libs/web_scraping.py b/metagpt/tools/libs/web_scraping.py index 921fca809..b6db62d67 100644 --- a/metagpt/tools/libs/web_scraping.py +++ b/metagpt/tools/libs/web_scraping.py @@ -1,9 +1,9 @@ -from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool +from metagpt.tools.tool_types import ToolTypes from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper -@register_tool(tool_type=ToolTypeEnum.WEBSCRAPING.value) +@register_tool(tool_type=ToolTypes.WEBSCRAPING.type_name) async def scrape_web_playwright(url, *urls): """ Scrape and save the HTML structure and inner text content of a web page using Playwright. diff --git a/metagpt/tools/tool_data_type.py b/metagpt/tools/tool_data_type.py index 0c4eea4cc..fe42b5721 100644 --- a/metagpt/tools/tool_data_type.py +++ b/metagpt/tools/tool_data_type.py @@ -1,26 +1,9 @@ -from enum import Enum - from pydantic import BaseModel -class ToolTypeEnum(Enum): - EDA = "eda" - DATA_PREPROCESS = "data_preprocess" - FEATURE_ENGINEERING = "feature_engineering" - MODEL_TRAIN = "model_train" - MODEL_EVALUATE = "model_evaluate" - STABLE_DIFFUSION = "stable_diffusion" - IMAGE2WEBPAGE = "image2webpage" - WEBSCRAPING = "web_scraping" - OTHER = "other" - - def __missing__(self, key): - return self.OTHER - - class ToolType(BaseModel): name: str - desc: str + desc: str = "" usage_prompt: str = "" diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 7e4ee5ead..5922e7f69 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -11,12 +11,13 @@ import re from collections import defaultdict import yaml -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from metagpt.const import TOOL_SCHEMA_PATH from metagpt.logs import logger from metagpt.tools.tool_convert import convert_code_to_tool_schema from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType +from metagpt.tools.tool_types import ToolTypes class ToolRegistry(BaseModel): @@ -24,16 +25,16 @@ class ToolRegistry(BaseModel): tool_types: dict = {} tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...} - def register_tool_type(self, tool_type: ToolType, verbose: bool = False): - self.tool_types[tool_type.name] = tool_type - if verbose: - logger.info(f"tool type {tool_type.name} registered") + @field_validator("tool_types", mode="before") + @classmethod + def init_tool_types(cls, tool_types: ToolTypes): + return {tool_type.type_name: tool_type.value for tool_type in tool_types} def register_tool( self, tool_name, tool_path, - schema_path=None, + schema_path="", tool_code="", tool_type="other", tool_source_object=None, @@ -44,6 +45,16 @@ class ToolRegistry(BaseModel): if self.has_tool(tool_name): return + if tool_type not in self.tool_types: + # register new tool type on the fly + logger.warning( + f"{tool_type} not previously defined, will create a temporary ToolType with just a name. This ToolType is only effective during this runtime. You may consider add this ToolType with more configs permanently at metagpt.tools.tool_types" + ) + temp_tool_type_obj = ToolType(name=tool_type) + self.tool_types[tool_type] = temp_tool_type_obj + if verbose: + logger.info(f"tool type {tool_type} registered") + schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml" if not os.path.exists(schema_path): @@ -93,16 +104,10 @@ class ToolRegistry(BaseModel): # Registry instance -TOOL_REGISTRY = ToolRegistry() +TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes) -def register_tool_type(cls): - """register a tool type to registry""" - TOOL_REGISTRY.register_tool_type(tool_type=cls()) - return cls - - -def register_tool(tool_name="", tool_type="other", schema_path=None, **kwargs): +def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: str = "", **kwargs): """register a tool to registry""" def decorator(cls, tool_name=tool_name): diff --git a/metagpt/tools/tool_types.py b/metagpt/tools/tool_types.py index 35c0772b1..40981f836 100644 --- a/metagpt/tools/tool_types.py +++ b/metagpt/tools/tool_types.py @@ -1,3 +1,5 @@ +from enum import Enum + from metagpt.prompts.tool_types import ( DATA_PREPROCESS_PROMPT, FEATURE_ENGINEERING_PROMPT, @@ -5,64 +7,74 @@ from metagpt.prompts.tool_types import ( MODEL_EVALUATE_PROMPT, MODEL_TRAIN_PROMPT, ) -from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum -from metagpt.tools.tool_registry import register_tool_type +from metagpt.tools.tool_data_type import ToolType + +Eda = ToolType(name="eda", desc="For performing exploratory data analysis") + +DataPreprocess = ToolType( + name="data_preprocess", + desc="Only for changing value inplace.", + usage_prompt=DATA_PREPROCESS_PROMPT, +) -@register_tool_type -class EDA(ToolType): - name: str = ToolTypeEnum.EDA.value - desc: str = "For performing exploratory data analysis" +FeatureEngineering = ToolType( + name="feature_engineering", + desc="Only for creating new columns for input data.", + usage_prompt=FEATURE_ENGINEERING_PROMPT, +) -@register_tool_type -class DataPreprocess(ToolType): - name: str = ToolTypeEnum.DATA_PREPROCESS.value - desc: str = "Only for changing value inplace." - usage_prompt: str = DATA_PREPROCESS_PROMPT +ModelTrain = ToolType( + name="model_train", + desc="Only for training model.", + usage_prompt=MODEL_TRAIN_PROMPT, +) -@register_tool_type -class FeatureEngineer(ToolType): - name: str = ToolTypeEnum.FEATURE_ENGINEERING.value - desc: str = "Only for creating new columns for input data." - usage_prompt: str = FEATURE_ENGINEERING_PROMPT +ModelEvaluate = ToolType( + name="model_evaluate", + desc="Only for evaluating model.", + usage_prompt=MODEL_EVALUATE_PROMPT, +) -@register_tool_type -class ModelTrain(ToolType): - name: str = ToolTypeEnum.MODEL_TRAIN.value - desc: str = "Only for training model." - usage_prompt: str = MODEL_TRAIN_PROMPT +StableDiffusion = ToolType( + name="stable_diffusion", + desc="Related to text2image, image2image using stable diffusion model.", +) -@register_tool_type -class ModelEvaluate(ToolType): - name: str = ToolTypeEnum.MODEL_EVALUATE.value - desc: str = "Only for evaluating model." - usage_prompt: str = MODEL_EVALUATE_PROMPT +Image2Webpage = ToolType( + name="image2webpage", + desc="For converting image into webpage code.", + usage_prompt=IMAGE2WEBPAGE_PROMPT, +) -@register_tool_type -class StableDiffusion(ToolType): - name: str = ToolTypeEnum.STABLE_DIFFUSION.value - desc: str = "Related to text2image, image2image using stable diffusion model." +WebScraping = ToolType( + name="web_scraping", + desc="For scraping data from web pages.", +) -@register_tool_type -class Image2Webpage(ToolType): - name: str = ToolTypeEnum.IMAGE2WEBPAGE.value - desc: str = "For converting image into webpage code." - usage_prompt: str = IMAGE2WEBPAGE_PROMPT +Other = ToolType(name="other", desc="Any tools not in the defined categories") -@register_tool_type -class WebScraping(ToolType): - name: str = ToolTypeEnum.WEBSCRAPING.value - desc: str = "For scraping data from web pages." +class ToolTypes(Enum): + EDA = Eda + DATA_PREPROCESS = DataPreprocess + FEATURE_ENGINEERING = FeatureEngineering + MODEL_TRAIN = ModelTrain + MODEL_EVALUATE = ModelEvaluate + STABLE_DIFFUSION = StableDiffusion + IMAGE2WEBPAGE = Image2Webpage + WEBSCRAPING = WebScraping + OTHER = Other + def __missing__(self, key): + return self.OTHER -@register_tool_type -class Other(ToolType): - name: str = ToolTypeEnum.OTHER.value - desc: str = "Any tools not in the defined categories" + @property + def type_name(self): + return self.value.name diff --git a/tests/metagpt/roles/test_ml_engineer.py b/tests/metagpt/roles/test_ml_engineer.py index fb1e67cb8..c00481019 100644 --- a/tests/metagpt/roles/test_ml_engineer.py +++ b/tests/metagpt/roles/test_ml_engineer.py @@ -4,7 +4,7 @@ from metagpt.actions.execute_nb_code import ExecuteNbCode from metagpt.logs import logger from metagpt.roles.ml_engineer import MLEngineer from metagpt.schema import Message, Plan, Task -from metagpt.tools.tool_data_type import ToolTypeEnum +from metagpt.tools.tool_types import ToolTypes from tests.metagpt.actions.test_debug_code import CODE, DebugContext, ErrorStr @@ -63,7 +63,7 @@ async def test_mle_update_data_columns(mocker): mle.planner.plan = MockPlan # manually update task type to test update - mle.planner.plan.current_task.task_type = ToolTypeEnum.DATA_PREPROCESS.value + mle.planner.plan.current_task.task_type = ToolTypes.DATA_PREPROCESS.value result = await mle._update_data_columns() assert result is not None diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py index c24122e39..bb5d7a0bd 100644 --- a/tests/metagpt/tools/test_tool_registry.py +++ b/tests/metagpt/tools/test_tool_registry.py @@ -1,7 +1,7 @@ import pytest from metagpt.tools.tool_registry import ToolRegistry -from metagpt.tools.tool_types import ToolType +from metagpt.tools.tool_types import ToolTypes @pytest.fixture @@ -9,6 +9,11 @@ def tool_registry(): return ToolRegistry() +@pytest.fixture +def tool_registry_full(): + return ToolRegistry(tool_types=ToolTypes) + + @pytest.fixture def schema_yaml(mocker): mock_yaml_content = """ @@ -29,11 +34,12 @@ def test_initialization(tool_registry): assert tool_registry.tools_by_types == {} -# Test Tool Type Registration -def test_register_tool_type(tool_registry): - tool_type = ToolType(name="TestType", desc="test") - tool_registry.register_tool_type(tool_type) - assert "TestType" in tool_registry.tool_types +# Test Initialization with tool types +def test_initialize_with_tool_types(tool_registry_full): + assert isinstance(tool_registry_full, ToolRegistry) + assert tool_registry_full.tools == {} + assert tool_registry_full.tools_by_types == {} + assert "data_preprocess" in tool_registry_full.tool_types # Test Tool Registration @@ -66,27 +72,21 @@ def test_get_tool(tool_registry, schema_yaml): # Similar tests for has_tool_type, get_tool_type, get_tools_by_type -def test_has_tool_type(tool_registry): - tool_type = ToolType(name="TestType", desc="test") - tool_registry.register_tool_type(tool_type) - assert tool_registry.has_tool_type("TestType") - assert not tool_registry.has_tool_type("NonexistentType") +def test_has_tool_type(tool_registry_full): + assert tool_registry_full.has_tool_type("data_preprocess") + assert not tool_registry_full.has_tool_type("NonexistentType") -def test_get_tool_type(tool_registry): - tool_type = ToolType(name="TestType", desc="test") - tool_registry.register_tool_type(tool_type) - retrieved_type = tool_registry.get_tool_type("TestType") +def test_get_tool_type(tool_registry_full): + retrieved_type = tool_registry_full.get_tool_type("data_preprocess") assert retrieved_type is not None - assert retrieved_type.name == "TestType" + assert retrieved_type.name == "data_preprocess" def test_get_tools_by_type(tool_registry, schema_yaml): tool_type_name = "TestType" tool_name = "TestTool" tool_path = "/path/to/tool" - tool_type = ToolType(name=tool_type_name, desc="test") - tool_registry.register_tool_type(tool_type) tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)