mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
change register arg name, integrate image2web tool
This commit is contained in:
parent
9dc421b122
commit
1cabf2c503
15 changed files with 100 additions and 103 deletions
|
|
@ -39,7 +39,7 @@ The current task is about evaluating a model, please note the following:
|
|||
"""
|
||||
|
||||
# Prompt for using tools of "vision" type
|
||||
VISION_PROMPT = """
|
||||
IMAGE2WEBPAGE_PROMPT = """
|
||||
The current task is about converting image into webpage code. please note the following:
|
||||
- Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow.
|
||||
"""
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ from metagpt.tools import tool_types # this registers all tool types
|
|||
from metagpt.tools import libs # this registers all tools
|
||||
from metagpt.tools.tool_registry import TOOL_REGISTRY
|
||||
|
||||
_ = tool_types # Avoid pre-commit error
|
||||
_ = libs # Avoid pre-commit error
|
||||
_ = TOOL_REGISTRY # Avoid pre-commit error
|
||||
_, _, _ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@
|
|||
from metagpt.tools.libs import (
|
||||
data_preprocess,
|
||||
feature_engineering,
|
||||
sd_engine,
|
||||
gpt_v_generator,
|
||||
)
|
||||
|
||||
_ = data_preprocess # Avoid pre-commit error
|
||||
_ = feature_engineering # Avoid pre-commit error
|
||||
_, _, _, _ = data_preprocess, feature_engineering, sd_engine, gpt_v_generator # Avoid pre-commit error
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class MLProcess(object):
|
|||
return self.transform(df)
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class FillMissingValue(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -58,7 +58,7 @@ class FillMissingValue(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MinMaxScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -77,7 +77,7 @@ class MinMaxScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class StandardScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -96,7 +96,7 @@ class StandardScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MaxAbsScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -115,7 +115,7 @@ class MaxAbsScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class RobustScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -134,7 +134,7 @@ class RobustScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OrdinalEncode(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -153,7 +153,7 @@ class OrdinalEncode(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OneHotEncode(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -175,7 +175,7 @@ class OneHotEncode(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class LabelEncode(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -204,7 +204,7 @@ class LabelEncode(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
def get_column_info(df: pd.DataFrame) -> dict:
|
||||
column_info = {
|
||||
"Category": [],
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from metagpt.tools.tool_registry import register_tool
|
|||
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class PolynomialExpansion(MLProcess):
|
||||
def __init__(self, cols: list, degree: int = 2, label_col: str = None):
|
||||
self.cols = cols
|
||||
|
|
@ -53,7 +53,7 @@ class PolynomialExpansion(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class CatCount(MLProcess):
|
||||
def __init__(self, col: str):
|
||||
self.col = col
|
||||
|
|
@ -68,7 +68,7 @@ class CatCount(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class TargetMeanEncoder(MLProcess):
|
||||
def __init__(self, col: str, label: str):
|
||||
self.col = col
|
||||
|
|
@ -84,7 +84,7 @@ class TargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class KFoldTargetMeanEncoder(MLProcess):
|
||||
def __init__(self, col: str, label: str, n_splits: int = 5, random_state: int = 2021):
|
||||
self.col = col
|
||||
|
|
@ -111,7 +111,7 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class CatCross(MLProcess):
|
||||
def __init__(self, cols: list, max_cat_num: int = 100):
|
||||
self.cols = cols
|
||||
|
|
@ -147,7 +147,7 @@ class CatCross(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class GroupStat(MLProcess):
|
||||
def __init__(self, group_col: str, agg_col: str, agg_funcs: list):
|
||||
self.group_col = group_col
|
||||
|
|
@ -167,7 +167,7 @@ class GroupStat(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class SplitBins(MLProcess):
|
||||
def __init__(self, cols: list, strategy: str = "quantile"):
|
||||
self.cols = cols
|
||||
|
|
@ -184,7 +184,7 @@ class SplitBins(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class ExtractTimeComps(MLProcess):
|
||||
def __init__(self, time_col: str, time_comps: list):
|
||||
self.time_col = time_col
|
||||
|
|
@ -213,7 +213,7 @@ class ExtractTimeComps(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class GeneralSelection(MLProcess):
|
||||
def __init__(self, label_col: str):
|
||||
self.label_col = label_col
|
||||
|
|
@ -284,7 +284,7 @@ class TreeBasedSelection(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class VarianceBasedSelection(MLProcess):
|
||||
def __init__(self, label_col: str, threshold: float = 0):
|
||||
self.label_col = label_col
|
||||
|
|
|
|||
|
|
@ -5,18 +5,13 @@
|
|||
@Author : mannaandpoem
|
||||
@File : vision.py
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
import base64
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL
|
||||
API_KEY = CONFIG.OPENAI_API_KEY
|
||||
MODEL = CONFIG.OPENAI_VISION_MODEL
|
||||
MAX_TOKENS = CONFIG.VISION_MAX_TOKENS
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
|
||||
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
|
||||
|
||||
|
|
@ -33,8 +28,15 @@ As the design pays tribute to large companies, sometimes it is normal for some c
|
|||
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
|
||||
|
||||
|
||||
class Vision:
|
||||
@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value)
|
||||
class GPTvGenerator:
|
||||
def __init__(self):
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL
|
||||
API_KEY = CONFIG.OPENAI_API_KEY
|
||||
MODEL = CONFIG.OPENAI_VISION_MODEL
|
||||
MAX_TOKENS = CONFIG.VISION_MAX_TOKENS
|
||||
self.api_key = API_KEY
|
||||
self.api_base = OPENAI_API_BASE
|
||||
self.model = MODEL
|
||||
|
|
@ -51,10 +53,7 @@ class Vision:
|
|||
|
||||
def get_result(self, image_path, prompt):
|
||||
base64_image = self.encode_image(image_path)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
|
|
@ -62,11 +61,8 @@ class Vision:
|
|||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
|
||||
}
|
||||
]
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_tokens,
|
||||
|
|
@ -81,7 +77,7 @@ class Vision:
|
|||
@staticmethod
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def save_webpages(image_path, webpages) -> Path:
|
||||
|
|
@ -13,7 +13,6 @@ import requests
|
|||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
|
|
@ -53,9 +52,11 @@ payload = {
|
|||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
||||
|
||||
@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION.value)
|
||||
@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value)
|
||||
class SDEngine:
|
||||
def __init__(self, sd_url=""):
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
Vision:
|
||||
GPTvGenerator:
|
||||
type: class
|
||||
description: "Class for generating web pages at once."
|
||||
methods:
|
||||
|
|
@ -10,6 +10,7 @@ class ToolTypeEnum(Enum):
|
|||
MODEL_TRAIN = "model_train"
|
||||
MODEL_EVALUATE = "model_evaluate"
|
||||
STABLE_DIFFUSION = "stable_diffusion"
|
||||
IMAGE2WEBPAGE = "image2webpage"
|
||||
OTHER = "other"
|
||||
|
||||
def __missing__(self, key):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class ToolRegistry:
|
|||
def __init__(self):
|
||||
self.tools = {}
|
||||
self.tool_types = {}
|
||||
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
|
||||
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
|
||||
|
||||
def register_tool_type(self, tool_type: ToolType):
|
||||
self.tool_types[tool_type.name] = tool_type
|
||||
|
|
@ -33,13 +33,13 @@ class ToolRegistry:
|
|||
tool_path,
|
||||
schema_path=None,
|
||||
tool_code="",
|
||||
tool_type_name="other",
|
||||
tool_type="other",
|
||||
make_schema_if_not_exists=False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type_name / f"{tool_name}.yml"
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
|
||||
if not os.path.exists(schema_path):
|
||||
if make_schema_if_not_exists:
|
||||
|
|
@ -62,7 +62,7 @@ class ToolRegistry:
|
|||
# )
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
|
||||
self.tools[tool_name] = tool
|
||||
self.tools_by_types[tool_type_name][tool_name] = tool
|
||||
self.tools_by_types[tool_type][tool_name] = tool
|
||||
logger.info(f"{tool_name} registered")
|
||||
|
||||
def has_tool(self, key):
|
||||
|
|
@ -94,7 +94,7 @@ def register_tool_type(cls):
|
|||
return cls
|
||||
|
||||
|
||||
def register_tool(tool_name="", tool_type_name="other", schema_path=None):
|
||||
def register_tool(tool_name="", tool_type="other", schema_path=None):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls, tool_name=tool_name):
|
||||
|
|
@ -111,7 +111,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
|
|||
tool_path=file_path,
|
||||
schema_path=schema_path,
|
||||
tool_code=source_code,
|
||||
tool_type_name=tool_type_name,
|
||||
tool_type=tool_type,
|
||||
)
|
||||
return cls
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from metagpt.prompts.tool_types import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
IMAGE2WEBPAGE_PROMPT,
|
||||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
|
|
@ -48,6 +49,13 @@ class StableDiffusion(ToolType):
|
|||
desc: str = "Related to text2image, image2image using stable diffusion model."
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class Image2Webpage(ToolType):
|
||||
name: str = ToolTypeEnum.IMAGE2WEBPAGE.value
|
||||
desc: str = "For converting image into webpage code."
|
||||
usage_prompt: str = IMAGE2WEBPAGE_PROMPT
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class Other(ToolType):
|
||||
name: str = ToolTypeEnum.OTHER.value
|
||||
|
|
|
|||
|
|
@ -1,48 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/15
|
||||
@Author : mannaandpoem
|
||||
@File : test_vision.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt import logs
|
||||
from metagpt.tools.functions.libs.vision import Vision
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_webpages():
|
||||
return """```html\n<html>\n<script src="scripts.js"></script>
|
||||
<link rel="stylesheet" href="styles.css(">\n</html>\n```\n
|
||||
```css\n.class { ... }\n```\n
|
||||
```javascript\nfunction() { ... }\n```\n"""
|
||||
|
||||
|
||||
def test_vision_generate_webpages(mocker, mock_webpages):
|
||||
mocker.patch(
|
||||
"metagpt.tools.functions.libs.vision.Vision.generate_web_pages",
|
||||
return_value=mock_webpages
|
||||
)
|
||||
image_path = "image.png"
|
||||
vision = Vision()
|
||||
rsp = vision.generate_web_pages(image_path=image_path)
|
||||
logs.logger.info(rsp)
|
||||
assert "html" in rsp
|
||||
assert "css" in rsp
|
||||
assert "javascript" in rsp
|
||||
|
||||
|
||||
def test_save_webpages(mocker, mock_webpages):
|
||||
mocker.patch(
|
||||
"metagpt.tools.functions.libs.vision.Vision.generate_web_pages",
|
||||
return_value=mock_webpages
|
||||
)
|
||||
image_path = "image.png"
|
||||
vision = Vision()
|
||||
webpages = vision.generate_web_pages(image_path)
|
||||
webpages_dir = vision.save_webpages(image_path=image_path, webpages=webpages)
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
|
||||
|
||||
40
tests/metagpt/tools/libs/test_gpt_v_generator.py
Normal file
40
tests/metagpt/tools/libs/test_gpt_v_generator.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/15
|
||||
@Author : mannaandpoem
|
||||
@File : test_vision.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt import logs
|
||||
from metagpt.tools.libs.gpt_v_generator import GPTvGenerator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_webpages(mocker):
|
||||
mock_data = """```html\n<html>\n<script src="scripts.js"></script>
|
||||
<link rel="stylesheet" href="styles.css(">\n</html>\n```\n
|
||||
```css\n.class { ... }\n```\n
|
||||
```javascript\nfunction() { ... }\n```\n"""
|
||||
mocker.patch("metagpt.tools.libs.gpt_v_generator.GPTvGenerator.generate_web_pages", return_value=mock_data)
|
||||
return mocker
|
||||
|
||||
|
||||
def test_vision_generate_webpages(mock_webpages):
|
||||
image_path = "image.png"
|
||||
generator = GPTvGenerator()
|
||||
rsp = generator.generate_web_pages(image_path=image_path)
|
||||
logs.logger.info(rsp)
|
||||
assert "html" in rsp
|
||||
assert "css" in rsp
|
||||
assert "javascript" in rsp
|
||||
|
||||
|
||||
def test_save_webpages(mock_webpages):
|
||||
image_path = "image.png"
|
||||
generator = GPTvGenerator()
|
||||
webpages = generator.generate_web_pages(image_path)
|
||||
webpages_dir = generator.save_webpages(image_path=image_path, webpages=webpages)
|
||||
logs.logger.info(webpages_dir)
|
||||
assert webpages_dir.exists()
|
||||
|
|
@ -88,7 +88,7 @@ def test_get_tools_by_type(tool_registry, schema_yaml):
|
|||
tool_type = ToolType(name=tool_type_name, desc="test")
|
||||
tool_registry.register_tool_type(tool_type)
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type_name=tool_type_name)
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue