mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
remove ToolTypesEnum
This commit is contained in:
parent
35438e7b03
commit
1da50f1825
11 changed files with 109 additions and 109 deletions
|
|
@ -3,7 +3,7 @@ from metagpt.actions.execute_nb_code import ExecuteNbCode
|
|||
from metagpt.actions.ml_action import UpdateDataColumns, WriteCodeWithToolsML
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
|
|
@ -51,9 +51,9 @@ class MLEngineer(CodeInterpreter):
|
|||
async def _update_data_columns(self):
|
||||
current_task = self.planner.plan.current_task
|
||||
if current_task.task_type not in [
|
||||
ToolTypeEnum.DATA_PREPROCESS.value,
|
||||
ToolTypeEnum.FEATURE_ENGINEERING.value,
|
||||
ToolTypeEnum.MODEL_TRAIN.value,
|
||||
ToolTypes.DATA_PREPROCESS.type_name,
|
||||
ToolTypes.FEATURE_ENGINEERING.type_name,
|
||||
ToolTypes.MODEL_TRAIN.type_name,
|
||||
]:
|
||||
return ""
|
||||
logger.info("Check columns in updated data")
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ from sklearn.preprocessing import (
|
|||
StandardScaler,
|
||||
)
|
||||
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
|
||||
TOOL_TYPE = ToolTypeEnum.DATA_PREPROCESS.value
|
||||
TOOL_TYPE = ToolTypes.DATA_PREPROCESS.type_name
|
||||
|
||||
|
||||
class MLProcess(object):
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ from sklearn.model_selection import KFold
|
|||
from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
|
||||
|
||||
from metagpt.tools.libs.data_preprocess import MLProcess
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
|
||||
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value
|
||||
TOOL_TYPE = ToolTypes.FEATURE_ENGINEERING.type_name
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from pathlib import Path
|
|||
import requests
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
|
||||
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
|
||||
|
||||
|
|
@ -30,7 +30,7 @@ As the design pays tribute to large companies, sometimes it is normal for some c
|
|||
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value)
|
||||
@register_tool(tool_type=ToolTypes.IMAGE2WEBPAGE.type_name)
|
||||
class GPTvGenerator:
|
||||
def __init__(self):
|
||||
from metagpt.config2 import config
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ from PIL import Image, PngImagePlugin
|
|||
#
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
|
|
@ -53,7 +53,7 @@ payload = {
|
|||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value)
|
||||
@register_tool(tool_type=ToolTypes.STABLE_DIFFUSION.type_name)
|
||||
class SDEngine:
|
||||
def __init__(self, sd_url=""):
|
||||
# Initialize the SDEngine with configuration
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolTypeEnum.WEBSCRAPING.value)
|
||||
@register_tool(tool_type=ToolTypes.WEBSCRAPING.type_name)
|
||||
async def scrape_web_playwright(url, *urls):
|
||||
"""
|
||||
Scrape and save the HTML structure and inner text content of a web page using Playwright.
|
||||
|
|
|
|||
|
|
@ -1,26 +1,9 @@
|
|||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolTypeEnum(Enum):
|
||||
EDA = "eda"
|
||||
DATA_PREPROCESS = "data_preprocess"
|
||||
FEATURE_ENGINEERING = "feature_engineering"
|
||||
MODEL_TRAIN = "model_train"
|
||||
MODEL_EVALUATE = "model_evaluate"
|
||||
STABLE_DIFFUSION = "stable_diffusion"
|
||||
IMAGE2WEBPAGE = "image2webpage"
|
||||
WEBSCRAPING = "web_scraping"
|
||||
OTHER = "other"
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OTHER
|
||||
|
||||
|
||||
class ToolType(BaseModel):
|
||||
name: str
|
||||
desc: str
|
||||
desc: str = ""
|
||||
usage_prompt: str = ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,12 +11,13 @@ import re
|
|||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from metagpt.const import TOOL_SCHEMA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolType
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
|
||||
|
||||
class ToolRegistry(BaseModel):
|
||||
|
|
@ -24,16 +25,16 @@ class ToolRegistry(BaseModel):
|
|||
tool_types: dict = {}
|
||||
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
|
||||
|
||||
def register_tool_type(self, tool_type: ToolType, verbose: bool = False):
|
||||
self.tool_types[tool_type.name] = tool_type
|
||||
if verbose:
|
||||
logger.info(f"tool type {tool_type.name} registered")
|
||||
@field_validator("tool_types", mode="before")
|
||||
@classmethod
|
||||
def init_tool_types(cls, tool_types: ToolTypes):
|
||||
return {tool_type.type_name: tool_type.value for tool_type in tool_types}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
tool_name,
|
||||
tool_path,
|
||||
schema_path=None,
|
||||
schema_path="",
|
||||
tool_code="",
|
||||
tool_type="other",
|
||||
tool_source_object=None,
|
||||
|
|
@ -44,6 +45,16 @@ class ToolRegistry(BaseModel):
|
|||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
if tool_type not in self.tool_types:
|
||||
# register new tool type on the fly
|
||||
logger.warning(
|
||||
f"{tool_type} not previously defined, will create a temporary ToolType with just a name. This ToolType is only effective during this runtime. You may consider add this ToolType with more configs permanently at metagpt.tools.tool_types"
|
||||
)
|
||||
temp_tool_type_obj = ToolType(name=tool_type)
|
||||
self.tool_types[tool_type] = temp_tool_type_obj
|
||||
if verbose:
|
||||
logger.info(f"tool type {tool_type} registered")
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
|
||||
if not os.path.exists(schema_path):
|
||||
|
|
@ -93,16 +104,10 @@ class ToolRegistry(BaseModel):
|
|||
|
||||
|
||||
# Registry instance
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
TOOL_REGISTRY = ToolRegistry(tool_types=ToolTypes)
|
||||
|
||||
|
||||
def register_tool_type(cls):
|
||||
"""register a tool type to registry"""
|
||||
TOOL_REGISTRY.register_tool_type(tool_type=cls())
|
||||
return cls
|
||||
|
||||
|
||||
def register_tool(tool_name="", tool_type="other", schema_path=None, **kwargs):
|
||||
def register_tool(tool_name: str = "", tool_type: str = "other", schema_path: str = "", **kwargs):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls, tool_name=tool_name):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from enum import Enum
|
||||
|
||||
from metagpt.prompts.tool_types import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
|
|
@ -5,64 +7,74 @@ from metagpt.prompts.tool_types import (
|
|||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
from metagpt.tools.tool_data_type import ToolType, ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool_type
|
||||
from metagpt.tools.tool_data_type import ToolType
|
||||
|
||||
Eda = ToolType(name="eda", desc="For performing exploratory data analysis")
|
||||
|
||||
DataPreprocess = ToolType(
|
||||
name="data_preprocess",
|
||||
desc="Only for changing value inplace.",
|
||||
usage_prompt=DATA_PREPROCESS_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class EDA(ToolType):
|
||||
name: str = ToolTypeEnum.EDA.value
|
||||
desc: str = "For performing exploratory data analysis"
|
||||
FeatureEngineering = ToolType(
|
||||
name="feature_engineering",
|
||||
desc="Only for creating new columns for input data.",
|
||||
usage_prompt=FEATURE_ENGINEERING_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class DataPreprocess(ToolType):
|
||||
name: str = ToolTypeEnum.DATA_PREPROCESS.value
|
||||
desc: str = "Only for changing value inplace."
|
||||
usage_prompt: str = DATA_PREPROCESS_PROMPT
|
||||
ModelTrain = ToolType(
|
||||
name="model_train",
|
||||
desc="Only for training model.",
|
||||
usage_prompt=MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class FeatureEngineer(ToolType):
|
||||
name: str = ToolTypeEnum.FEATURE_ENGINEERING.value
|
||||
desc: str = "Only for creating new columns for input data."
|
||||
usage_prompt: str = FEATURE_ENGINEERING_PROMPT
|
||||
ModelEvaluate = ToolType(
|
||||
name="model_evaluate",
|
||||
desc="Only for evaluating model.",
|
||||
usage_prompt=MODEL_EVALUATE_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class ModelTrain(ToolType):
|
||||
name: str = ToolTypeEnum.MODEL_TRAIN.value
|
||||
desc: str = "Only for training model."
|
||||
usage_prompt: str = MODEL_TRAIN_PROMPT
|
||||
StableDiffusion = ToolType(
|
||||
name="stable_diffusion",
|
||||
desc="Related to text2image, image2image using stable diffusion model.",
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class ModelEvaluate(ToolType):
|
||||
name: str = ToolTypeEnum.MODEL_EVALUATE.value
|
||||
desc: str = "Only for evaluating model."
|
||||
usage_prompt: str = MODEL_EVALUATE_PROMPT
|
||||
Image2Webpage = ToolType(
|
||||
name="image2webpage",
|
||||
desc="For converting image into webpage code.",
|
||||
usage_prompt=IMAGE2WEBPAGE_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class StableDiffusion(ToolType):
|
||||
name: str = ToolTypeEnum.STABLE_DIFFUSION.value
|
||||
desc: str = "Related to text2image, image2image using stable diffusion model."
|
||||
WebScraping = ToolType(
|
||||
name="web_scraping",
|
||||
desc="For scraping data from web pages.",
|
||||
)
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class Image2Webpage(ToolType):
|
||||
name: str = ToolTypeEnum.IMAGE2WEBPAGE.value
|
||||
desc: str = "For converting image into webpage code."
|
||||
usage_prompt: str = IMAGE2WEBPAGE_PROMPT
|
||||
Other = ToolType(name="other", desc="Any tools not in the defined categories")
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class WebScraping(ToolType):
|
||||
name: str = ToolTypeEnum.WEBSCRAPING.value
|
||||
desc: str = "For scraping data from web pages."
|
||||
class ToolTypes(Enum):
|
||||
EDA = Eda
|
||||
DATA_PREPROCESS = DataPreprocess
|
||||
FEATURE_ENGINEERING = FeatureEngineering
|
||||
MODEL_TRAIN = ModelTrain
|
||||
MODEL_EVALUATE = ModelEvaluate
|
||||
STABLE_DIFFUSION = StableDiffusion
|
||||
IMAGE2WEBPAGE = Image2Webpage
|
||||
WEBSCRAPING = WebScraping
|
||||
OTHER = Other
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OTHER
|
||||
|
||||
@register_tool_type
|
||||
class Other(ToolType):
|
||||
name: str = ToolTypeEnum.OTHER.value
|
||||
desc: str = "Any tools not in the defined categories"
|
||||
@property
|
||||
def type_name(self):
|
||||
return self.value.name
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from metagpt.actions.execute_nb_code import ExecuteNbCode
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles.ml_engineer import MLEngineer
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
from tests.metagpt.actions.test_debug_code import CODE, DebugContext, ErrorStr
|
||||
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ async def test_mle_update_data_columns(mocker):
|
|||
mle.planner.plan = MockPlan
|
||||
|
||||
# manually update task type to test update
|
||||
mle.planner.plan.current_task.task_type = ToolTypeEnum.DATA_PREPROCESS.value
|
||||
mle.planner.plan.current_task.task_type = ToolTypes.DATA_PREPROCESS.value
|
||||
|
||||
result = await mle._update_data_columns()
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_types import ToolType
|
||||
from metagpt.tools.tool_types import ToolTypes
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -9,6 +9,11 @@ def tool_registry():
|
|||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolTypes)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema_yaml(mocker):
|
||||
mock_yaml_content = """
|
||||
|
|
@ -29,11 +34,12 @@ def test_initialization(tool_registry):
|
|||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Tool Type Registration
|
||||
def test_register_tool_type(tool_registry):
|
||||
tool_type = ToolType(name="TestType", desc="test")
|
||||
tool_registry.register_tool_type(tool_type)
|
||||
assert "TestType" in tool_registry.tool_types
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
|
||||
|
||||
# Test Tool Registration
|
||||
|
|
@ -66,27 +72,21 @@ def test_get_tool(tool_registry, schema_yaml):
|
|||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry):
|
||||
tool_type = ToolType(name="TestType", desc="test")
|
||||
tool_registry.register_tool_type(tool_type)
|
||||
assert tool_registry.has_tool_type("TestType")
|
||||
assert not tool_registry.has_tool_type("NonexistentType")
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry):
|
||||
tool_type = ToolType(name="TestType", desc="test")
|
||||
tool_registry.register_tool_type(tool_type)
|
||||
retrieved_type = tool_registry.get_tool_type("TestType")
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "TestType"
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry, schema_yaml):
|
||||
tool_type_name = "TestType"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
tool_type = ToolType(name=tool_type_name, desc="test")
|
||||
tool_registry.register_tool_type(tool_type)
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue