From 1cabf2c503f2de5c037049af78923ad2faa2be4a Mon Sep 17 00:00:00 2001 From: yzlin Date: Thu, 18 Jan 2024 20:34:32 +0800 Subject: [PATCH] change register arg name, integrate image2web tool --- metagpt/prompts/tool_types.py | 4 +- metagpt/tools/__init__.py | 4 +- metagpt/tools/libs/__init__.py | 5 +- metagpt/tools/libs/data_preprocess.py | 18 +++---- metagpt/tools/libs/feature_engineering.py | 20 ++++---- .../vision.py => libs/gpt_v_generator.py} | 34 ++++++------- metagpt/tools/libs/sd_engine.py | 5 +- .../image2webpage/GPTvGenerator.yml} | 2 +- metagpt/tools/tool_data_type.py | 1 + metagpt/tools/tool_registry.py | 12 ++--- metagpt/tools/tool_types.py | 8 ++++ .../tools/functions/libs/test_vision.py | 48 ------------------- .../tools/libs/test_gpt_v_generator.py | 40 ++++++++++++++++ .../libs/{test_sd.py => test_sd_engine.py} | 0 tests/metagpt/tools/test_tool_registry.py | 2 +- 15 files changed, 100 insertions(+), 103 deletions(-) rename metagpt/tools/{functions/libs/vision.py => libs/gpt_v_generator.py} (85%) rename metagpt/tools/{functions/schemas/vision.yml => schemas/image2webpage/GPTvGenerator.yml} (93%) delete mode 100644 tests/metagpt/tools/functions/libs/test_vision.py create mode 100644 tests/metagpt/tools/libs/test_gpt_v_generator.py rename tests/metagpt/tools/libs/{test_sd.py => test_sd_engine.py} (100%) diff --git a/metagpt/prompts/tool_types.py b/metagpt/prompts/tool_types.py index 43ead78a6..c01a80310 100644 --- a/metagpt/prompts/tool_types.py +++ b/metagpt/prompts/tool_types.py @@ -39,7 +39,7 @@ The current task is about evaluating a model, please note the following: """ # Prompt for using tools of "vision" type -VISION_PROMPT = """ +IMAGE2WEBPAGE_PROMPT = """ The current task is about converting image into webpage code. please note the following: - Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow. -""" \ No newline at end of file +""" diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index 23b51533d..f18d1d276 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -11,9 +11,7 @@ from metagpt.tools import tool_types # this registers all tool types from metagpt.tools import libs # this registers all tools from metagpt.tools.tool_registry import TOOL_REGISTRY -_ = tool_types # Avoid pre-commit error -_ = libs # Avoid pre-commit error -_ = TOOL_REGISTRY # Avoid pre-commit error +_, _, _ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error class SearchEngineType(Enum): diff --git a/metagpt/tools/libs/__init__.py b/metagpt/tools/libs/__init__.py index 3d74674aa..b576997c9 100644 --- a/metagpt/tools/libs/__init__.py +++ b/metagpt/tools/libs/__init__.py @@ -7,7 +7,8 @@ from metagpt.tools.libs import ( data_preprocess, feature_engineering, + sd_engine, + gpt_v_generator, ) -_ = data_preprocess # Avoid pre-commit error -_ = feature_engineering # Avoid pre-commit error +_, _, _, _ = data_preprocess, feature_engineering, sd_engine, gpt_v_generator # Avoid pre-commit error diff --git a/metagpt/tools/libs/data_preprocess.py b/metagpt/tools/libs/data_preprocess.py index 7cc44263d..3891f9df0 100644 --- a/metagpt/tools/libs/data_preprocess.py +++ b/metagpt/tools/libs/data_preprocess.py @@ -31,7 +31,7 @@ class MLProcess(object): return self.transform(df) -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class FillMissingValue(MLProcess): def __init__( self, @@ -58,7 +58,7 @@ class FillMissingValue(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class MinMaxScale(MLProcess): def __init__( self, @@ -77,7 +77,7 @@ class MinMaxScale(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class StandardScale(MLProcess): def __init__( self, @@ -96,7 +96,7 @@ class StandardScale(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class MaxAbsScale(MLProcess): def __init__( self, @@ -115,7 +115,7 @@ class MaxAbsScale(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class RobustScale(MLProcess): def __init__( self, @@ -134,7 +134,7 @@ class RobustScale(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class OrdinalEncode(MLProcess): def __init__( self, @@ -153,7 +153,7 @@ class OrdinalEncode(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class OneHotEncode(MLProcess): def __init__( self, @@ -175,7 +175,7 @@ class OneHotEncode(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class LabelEncode(MLProcess): def __init__( self, @@ -204,7 +204,7 @@ class LabelEncode(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) def get_column_info(df: pd.DataFrame) -> dict: column_info = { "Category": [], diff --git a/metagpt/tools/libs/feature_engineering.py b/metagpt/tools/libs/feature_engineering.py index ed5c1be72..308150f9b 100644 --- a/metagpt/tools/libs/feature_engineering.py +++ b/metagpt/tools/libs/feature_engineering.py @@ -22,7 +22,7 @@ from metagpt.tools.tool_registry import register_tool TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class PolynomialExpansion(MLProcess): def __init__(self, cols: list, degree: int = 2, label_col: str = None): self.cols = cols @@ -53,7 +53,7 @@ class PolynomialExpansion(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class CatCount(MLProcess): def __init__(self, col: str): self.col = col @@ -68,7 +68,7 @@ class CatCount(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class TargetMeanEncoder(MLProcess): def __init__(self, col: str, label: str): self.col = col @@ -84,7 +84,7 @@ class TargetMeanEncoder(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class KFoldTargetMeanEncoder(MLProcess): def __init__(self, col: str, label: str, n_splits: int = 5, random_state: int = 2021): self.col = col @@ -111,7 +111,7 @@ class KFoldTargetMeanEncoder(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class CatCross(MLProcess): def __init__(self, cols: list, max_cat_num: int = 100): self.cols = cols @@ -147,7 +147,7 @@ class CatCross(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class GroupStat(MLProcess): def __init__(self, group_col: str, agg_col: str, agg_funcs: list): self.group_col = group_col @@ -167,7 +167,7 @@ class GroupStat(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class SplitBins(MLProcess): def __init__(self, cols: list, strategy: str = "quantile"): self.cols = cols @@ -184,7 +184,7 @@ class SplitBins(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class ExtractTimeComps(MLProcess): def __init__(self, time_col: str, time_comps: list): self.time_col = time_col @@ -213,7 +213,7 @@ class ExtractTimeComps(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class GeneralSelection(MLProcess): def __init__(self, label_col: str): self.label_col = label_col @@ -284,7 +284,7 @@ class TreeBasedSelection(MLProcess): return new_df -@register_tool(tool_type_name=TOOL_TYPE) +@register_tool(tool_type=TOOL_TYPE) class VarianceBasedSelection(MLProcess): def __init__(self, label_col: str, threshold: float = 0): self.label_col = label_col diff --git a/metagpt/tools/functions/libs/vision.py b/metagpt/tools/libs/gpt_v_generator.py similarity index 85% rename from metagpt/tools/functions/libs/vision.py rename to metagpt/tools/libs/gpt_v_generator.py index b10ad7608..58e547840 100644 --- a/metagpt/tools/functions/libs/vision.py +++ b/metagpt/tools/libs/gpt_v_generator.py @@ -5,18 +5,13 @@ @Author : mannaandpoem @File : vision.py """ +import base64 from pathlib import Path import requests -import base64 - -from metagpt.config import CONFIG - -OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL -API_KEY = CONFIG.OPENAI_API_KEY -MODEL = CONFIG.OPENAI_VISION_MODEL -MAX_TOKENS = CONFIG.VISION_MAX_TOKENS +from metagpt.tools.tool_data_type import ToolTypeEnum +from metagpt.tools.tool_registry import register_tool ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image: @@ -33,8 +28,15 @@ As the design pays tribute to large companies, sometimes it is normal for some c Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:""" -class Vision: +@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value) +class GPTvGenerator: def __init__(self): + from metagpt.config import CONFIG + + OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL + API_KEY = CONFIG.OPENAI_API_KEY + MODEL = CONFIG.OPENAI_VISION_MODEL + MAX_TOKENS = CONFIG.VISION_MAX_TOKENS self.api_key = API_KEY self.api_base = OPENAI_API_BASE self.model = MODEL @@ -51,10 +53,7 @@ class Vision: def get_result(self, image_path, prompt): base64_image = self.encode_image(image_path) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" - } + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} payload = { "model": self.model, "messages": [ @@ -62,11 +61,8 @@ class Vision: "role": "user", "content": [ {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} - } - ] + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, + ], } ], "max_tokens": self.max_tokens, @@ -81,7 +77,7 @@ class Vision: @staticmethod def encode_image(image_path): with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') + return base64.b64encode(image_file.read()).decode("utf-8") @staticmethod def save_webpages(image_path, webpages) -> Path: diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index ad63c2505..794758f77 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -13,7 +13,6 @@ import requests from aiohttp import ClientSession from PIL import Image, PngImagePlugin -from metagpt.config import CONFIG from metagpt.const import SD_OUTPUT_FILE_REPO from metagpt.logs import logger from metagpt.tools.tool_data_type import ToolTypeEnum @@ -53,9 +52,11 @@ payload = { default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" -@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION.value) +@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value) class SDEngine: def __init__(self, sd_url=""): + from metagpt.config import CONFIG + # Initialize the SDEngine with configuration self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL") self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}" diff --git a/metagpt/tools/functions/schemas/vision.yml b/metagpt/tools/schemas/image2webpage/GPTvGenerator.yml similarity index 93% rename from metagpt/tools/functions/schemas/vision.yml rename to metagpt/tools/schemas/image2webpage/GPTvGenerator.yml index 4cb247419..4087f7c12 100644 --- a/metagpt/tools/functions/schemas/vision.yml +++ b/metagpt/tools/schemas/image2webpage/GPTvGenerator.yml @@ -1,4 +1,4 @@ -Vision: +GPTvGenerator: type: class description: "Class for generating web pages at once." methods: diff --git a/metagpt/tools/tool_data_type.py b/metagpt/tools/tool_data_type.py index 8206afa59..45fb539a6 100644 --- a/metagpt/tools/tool_data_type.py +++ b/metagpt/tools/tool_data_type.py @@ -10,6 +10,7 @@ class ToolTypeEnum(Enum): MODEL_TRAIN = "model_train" MODEL_EVALUATE = "model_evaluate" STABLE_DIFFUSION = "stable_diffusion" + IMAGE2WEBPAGE = "image2webpage" OTHER = "other" def __missing__(self, key): diff --git a/metagpt/tools/tool_registry.py b/metagpt/tools/tool_registry.py index 5d743358c..0544d25ee 100644 --- a/metagpt/tools/tool_registry.py +++ b/metagpt/tools/tool_registry.py @@ -21,7 +21,7 @@ class ToolRegistry: def __init__(self): self.tools = {} self.tool_types = {} - self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...} + self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...} def register_tool_type(self, tool_type: ToolType): self.tool_types[tool_type.name] = tool_type @@ -33,13 +33,13 @@ class ToolRegistry: tool_path, schema_path=None, tool_code="", - tool_type_name="other", + tool_type="other", make_schema_if_not_exists=False, ): if self.has_tool(tool_name): return - schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type_name / f"{tool_name}.yml" + schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml" if not os.path.exists(schema_path): if make_schema_if_not_exists: @@ -62,7 +62,7 @@ class ToolRegistry: # ) tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code) self.tools[tool_name] = tool - self.tools_by_types[tool_type_name][tool_name] = tool + self.tools_by_types[tool_type][tool_name] = tool logger.info(f"{tool_name} registered") def has_tool(self, key): @@ -94,7 +94,7 @@ def register_tool_type(cls): return cls -def register_tool(tool_name="", tool_type_name="other", schema_path=None): +def register_tool(tool_name="", tool_type="other", schema_path=None): """register a tool to registry""" def decorator(cls, tool_name=tool_name): @@ -111,7 +111,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None): tool_path=file_path, schema_path=schema_path, tool_code=source_code, - tool_type_name=tool_type_name, + tool_type=tool_type, ) return cls diff --git a/metagpt/tools/tool_types.py b/metagpt/tools/tool_types.py index 2e22adc40..b5b233d53 100644 --- a/metagpt/tools/tool_types.py +++ b/metagpt/tools/tool_types.py @@ -1,6 +1,7 @@ from metagpt.prompts.tool_types import ( DATA_PREPROCESS_PROMPT, FEATURE_ENGINEERING_PROMPT, + IMAGE2WEBPAGE_PROMPT, MODEL_EVALUATE_PROMPT, MODEL_TRAIN_PROMPT, ) @@ -48,6 +49,13 @@ class StableDiffusion(ToolType): desc: str = "Related to text2image, image2image using stable diffusion model." +@register_tool_type +class Image2Webpage(ToolType): + name: str = ToolTypeEnum.IMAGE2WEBPAGE.value + desc: str = "For converting image into webpage code." + usage_prompt: str = IMAGE2WEBPAGE_PROMPT + + @register_tool_type class Other(ToolType): name: str = ToolTypeEnum.OTHER.value diff --git a/tests/metagpt/tools/functions/libs/test_vision.py b/tests/metagpt/tools/functions/libs/test_vision.py deleted file mode 100644 index f4f97c46a..000000000 --- a/tests/metagpt/tools/functions/libs/test_vision.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2024/01/15 -@Author : mannaandpoem -@File : test_vision.py -""" -import pytest - -from metagpt import logs -from metagpt.tools.functions.libs.vision import Vision - - -@pytest.fixture -def mock_webpages(): - return """```html\n\n -\n\n```\n -```css\n.class { ... }\n```\n -```javascript\nfunction() { ... }\n```\n""" - - -def test_vision_generate_webpages(mocker, mock_webpages): - mocker.patch( - "metagpt.tools.functions.libs.vision.Vision.generate_web_pages", - return_value=mock_webpages - ) - image_path = "image.png" - vision = Vision() - rsp = vision.generate_web_pages(image_path=image_path) - logs.logger.info(rsp) - assert "html" in rsp - assert "css" in rsp - assert "javascript" in rsp - - -def test_save_webpages(mocker, mock_webpages): - mocker.patch( - "metagpt.tools.functions.libs.vision.Vision.generate_web_pages", - return_value=mock_webpages - ) - image_path = "image.png" - vision = Vision() - webpages = vision.generate_web_pages(image_path) - webpages_dir = vision.save_webpages(image_path=image_path, webpages=webpages) - logs.logger.info(webpages_dir) - assert webpages_dir.exists() - - diff --git a/tests/metagpt/tools/libs/test_gpt_v_generator.py b/tests/metagpt/tools/libs/test_gpt_v_generator.py new file mode 100644 index 000000000..360ca4a75 --- /dev/null +++ b/tests/metagpt/tools/libs/test_gpt_v_generator.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/01/15 +@Author : mannaandpoem +@File : test_vision.py +""" +import pytest + +from metagpt import logs +from metagpt.tools.libs.gpt_v_generator import GPTvGenerator + + +@pytest.fixture +def mock_webpages(mocker): + mock_data = """```html\n\n +\n\n```\n +```css\n.class { ... }\n```\n +```javascript\nfunction() { ... }\n```\n""" + mocker.patch("metagpt.tools.libs.gpt_v_generator.GPTvGenerator.generate_web_pages", return_value=mock_data) + return mocker + + +def test_vision_generate_webpages(mock_webpages): + image_path = "image.png" + generator = GPTvGenerator() + rsp = generator.generate_web_pages(image_path=image_path) + logs.logger.info(rsp) + assert "html" in rsp + assert "css" in rsp + assert "javascript" in rsp + + +def test_save_webpages(mock_webpages): + image_path = "image.png" + generator = GPTvGenerator() + webpages = generator.generate_web_pages(image_path) + webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages) + logs.logger.info(webpages_dir) + assert webpages_dir.exists() diff --git a/tests/metagpt/tools/libs/test_sd.py b/tests/metagpt/tools/libs/test_sd_engine.py similarity index 100% rename from tests/metagpt/tools/libs/test_sd.py rename to tests/metagpt/tools/libs/test_sd_engine.py diff --git a/tests/metagpt/tools/test_tool_registry.py b/tests/metagpt/tools/test_tool_registry.py index fd758b141..582c368a8 100644 --- a/tests/metagpt/tools/test_tool_registry.py +++ b/tests/metagpt/tools/test_tool_registry.py @@ -88,7 +88,7 @@ def test_get_tools_by_type(tool_registry, schema_yaml): tool_type = ToolType(name=tool_type_name, desc="test") tool_registry.register_tool_type(tool_type) - tool_registry.register_tool(tool_name, tool_path, tool_type_name=tool_type_name) + tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name) tools_by_type = tool_registry.get_tools_by_type(tool_type_name) assert tools_by_type is not None