remove get_result method and improve gpt_v_generator.py and test_gpt_v_generator.py.

This commit is contained in:
mannaandpoem 2024-02-05 21:48:59 +08:00
parent 23c27627ce
commit 6015c16618
3 changed files with 88 additions and 85 deletions

View file

@ -5,15 +5,13 @@
@Author : mannaandpoem
@File : gpt_v_generator.py
"""
import base64
import os
from pathlib import Path
import requests
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.tools.tool_registry import register_tool
from metagpt.tools.tool_type import ToolType
from metagpt.utils.common import encode_image
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX, please generate layout information for this image:
@ -43,27 +41,26 @@ class GPTvGenerator:
def __init__(self):
"""Initialize GPTvGenerator class with default values from the configuration."""
from metagpt.config2 import config
from metagpt.llm import LLM
self.api_key = config.llm.api_key
self.api_base = config.llm.base_url
self.model = config.openai_vision_model
self.max_tokens = config.vision_max_tokens
self.llm = LLM(llm_config=config.get_openai_llm())
self.llm.model = "gpt-4-vision-preview"
def analyze_layout(self, image_path):
"""Analyze the layout of the given image and return the result.
async def analyze_layout(self, image_path: Path) -> str:
"""Asynchronously analyze the layout of the given image and return the result.
This is a helper method to generate a layout description based on the image.
Args:
image_path (str): Path of the image to analyze.
image_path (Path): Path of the image to analyze.
Returns:
str: The layout analysis result.
"""
return self.get_result(image_path, ANALYZE_LAYOUT_PROMPT)
return await self.llm.aask(msg=ANALYZE_LAYOUT_PROMPT, images=[encode_image(image_path)])
def generate_webpages(self, image_path):
"""Generate webpages including all code (HTML, CSS, and JavaScript) in one go based on the image.
async def generate_webpages(self, image_path: str) -> str:
"""Asynchronously generate webpages including all code (HTML, CSS, and JavaScript) in one go based on the image.
Args:
image_path (str): The path of the image file.
@ -71,58 +68,14 @@ class GPTvGenerator:
Returns:
str: Generated webpages content.
"""
layout = self.analyze_layout(image_path)
if isinstance(image_path, str):
image_path = Path(image_path)
layout = await self.analyze_layout(image_path)
prompt = GENERATE_PROMPT + "\n\n # Context\n The layout information of the sketch image is: \n" + layout
result = self.get_result(image_path, prompt)
return result
def get_result(self, image_path, prompt):
"""Get the result from the vision model based on the given image path and prompt.
Args:
image_path (str): Path of the image to analyze.
prompt (str): Prompt to use for the analysis.
Returns:
str: The model's response as a string.
"""
base64_image = self.encode_image(image_path)
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
],
}
],
"max_tokens": self.max_tokens,
}
response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
if response.status_code != 200:
raise ValueError(f"Request failed with status {response.status_code}, {response.text}")
else:
return response.json()["choices"][0]["message"]["content"]
return await self.llm.aask(msg=prompt, images=[encode_image(image_path)])
@staticmethod
def encode_image(image_path):
"""Encode the image at the given path to a base64 string.
Args:
image_path (str): Path of the image to encode.
Returns:
str: The base64 encoded string of the image.
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@staticmethod
def save_webpages(image_path, webpages) -> Path:
def save_webpages(image_path: str, webpages: str) -> Path:
"""Save webpages including all code (HTML, CSS, and JavaScript) at once.
Args:
@ -132,35 +85,29 @@ class GPTvGenerator:
Returns:
Path: The path of the saved webpages.
"""
# 在workspace目录下创建一个名为下webpages的文件夹用于存储html、css和js文件
# Create a folder called webpages in the workspace directory to store HTML, CSS, and JavaScript files
webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / Path(image_path).stem
os.makedirs(webpages_path, exist_ok=True)
index_path = webpages_path / "index.html"
try:
index = webpages.split("```html")[1].split("```")[0]
except IndexError:
index = "No html code found in the result, please check your image and try again." + "\n" + webpages
try:
style_path = None
if "styles.css" in index:
style_path = webpages_path / "styles.css"
elif "style.css" in index:
style_path = webpages_path / "style.css"
else:
style_path = None
style = webpages.split("```css")[1].split("```")[0] if style_path else ""
js_path = None
if "scripts.js" in index:
js_path = webpages_path / "scripts.js"
elif "script.js" in index:
js_path = webpages_path / "script.js"
else:
js_path = None
js = webpages.split("```javascript")[1].split("```")[0] if js_path else ""
except IndexError:
raise ValueError("No css or js code found in the result, please check your image and try again.")
raise ValueError(f"No html or css or js code found in the result. \nWebpages: {webpages}")
try:
with open(index_path, "w", encoding="utf-8") as f:

View file

@ -15,7 +15,8 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
# method_doc = inspect.getdoc(method)
method_doc = get_class_method_docstring(obj, name)
if method_doc:
schema["methods"][name] = docstring_to_schema(method_doc)
function_type = "function" if not inspect.iscoroutinefunction(method) else "async_function"
schema["methods"][name] = {"type": function_type, **docstring_to_schema(method_doc)}
elif inspect.isfunction(obj):
schema = {