mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
change register arg name, integrate image2web tool
This commit is contained in:
parent
9dc421b122
commit
1cabf2c503
15 changed files with 100 additions and 103 deletions
|
|
@ -11,9 +11,7 @@ from metagpt.tools import tool_types # this registers all tool types
|
|||
from metagpt.tools import libs # this registers all tools
|
||||
from metagpt.tools.tool_registry import TOOL_REGISTRY
|
||||
|
||||
_ = tool_types # Avoid pre-commit error
|
||||
_ = libs # Avoid pre-commit error
|
||||
_ = TOOL_REGISTRY # Avoid pre-commit error
|
||||
_, _, _ = tool_types, libs, TOOL_REGISTRY # Avoid pre-commit error
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@
|
|||
from metagpt.tools.libs import (
|
||||
data_preprocess,
|
||||
feature_engineering,
|
||||
sd_engine,
|
||||
gpt_v_generator,
|
||||
)
|
||||
|
||||
_ = data_preprocess # Avoid pre-commit error
|
||||
_ = feature_engineering # Avoid pre-commit error
|
||||
_, _, _, _ = data_preprocess, feature_engineering, sd_engine, gpt_v_generator # Avoid pre-commit error
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class MLProcess(object):
|
|||
return self.transform(df)
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class FillMissingValue(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -58,7 +58,7 @@ class FillMissingValue(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MinMaxScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -77,7 +77,7 @@ class MinMaxScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class StandardScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -96,7 +96,7 @@ class StandardScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class MaxAbsScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -115,7 +115,7 @@ class MaxAbsScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class RobustScale(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -134,7 +134,7 @@ class RobustScale(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OrdinalEncode(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -153,7 +153,7 @@ class OrdinalEncode(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class OneHotEncode(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -175,7 +175,7 @@ class OneHotEncode(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class LabelEncode(MLProcess):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -204,7 +204,7 @@ class LabelEncode(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
def get_column_info(df: pd.DataFrame) -> dict:
|
||||
column_info = {
|
||||
"Category": [],
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from metagpt.tools.tool_registry import register_tool
|
|||
TOOL_TYPE = ToolTypeEnum.FEATURE_ENGINEERING.value
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class PolynomialExpansion(MLProcess):
|
||||
def __init__(self, cols: list, degree: int = 2, label_col: str = None):
|
||||
self.cols = cols
|
||||
|
|
@ -53,7 +53,7 @@ class PolynomialExpansion(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class CatCount(MLProcess):
|
||||
def __init__(self, col: str):
|
||||
self.col = col
|
||||
|
|
@ -68,7 +68,7 @@ class CatCount(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class TargetMeanEncoder(MLProcess):
|
||||
def __init__(self, col: str, label: str):
|
||||
self.col = col
|
||||
|
|
@ -84,7 +84,7 @@ class TargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class KFoldTargetMeanEncoder(MLProcess):
|
||||
def __init__(self, col: str, label: str, n_splits: int = 5, random_state: int = 2021):
|
||||
self.col = col
|
||||
|
|
@ -111,7 +111,7 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class CatCross(MLProcess):
|
||||
def __init__(self, cols: list, max_cat_num: int = 100):
|
||||
self.cols = cols
|
||||
|
|
@ -147,7 +147,7 @@ class CatCross(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class GroupStat(MLProcess):
|
||||
def __init__(self, group_col: str, agg_col: str, agg_funcs: list):
|
||||
self.group_col = group_col
|
||||
|
|
@ -167,7 +167,7 @@ class GroupStat(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class SplitBins(MLProcess):
|
||||
def __init__(self, cols: list, strategy: str = "quantile"):
|
||||
self.cols = cols
|
||||
|
|
@ -184,7 +184,7 @@ class SplitBins(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class ExtractTimeComps(MLProcess):
|
||||
def __init__(self, time_col: str, time_comps: list):
|
||||
self.time_col = time_col
|
||||
|
|
@ -213,7 +213,7 @@ class ExtractTimeComps(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class GeneralSelection(MLProcess):
|
||||
def __init__(self, label_col: str):
|
||||
self.label_col = label_col
|
||||
|
|
@ -284,7 +284,7 @@ class TreeBasedSelection(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type_name=TOOL_TYPE)
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
class VarianceBasedSelection(MLProcess):
|
||||
def __init__(self, label_col: str, threshold: float = 0):
|
||||
self.label_col = label_col
|
||||
|
|
|
|||
|
|
@ -5,18 +5,13 @@
|
|||
@Author : mannaandpoem
|
||||
@File : vision.py
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
import base64
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL
|
||||
API_KEY = CONFIG.OPENAI_API_KEY
|
||||
MODEL = CONFIG.OPENAI_VISION_MODEL
|
||||
MAX_TOKENS = CONFIG.VISION_MAX_TOKENS
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
|
||||
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
|
||||
|
||||
|
|
@ -33,8 +28,15 @@ As the design pays tribute to large companies, sometimes it is normal for some c
|
|||
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
|
||||
|
||||
|
||||
class Vision:
|
||||
@register_tool(tool_type=ToolTypeEnum.IMAGE2WEBPAGE.value)
|
||||
class GPTvGenerator:
|
||||
def __init__(self):
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
OPENAI_API_BASE = CONFIG.OPENAI_BASE_URL
|
||||
API_KEY = CONFIG.OPENAI_API_KEY
|
||||
MODEL = CONFIG.OPENAI_VISION_MODEL
|
||||
MAX_TOKENS = CONFIG.VISION_MAX_TOKENS
|
||||
self.api_key = API_KEY
|
||||
self.api_base = OPENAI_API_BASE
|
||||
self.model = MODEL
|
||||
|
|
@ -51,10 +53,7 @@ class Vision:
|
|||
|
||||
def get_result(self, image_path, prompt):
|
||||
base64_image = self.encode_image(image_path)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
|
|
@ -62,11 +61,8 @@ class Vision:
|
|||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
|
||||
}
|
||||
]
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_tokens,
|
||||
|
|
@ -81,7 +77,7 @@ class Vision:
|
|||
@staticmethod
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def save_webpages(image_path, webpages) -> Path:
|
||||
|
|
@ -13,7 +13,6 @@ import requests
|
|||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
|
|
@ -53,9 +52,11 @@ payload = {
|
|||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
||||
|
||||
@register_tool(tool_type_name=ToolTypeEnum.STABLE_DIFFUSION.value)
|
||||
@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value)
|
||||
class SDEngine:
|
||||
def __init__(self, sd_url=""):
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
Vision:
|
||||
GPTvGenerator:
|
||||
type: class
|
||||
description: "Class for generating web pages at once."
|
||||
methods:
|
||||
|
|
@ -10,6 +10,7 @@ class ToolTypeEnum(Enum):
|
|||
MODEL_TRAIN = "model_train"
|
||||
MODEL_EVALUATE = "model_evaluate"
|
||||
STABLE_DIFFUSION = "stable_diffusion"
|
||||
IMAGE2WEBPAGE = "image2webpage"
|
||||
OTHER = "other"
|
||||
|
||||
def __missing__(self, key):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ class ToolRegistry:
|
|||
def __init__(self):
|
||||
self.tools = {}
|
||||
self.tool_types = {}
|
||||
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type_name: {tool_name: {...}, ...}, ...}
|
||||
self.tools_by_types = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
|
||||
|
||||
def register_tool_type(self, tool_type: ToolType):
|
||||
self.tool_types[tool_type.name] = tool_type
|
||||
|
|
@ -33,13 +33,13 @@ class ToolRegistry:
|
|||
tool_path,
|
||||
schema_path=None,
|
||||
tool_code="",
|
||||
tool_type_name="other",
|
||||
tool_type="other",
|
||||
make_schema_if_not_exists=False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type_name / f"{tool_name}.yml"
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
|
||||
if not os.path.exists(schema_path):
|
||||
if make_schema_if_not_exists:
|
||||
|
|
@ -62,7 +62,7 @@ class ToolRegistry:
|
|||
# )
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
|
||||
self.tools[tool_name] = tool
|
||||
self.tools_by_types[tool_type_name][tool_name] = tool
|
||||
self.tools_by_types[tool_type][tool_name] = tool
|
||||
logger.info(f"{tool_name} registered")
|
||||
|
||||
def has_tool(self, key):
|
||||
|
|
@ -94,7 +94,7 @@ def register_tool_type(cls):
|
|||
return cls
|
||||
|
||||
|
||||
def register_tool(tool_name="", tool_type_name="other", schema_path=None):
|
||||
def register_tool(tool_name="", tool_type="other", schema_path=None):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls, tool_name=tool_name):
|
||||
|
|
@ -111,7 +111,7 @@ def register_tool(tool_name="", tool_type_name="other", schema_path=None):
|
|||
tool_path=file_path,
|
||||
schema_path=schema_path,
|
||||
tool_code=source_code,
|
||||
tool_type_name=tool_type_name,
|
||||
tool_type=tool_type,
|
||||
)
|
||||
return cls
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from metagpt.prompts.tool_types import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
IMAGE2WEBPAGE_PROMPT,
|
||||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
|
|
@ -48,6 +49,13 @@ class StableDiffusion(ToolType):
|
|||
desc: str = "Related to text2image, image2image using stable diffusion model."
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class Image2Webpage(ToolType):
|
||||
name: str = ToolTypeEnum.IMAGE2WEBPAGE.value
|
||||
desc: str = "For converting image into webpage code."
|
||||
usage_prompt: str = IMAGE2WEBPAGE_PROMPT
|
||||
|
||||
|
||||
@register_tool_type
|
||||
class Other(ToolType):
|
||||
name: str = ToolTypeEnum.OTHER.value
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue