diff --git a/metagpt/roles/ci/ml_engineer.py b/metagpt/roles/ci/ml_engineer.py index 6fa6fe7b2..f8bcb2c89 100644 --- a/metagpt/roles/ci/ml_engineer.py +++ b/metagpt/roles/ci/ml_engineer.py @@ -3,7 +3,7 @@ from metagpt.actions.ci.execute_nb_code import ExecuteNbCode from metagpt.actions.ci.ml_action import UpdateDataColumns, WriteCodeWithToolsML from metagpt.logs import logger from metagpt.roles.ci.code_interpreter import CodeInterpreter -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType from metagpt.utils.common import any_to_str @@ -51,9 +51,9 @@ class MLEngineer(CodeInterpreter): async def _update_data_columns(self): current_task = self.planner.plan.current_task if current_task.task_type not in [ - ToolTypes.DATA_PREPROCESS.type_name, - ToolTypes.FEATURE_ENGINEERING.type_name, - ToolTypes.MODEL_TRAIN.type_name, + ToolType.DATA_PREPROCESS.type_name, + ToolType.FEATURE_ENGINEERING.type_name, + ToolType.MODEL_TRAIN.type_name, ]: return "" logger.info("Check columns in updated data") diff --git a/metagpt/tools/libs/data_preprocess.py b/metagpt/tools/libs/data_preprocess.py index 2cfa0b389..c9ca657a5 100644 --- a/metagpt/tools/libs/data_preprocess.py +++ b/metagpt/tools/libs/data_preprocess.py @@ -14,9 +14,9 @@ from sklearn.preprocessing import ( ) from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType -TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name +TOOL_TYPE = ToolType.DATA_PREPROCESS.type_name class MLProcess: diff --git a/metagpt/tools/libs/feature_engineering.py b/metagpt/tools/libs/feature_engineering.py index bbd16b681..325742105 100644 --- a/metagpt/tools/libs/feature_engineering.py +++ b/metagpt/tools/libs/feature_engineering.py @@ -17,9 +17,9 @@ from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures from metagpt.tools.libs.data_preprocess import MLProcess from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType -TOOL_TYPE = ToolTypes.FEATURE_ENGINEERING.type_name +TOOL_TYPE = ToolType.FEATURE_ENGINEERING.type_name @register_tool(tool_type=TOOL_TYPE) diff --git a/metagpt/tools/libs/gpt_v_generator.py b/metagpt/tools/libs/gpt_v_generator.py index 63fda3e81..6953300d8 100644 --- a/metagpt/tools/libs/gpt_v_generator.py +++ b/metagpt/tools/libs/gpt_v_generator.py @@ -13,7 +13,7 @@ import requests from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image: @@ -31,7 +31,7 @@ Now, please generate the corresponding webpage code including HTML, CSS and Java @register_tool( - tool_type=ToolTypes.IMAGE2WEBPAGE.type_name, include_functions=["__init__", "generate_webpages", "save_webpages"] + tool_type=ToolType.IMAGE2WEBPAGE.type_name, include_functions=["__init__", "generate_webpages", "save_webpages"] ) class GPTvGenerator: """Class for generating webpages at once. diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index 6229a60e3..58f34a152 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -17,7 +17,7 @@ from PIL import Image, PngImagePlugin from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType payload = { "prompt": "", @@ -54,7 +54,7 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" @register_tool( - tool_type=ToolTypes.STABLE_DIFFUSION.type_name, + tool_type=ToolType.STABLE_DIFFUSION.type_name, include_functions=["__init__", "simple_run_t2i", "run_t2i", "construct_payload", "save"], ) class SDEngine: diff --git a/metagpt/tools/libs/web_scraping.py b/metagpt/tools/libs/web_scraping.py index f983c1215..6fd3b9435 100644 --- a/metagpt/tools/libs/web_scraping.py +++ b/metagpt/tools/libs/web_scraping.py @@ -1,9 +1,9 @@ from metagpt.tools.tool_registry import register_tool -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper -@register_tool(tool_type=ToolTypes.WEBSCRAPING.type_name) +@register_tool(tool_type=ToolType.WEBSCRAPING.type_name) async def scrape_web_playwright(url, *urls): """ Scrape and save the HTML structure and inner text content of a web page using Playwright. diff --git a/metagpt/tools/tool_data_type.py b/metagpt/tools/tool_data_type.py index fe42b5721..0ae46fa5c 100644 --- a/metagpt/tools/tool_data_type.py +++ b/metagpt/tools/tool_data_type.py @@ -1,14 +1,14 @@ from pydantic import BaseModel -class ToolType(BaseModel): +class ToolTypeDef(BaseModel): name: str desc: str = "" usage_prompt: str = "" class ToolSchema(BaseModel): - name: str + description: str class Tool(BaseModel): diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 299d62ca3..87645d35a 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -16,8 +16,8 @@ from pydantic import BaseModel, field_validator from metagpt.const import TOOL_SCHEMA_PATH from metagpt.logs import logger from metagpt.tools.tool_convert import convert_code_to_tool_schema -from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolTypeDef +from metagpt.tools.tool_type import ToolType class ToolRegistry(BaseModel): @@ -27,7 +27,7 @@ class ToolRegistry(BaseModel): @field_validator("tool_types", mode="before") @classmethod - def init_tool_types(cls, tool_types: ToolTypes): + def init_tool_types(cls, tool_types: ToolType): return {tool_type.type_name: tool_type.value for tool_type in tool_types} def register_tool( @@ -47,9 +47,9 @@ class ToolRegistry(BaseModel): if tool_type not in self.tool_types: # register new tool type on the fly logger.warning( - f"{tool_type} not previously defined, will create a temporary ToolType with just a name. This ToolType is only effective during this runtime. You may consider add this ToolType with more configs permanently at metagpt.tools.tool_types" + f"{tool_type} not previously defined, will create a temporary tool type with just a name. This tool type is only effective during this runtime. You may consider add this tool type with more configs permanently at metagpt.tools.tool_type" ) - temp_tool_type_obj = ToolType(name=tool_type) + temp_tool_type_obj = ToolTypeDef(name=tool_type) self.tool_types[tool_type] = temp_tool_type_obj if verbose: logger.info(f"tool type {tool_type} registered") @@ -97,7 +97,7 @@ class ToolRegistry(BaseModel): # Registry instance -TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes) +TOOL_REGISTRY = ToolRegistry(tool_types=ToolType) def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs): diff --git a/metagpt/tools/tool_types.py b/metagpt/tools/tool_type.py similarity index 71% rename from metagpt/tools/tool_types.py rename to metagpt/tools/tool_type.py index d96c0299c..6fa971c56 100644 --- a/metagpt/tools/tool_types.py +++ b/metagpt/tools/tool_type.py @@ -7,45 +7,45 @@ from metagpt.prompts.tool_types import ( MODEL_EVALUATE_PROMPT, MODEL_TRAIN_PROMPT, ) -from metagpt.tools.tool_data_type import ToolType +from metagpt.tools.tool_data_type import ToolTypeDef -class ToolTypes(Enum): - EDA = ToolType(name="eda", desc="For performing exploratory data analysis") - DATA_PREPROCESS = ToolType( +class ToolType(Enum): + EDA = ToolTypeDef(name="eda", desc="For performing exploratory data analysis") + DATA_PREPROCESS = ToolTypeDef( name="data_preprocess", desc="Only for changing value inplace.", usage_prompt=DATA_PREPROCESS_PROMPT, ) - FEATURE_ENGINEERING = ToolType( + FEATURE_ENGINEERING = ToolTypeDef( name="feature_engineering", desc="Only for creating new columns for input data.", usage_prompt=FEATURE_ENGINEERING_PROMPT, ) - MODEL_TRAIN = ToolType( + MODEL_TRAIN = ToolTypeDef( name="model_train", desc="Only for training model.", usage_prompt=MODEL_TRAIN_PROMPT, ) - MODEL_EVALUATE = ToolType( + MODEL_EVALUATE = ToolTypeDef( name="model_evaluate", desc="Only for evaluating model.", usage_prompt=MODEL_EVALUATE_PROMPT, ) - STABLE_DIFFUSION = ToolType( + STABLE_DIFFUSION = ToolTypeDef( name="stable_diffusion", desc="Related to text2image, image2image using stable diffusion model.", ) - IMAGE2WEBPAGE = ToolType( + IMAGE2WEBPAGE = ToolTypeDef( name="image2webpage", desc="For converting image into webpage code.", usage_prompt=IMAGE2WEBPAGE_PROMPT, ) - WEBSCRAPING = ToolType( + WEBSCRAPING = ToolTypeDef( name="web_scraping", desc="For scraping data from web pages.", ) - OTHER = ToolType(name="other", desc="Any tools not in the defined categories") + OTHER = ToolTypeDef(name="other", desc="Any tools not in the defined categories") def __missing__(self, key): return self.OTHER diff --git a/tests/metagpt/roles/ci/test_ml_engineer.py b/tests/metagpt/roles/ci/test_ml_engineer.py index 144201f85..3bf9f3b92 100644 --- a/tests/metagpt/roles/ci/test_ml_engineer.py +++ b/tests/metagpt/roles/ci/test_ml_engineer.py @@ -4,7 +4,7 @@ from metagpt.actions.ci.execute_nb_code import ExecuteNbCode from metagpt.logs import logger from metagpt.roles.ci.ml_engineer import MLEngineer from metagpt.schema import Message, Plan, Task -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType from tests.metagpt.actions.ci.test_debug_code import CODE, DebugContext, ErrorStr @@ -61,7 +61,7 @@ async def test_mle_update_data_columns(mocker): mle.planner.plan = MockPlan # manually update task type to test update - mle.planner.plan.current_task.task_type = ToolTypes.DATA_PREPROCESS.value + mle.planner.plan.current_task.task_type = ToolType.DATA_PREPROCESS.value result = await mle._update_data_columns() assert result is not None diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py index e41ddfa79..2fd487fb7 100644 --- a/tests/metagpt/tools/test_tool_registry.py +++ b/tests/metagpt/tools/test_tool_registry.py @@ -1,7 +1,7 @@ import pytest from metagpt.tools.tool_registry import ToolRegistry -from metagpt.tools.tool_types import ToolTypes +from metagpt.tools.tool_type import ToolType @pytest.fixture @@ -11,7 +11,7 @@ def tool_registry(): @pytest.fixture def tool_registry_full(): - return ToolRegistry(tool_types=ToolTypes) + return ToolRegistry(tool_types=ToolType) # Test Initialization