From 7f1584db9e5bd153f5f78f2f79d3d16970f20f0c Mon Sep 17 00:00:00 2001 From: mannaandpoem <1580466765@qq.com> Date: Mon, 15 Jan 2024 17:26:35 +0800 Subject: [PATCH] 1. add test_vision.py 2. add save_webpages function in vision.py and vision.yml --- metagpt/tools/functions/libs/vision.py | 64 +++++++++++++++++--- metagpt/tools/functions/schemas/vision.yml | 20 +++++- tests/metagpt/tools/functions/test_vision.py | 40 ++++++++++++ 3 files changed, 113 insertions(+), 11 deletions(-) create mode 100644 tests/metagpt/tools/functions/test_vision.py diff --git a/metagpt/tools/functions/libs/vision.py b/metagpt/tools/functions/libs/vision.py index 8c29b0567..b10ad7608 100644 --- a/metagpt/tools/functions/libs/vision.py +++ b/metagpt/tools/functions/libs/vision.py @@ -5,6 +5,8 @@ @Author : mannaandpoem @File : vision.py """ +from pathlib import Path + import requests import base64 @@ -34,8 +36,9 @@ Now, please generate the corresponding webpage code including HTML, CSS and Java class Vision: def __init__(self): self.api_key = API_KEY + self.api_base = OPENAI_API_BASE self.model = MODEL - self.max_tokens = 4096 + self.max_tokens = MAX_TOKENS def analyze_layout(self, image_path): return self.get_result(image_path, ANALYZE_LAYOUT_PROMPT) @@ -43,7 +46,8 @@ class Vision: def generate_web_pages(self, image_path): layout = self.analyze_layout(image_path) prompt = GENERATE_PROMPT + "\n\n # Context\n The layout information of the sketch image is: \n" + layout - return self.get_result(image_path, prompt) + result = self.get_result(image_path, prompt) + return result def get_result(self, image_path, prompt): base64_image = self.encode_image(image_path) @@ -67,17 +71,59 @@ class Vision: ], "max_tokens": self.max_tokens, } - response = requests.post(f"{OPENAI_API_BASE}/chat/completions", headers=headers, json=payload) - return response.json()["choices"][0]["message"]["content"] + response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload) + + if response.status_code != 200: + raise ValueError(f"Request failed with status {response.status_code}, {response.text}") + else: + return response.json()["choices"][0]["message"]["content"] @staticmethod def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') + @staticmethod + def save_webpages(image_path, webpages) -> Path: + # 在当前目录下创建一个名为webpages的文件夹,用于存储html、css和js文件 + webpages_path = Path(image_path).parent / "webpages" + webpages_path.mkdir(exist_ok=True) -if __name__ == "__main__": - image_path = "image.png" - vision = Vision() - rsp = vision.generate_web_pages(image_path=image_path) - print(rsp) + try: + index_path = webpages_path / "index.html" + index = webpages.split("```html")[1].split("```")[0] + except IndexError: + raise ValueError("No html code found in the result, please check your image and try again.") + + try: + if "styles.css" in index: + style_path = webpages_path / "styles.css" + elif "style.css" in index: + style_path = webpages_path / "style.css" + else: + style_path = None + style = webpages.split("```css")[1].split("```")[0] if style_path else "" + + if "scripts.js" in index: + js_path = webpages_path / "scripts.js" + elif "script.js" in index: + js_path = webpages_path / "script.js" + else: + js_path = None + js = webpages.split("```javascript")[1].split("```")[0] if js_path else "" + except IndexError: + raise ValueError("No css or js code found in the result, please check your image and try again.") + + try: + with open(index_path, "w") as f: + f.write(index) + if style_path: + with open(style_path, "w") as f: + f.write(style) + if js_path: + with open(js_path, "w") as f: + f.write(js) + except FileNotFoundError as e: + raise FileNotFoundError(f"Cannot save the webpages to {str(webpages_path)}") from e + + return webpages_path diff --git a/metagpt/tools/functions/schemas/vision.yml b/metagpt/tools/functions/schemas/vision.yml index 795854e75..4cb247419 100644 --- a/metagpt/tools/functions/schemas/vision.yml +++ b/metagpt/tools/functions/schemas/vision.yml @@ -12,9 +12,25 @@ Vision: image_path: type: str description: "The path of the image file" - required: - image_path returns: type: str - description: "Generated web page content." \ No newline at end of file + description: "Generated webpages content." + + save_webpages: + description: "Save webpages including all code(HTML, CSS and JavaScript) at once" + parameters: + properties: + image_path: + type: str + description: "The path of the image file" + webpages: + type: str + description: "The generated webpages content" + required: + - image_path + - webpages + returns: + type: Path + description: "The path of the saved webpages" \ No newline at end of file diff --git a/tests/metagpt/tools/functions/test_vision.py b/tests/metagpt/tools/functions/test_vision.py new file mode 100644 index 000000000..0359f14f1 --- /dev/null +++ b/tests/metagpt/tools/functions/test_vision.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/01/15 +@Author : mannaandpoem +@File : test_vision.py +""" +import base64 +from unittest.mock import AsyncMock + +from pytest_mock import mocker + +from metagpt import logs +from metagpt.tools.functions.libs.vision import Vision + + +def test_vision_generate_web_pages(): + image_path = "./image.png" + vision = Vision() + rsp = vision.generate_web_pages(image_path=image_path) + logs.logger.info(rsp) + assert "html" in rsp + assert "css" in rsp + assert "javascript" in rsp + + +def test_save_webpages(): + image_path = "./image.png" + vision = Vision() + webpages = """```html: \n + \n``` + "```css: .class { ... } ```\n ```javascript: function() { ... }```""" + webpages_dir = vision.save_webpages(image_path=image_path, webpages=webpages) + logs.logger.info(webpages_dir) + assert webpages_dir.exists() + assert (webpages_dir / "index.html").exists() + assert (webpages_dir / "style.css").exists() or (webpages_dir / "styles.css").exists() + assert (webpages_dir / "script.js").exists() or (webpages_dir / "scripts.js").exists() + +