diff --git a/metagpt/prompts/tool_types.py b/metagpt/prompts/tool_types.py
index 43ead78a6..c01a80310 100644
--- a/metagpt/prompts/tool_types.py
+++ b/metagpt/prompts/tool_types.py
@@ -39,7 +39,7 @@ The current task is about evaluating a model, please note the following:
"""
# Prompt for using tools of "vision" type
-VISION_PROMPT = """
+IMAGE2WEBPAGE_PROMPT = """
The current task is about converting image into webpage code. please note the following:
- Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow.
-"""
\ No newline at end of file
+"""
diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py
index 23b51533d..f18d1d276 100644
--- a/metagpt/tools/__init__.py
+++ b/metagpt/tools/__init__.py
@@ -11,9 +11,7 @@ from metagpt.tools import tool_types # this registers all tool types
from metagpt.tools import libs # this registers all tools
from metagpt.tools.tool_registry import TOOL_REGISTRY
-_ = tool_types # Avoid pre-commit error
-_ = libs # Avoid pre-commit error
-_ = TOOL_REGISTRY # Avoid pre-commit error
+_, _, _ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error
class SearchEngineType(Enum):
diff --git a/metagpt/tools/libs/__init__.py b/metagpt/tools/libs/__init__.py
index 3d74674aa..b576997c9 100644
--- a/metagpt/tools/libs/__init__.py
+++ b/metagpt/tools/libs/__init__.py
@@ -7,7 +7,8 @@
from metagpt.tools.libs import (
data_preprocess,
feature_engineering,
+ sd_engine,
+ gpt_v_generator,
)
-_ = data_preprocess # Avoid pre-commit error
-_ = feature_engineering # Avoid pre-commit error
+_, _, _, _ = data_preprocess, feature_engineering, sd_engine, gpt_v_generator # Avoid pre-commit error
diff --git a/metagpt/tools/libs/data_preprocess.py b/metagpt/tools/libs/data_preprocess.py
index 7cc44263d..3891f9df0 100644
--- a/metagpt/tools/libs/data_preprocess.py
+++ b/metagpt/tools/libs/data_preprocess.py
@@ -31,7 +31,7 @@ class MLProcess(object):
return self.transform(df)
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class FillMissingValue(MLProcess):
def __init__(
self,
@@ -58,7 +58,7 @@ class FillMissingValue(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class MinMaxScale(MLProcess):
def __init__(
self,
@@ -77,7 +77,7 @@ class MinMaxScale(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class StandardScale(MLProcess):
def __init__(
self,
@@ -96,7 +96,7 @@ class StandardScale(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class MaxAbsScale(MLProcess):
def __init__(
self,
@@ -115,7 +115,7 @@ class MaxAbsScale(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class RobustScale(MLProcess):
def __init__(
self,
@@ -134,7 +134,7 @@ class RobustScale(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class OrdinalEncode(MLProcess):
def __init__(
self,
@@ -153,7 +153,7 @@ class OrdinalEncode(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class OneHotEncode(MLProcess):
def __init__(
self,
@@ -175,7 +175,7 @@ class OneHotEncode(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class LabelEncode(MLProcess):
def __init__(
self,
@@ -204,7 +204,7 @@ class LabelEncode(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
def get_column_info(df: pd.DataFrame) -> dict:
column_info = {
"Category": [],
diff --git a/metagpt/tools/libs/feature_engineering.py b/metagpt/tools/libs/feature_engineering.py
index ed5c1be72..308150f9b 100644
--- a/metagpt/tools/libs/feature_engineering.py
+++ b/metagpt/tools/libs/feature_engineering.py
@@ -22,7 +22,7 @@ from metagpt.tools.tool_registry import register_tool
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class PolynomialExpansion(MLProcess):
def __init__(self, cols: list, degree: int = 2, label_col: str = None):
self.cols = cols
@@ -53,7 +53,7 @@ class PolynomialExpansion(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class CatCount(MLProcess):
def __init__(self, col: str):
self.col = col
@@ -68,7 +68,7 @@ class CatCount(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class TargetMeanEncoder(MLProcess):
def __init__(self, col: str, label: str):
self.col = col
@@ -84,7 +84,7 @@ class TargetMeanEncoder(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class KFoldTargetMeanEncoder(MLProcess):
def __init__(self, col: str, label: str, n_splits: int = 5, random_state: int = 2021):
self.col = col
@@ -111,7 +111,7 @@ class KFoldTargetMeanEncoder(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class CatCross(MLProcess):
def __init__(self, cols: list, max_cat_num: int = 100):
self.cols = cols
@@ -147,7 +147,7 @@ class CatCross(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class GroupStat(MLProcess):
def __init__(self, group_col: str, agg_col: str, agg_funcs: list):
self.group_col = group_col
@@ -167,7 +167,7 @@ class GroupStat(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class SplitBins(MLProcess):
def __init__(self, cols: list, strategy: str = "quantile"):
self.cols = cols
@@ -184,7 +184,7 @@ class SplitBins(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class ExtractTimeComps(MLProcess):
def __init__(self, time_col: str, time_comps: list):
self.time_col = time_col
@@ -213,7 +213,7 @@ class ExtractTimeComps(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class GeneralSelection(MLProcess):
def __init__(self, label_col: str):
self.label_col = label_col
@@ -284,7 +284,7 @@ class TreeBasedSelection(MLProcess):
return new_df
-@register_tool(tool_type_name=TOOL_TYPE)
+@register_tool(tool_type=TOOL_TYPE)
class VarianceBasedSelection(MLProcess):
def __init__(self, label_col: str, threshold: float = 0):
self.label_col = label_col
diff --git a/metagpt/tools/functions/libs/vision.py b/metagpt/tools/libs/gpt_v_generator.py
similarity index 85%
rename from metagpt/tools/functions/libs/vision.py
rename to metagpt/tools/libs/gpt_v_generator.py
index b10ad7608..58e547840 100644
--- a/metagpt/tools/functions/libs/vision.py
+++ b/metagpt/tools/libs/gpt_v_generator.py
@@ -5,18 +5,13 @@
@Author : mannaandpoem
@File : vision.py
"""
+import base64
from pathlib import Path
import requests
-import base64
-
-from metagpt.config import CONFIG
-
-OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL
-API_KEY = CONFIG.OPENAI_API_KEY
-MODEL = CONFIG.OPENAI_VISION_MODEL
-MAX_TOKENS = CONFIG.VISION_MAX_TOKENS
+from metagpt.tools.tool_data_type import ToolTypeEnum
+from metagpt.tools.tool_registry import register_tool
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
@@ -33,8 +28,15 @@ As the design pays tribute to large companies, sometimes it is normal for some c
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
-class Vision:
+@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value)
+class GPTvGenerator:
def __init__(self):
+ from metagpt.config import CONFIG
+
+ OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL
+ API_KEY = CONFIG.OPENAI_API_KEY
+ MODEL = CONFIG.OPENAI_VISION_MODEL
+ MAX_TOKENS = CONFIG.VISION_MAX_TOKENS
self.api_key = API_KEY
self.api_base = OPENAI_API_BASE
self.model = MODEL
@@ -51,10 +53,7 @@ class Vision:
def get_result(self, image_path, prompt):
base64_image = self.encode_image(image_path)
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}"
- }
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
payload = {
"model": self.model,
"messages": [
@@ -62,11 +61,8 @@ class Vision:
"role": "user",
"content": [
{"type": "text", "text": prompt},
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
- }
- ]
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
+ ],
}
],
"max_tokens": self.max_tokens,
@@ -81,7 +77,7 @@ class Vision:
@staticmethod
def encode_image(image_path):
with open(image_path, "rb") as image_file:
- return base64.b64encode(image_file.read()).decode('utf-8')
+ return base64.b64encode(image_file.read()).decode("utf-8")
@staticmethod
def save_webpages(image_path, webpages) -> Path:
diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py
index ad63c2505..794758f77 100644
--- a/metagpt/tools/libs/sd_engine.py
+++ b/metagpt/tools/libs/sd_engine.py
@@ -13,7 +13,6 @@ import requests
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
-from metagpt.config import CONFIG
from metagpt.const import SD_OUTPUT_FILE_REPO
from metagpt.logs import logger
from metagpt.tools.tool_data_type import ToolTypeEnum
@@ -53,9 +52,11 @@ payload = {
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
-@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION.value)
+@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value)
class SDEngine:
def __init__(self, sd_url=""):
+ from metagpt.config import CONFIG
+
# Initialize the SDEngine with configuration
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
diff --git a/metagpt/tools/functions/schemas/vision.yml b/metagpt/tools/schemas/image2webpage/GPTvGenerator.yml
similarity index 93%
rename from metagpt/tools/functions/schemas/vision.yml
rename to metagpt/tools/schemas/image2webpage/GPTvGenerator.yml
index 4cb247419..4087f7c12 100644
--- a/metagpt/tools/functions/schemas/vision.yml
+++ b/metagpt/tools/schemas/image2webpage/GPTvGenerator.yml
@@ -1,4 +1,4 @@
-Vision:
+GPTvGenerator:
type: class
description: "Class for generating web pages at once."
methods:
diff --git a/metagpt/tools/tool_data_type.py b/metagpt/tools/tool_data_type.py
index 8206afa59..45fb539a6 100644
--- a/metagpt/tools/tool_data_type.py
+++ b/metagpt/tools/tool_data_type.py
@@ -10,6 +10,7 @@ class ToolTypeEnum(Enum):
MODEL_TRAIN = "model_train"
MODEL_EVALUATE = "model_evaluate"
STABLE_DIFFUSION = "stable_diffusion"
+ IMAGE2WEBPAGE = "image2webpage"
OTHER = "other"
def __missing__(self, key):
diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py
index 5d743358c..0544d25ee 100644
--- a/metagpt/tools/tool_registry.py
+++ b/metagpt/tools/tool_registry.py
@@ -21,7 +21,7 @@ class ToolRegistry:
def __init__(self):
self.tools = {}
self.tool_types = {}
- self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
+ self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
def register_tool_type(self, tool_type: ToolType):
self.tool_types[tool_type.name] = tool_type
@@ -33,13 +33,13 @@ class ToolRegistry:
tool_path,
schema_path=None,
tool_code="",
- tool_type_name="other",
+ tool_type="other",
make_schema_if_not_exists=False,
):
if self.has_tool(tool_name):
return
- schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type_name / f"{tool_name}.yml"
+ schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
if not os.path.exists(schema_path):
if make_schema_if_not_exists:
@@ -62,7 +62,7 @@ class ToolRegistry:
# )
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
self.tools[tool_name] = tool
- self.tools_by_types[tool_type_name][tool_name] = tool
+ self.tools_by_types[tool_type][tool_name] = tool
logger.info(f"{tool_name} registered")
def has_tool(self, key):
@@ -94,7 +94,7 @@ def register_tool_type(cls):
return cls
-def register_tool(tool_name="", tool_type_name="other", schema_path=None):
+def register_tool(tool_name="", tool_type="other", schema_path=None):
"""register a tool to registry"""
def decorator(cls, tool_name=tool_name):
@@ -111,7 +111,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
tool_path=file_path,
schema_path=schema_path,
tool_code=source_code,
- tool_type_name=tool_type_name,
+ tool_type=tool_type,
)
return cls
diff --git a/metagpt/tools/tool_types.py b/metagpt/tools/tool_types.py
index 2e22adc40..b5b233d53 100644
--- a/metagpt/tools/tool_types.py
+++ b/metagpt/tools/tool_types.py
@@ -1,6 +1,7 @@
from metagpt.prompts.tool_types import (
DATA_PREPROCESS_PROMPT,
FEATURE_ENGINEERING_PROMPT,
+ IMAGE2WEBPAGE_PROMPT,
MODEL_EVALUATE_PROMPT,
MODEL_TRAIN_PROMPT,
)
@@ -48,6 +49,13 @@ class StableDiffusion(ToolType):
desc: str = "Related to text2image, image2image using stable diffusion model."
+@register_tool_type
+class Image2Webpage(ToolType):
+ name: str = ToolTypeEnum.IMAGE2WEBPAGE.value
+ desc: str = "For converting image into webpage code."
+ usage_prompt: str = IMAGE2WEBPAGE_PROMPT
+
+
@register_tool_type
class Other(ToolType):
name: str = ToolTypeEnum.OTHER.value
diff --git a/tests/metagpt/tools/functions/libs/test_vision.py b/tests/metagpt/tools/functions/libs/test_vision.py
deleted file mode 100644
index f4f97c46a..000000000
--- a/tests/metagpt/tools/functions/libs/test_vision.py
+++ /dev/null
@@ -1,48 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
-"""
-@Time : 2024/01/15
-@Author : mannaandpoem
-@File : test_vision.py
-"""
-import pytest
-
-from metagpt import logs
-from metagpt.tools.functions.libs.vision import Vision
-
-
-@pytest.fixture
-def mock_webpages():
- return """```html\n\n
-\n\n```\n
-```css\n.class { ... }\n```\n
-```javascript\nfunction() { ... }\n```\n"""
-
-
-def test_vision_generate_webpages(mocker, mock_webpages):
- mocker.patch(
- "metagpt.tools.functions.libs.vision.Vision.generate_web_pages",
- return_value=mock_webpages
- )
- image_path = "image.png"
- vision = Vision()
- rsp = vision.generate_web_pages(image_path=image_path)
- logs.logger.info(rsp)
- assert "html" in rsp
- assert "css" in rsp
- assert "javascript" in rsp
-
-
-def test_save_webpages(mocker, mock_webpages):
- mocker.patch(
- "metagpt.tools.functions.libs.vision.Vision.generate_web_pages",
- return_value=mock_webpages
- )
- image_path = "image.png"
- vision = Vision()
- webpages = vision.generate_web_pages(image_path)
- webpages_dir = vision.save_webpages(image_path=image_path, webpages=webpages)
- logs.logger.info(webpages_dir)
- assert webpages_dir.exists()
-
-
diff --git a/tests/metagpt/tools/libs/test_gpt_v_generator.py b/tests/metagpt/tools/libs/test_gpt_v_generator.py
new file mode 100644
index 000000000..360ca4a75
--- /dev/null
+++ b/tests/metagpt/tools/libs/test_gpt_v_generator.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/01/15
+@Author : mannaandpoem
+@File : test_vision.py
+"""
+import pytest
+
+from metagpt import logs
+from metagpt.tools.libs.gpt_v_generator import GPTvGenerator
+
+
+@pytest.fixture
+def mock_webpages(mocker):
+ mock_data = """```html\n\n
+\n\n```\n
+```css\n.class { ... }\n```\n
+```javascript\nfunction() { ... }\n```\n"""
+ mocker.patch("metagpt.tools.libs.gpt_v_generator.GPTvGenerator.generate_web_pages", return_value=mock_data)
+ return mocker
+
+
+def test_vision_generate_webpages(mock_webpages):
+ image_path = "image.png"
+ generator = GPTvGenerator()
+ rsp = generator.generate_web_pages(image_path=image_path)
+ logs.logger.info(rsp)
+ assert "html" in rsp
+ assert "css" in rsp
+ assert "javascript" in rsp
+
+
+def test_save_webpages(mock_webpages):
+ image_path = "image.png"
+ generator = GPTvGenerator()
+ webpages = generator.generate_web_pages(image_path)
+ webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
+ logs.logger.info(webpages_dir)
+ assert webpages_dir.exists()
diff --git a/tests/metagpt/tools/libs/test_sd.py b/tests/metagpt/tools/libs/test_sd_engine.py
similarity index 100%
rename from tests/metagpt/tools/libs/test_sd.py
rename to tests/metagpt/tools/libs/test_sd_engine.py
diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py
index fd758b141..582c368a8 100644
--- a/tests/metagpt/tools/test_tool_registry.py
+++ b/tests/metagpt/tools/test_tool_registry.py
@@ -88,7 +88,7 @@ def test_get_tools_by_type(tool_registry, schema_yaml):
tool_type = ToolType(name=tool_type_name, desc="test")
tool_registry.register_tool_type(tool_type)
- tool_registry.register_tool(tool_name, tool_path, tool_type_name=tool_type_name)
+ tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
assert tools_by_type is not None