From 7f1584db9e5bd153f5f78f2f79d3d16970f20f0c Mon Sep 17 00:00:00 2001
From: mannaandpoem <1580466765@qq.com>
Date: Mon, 15 Jan 2024 17:26:35 +0800
Subject: [PATCH] 1. add test_vision.py 2. add save_webpages function in
vision.py and vision.yml
---
metagpt/tools/functions/libs/vision.py | 64 +++++++++++++++++---
metagpt/tools/functions/schemas/vision.yml | 20 +++++-
tests/metagpt/tools/functions/test_vision.py | 40 ++++++++++++
3 files changed, 113 insertions(+), 11 deletions(-)
create mode 100644 tests/metagpt/tools/functions/test_vision.py
diff --git a/metagpt/tools/functions/libs/vision.py b/metagpt/tools/functions/libs/vision.py
index 8c29b0567..b10ad7608 100644
--- a/metagpt/tools/functions/libs/vision.py
+++ b/metagpt/tools/functions/libs/vision.py
@@ -5,6 +5,8 @@
@Author : mannaandpoem
@File : vision.py
"""
+from pathlib import Path
+
import requests
import base64
@@ -34,8 +36,9 @@ Now, please generate the corresponding webpage code including HTML, CSS and Java
class Vision:
def __init__(self):
self.api_key = API_KEY
+ self.api_base = OPENAI_API_BASE
self.model = MODEL
- self.max_tokens = 4096
+ self.max_tokens = MAX_TOKENS
def analyze_layout(self, image_path):
return self.get_result(image_path, ANALYZE_LAYOUT_PROMPT)
@@ -43,7 +46,8 @@ class Vision:
def generate_web_pages(self, image_path):
layout = self.analyze_layout(image_path)
prompt = GENERATE_PROMPT + "\n\n # Context\n The layout information of the sketch image is: \n" + layout
- return self.get_result(image_path, prompt)
+ result = self.get_result(image_path, prompt)
+ return result
def get_result(self, image_path, prompt):
base64_image = self.encode_image(image_path)
@@ -67,17 +71,59 @@ class Vision:
],
"max_tokens": self.max_tokens,
}
- response = requests.post(f"{OPENAI_API_BASE}/chat/completions", headers=headers, json=payload)
- return response.json()["choices"][0]["message"]["content"]
+ response = requests.post(f"{self.api_base}/chat/completions", headers=headers, json=payload)
+
+ if response.status_code != 200:
+ raise ValueError(f"Request failed with status {response.status_code}, {response.text}")
+ else:
+ return response.json()["choices"][0]["message"]["content"]
@staticmethod
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
+ @staticmethod
+ def save_webpages(image_path, webpages) -> Path:
+ # 在当前目录下创建一个名为webpages的文件夹,用于存储html、css和js文件
+ webpages_path = Path(image_path).parent / "webpages"
+ webpages_path.mkdir(exist_ok=True)
-if __name__ == "__main__":
- image_path = "image.png"
- vision = Vision()
- rsp = vision.generate_web_pages(image_path=image_path)
- print(rsp)
+ try:
+ index_path = webpages_path / "index.html"
+ index = webpages.split("```html")[1].split("```")[0]
+ except IndexError:
+ raise ValueError("No html code found in the result, please check your image and try again.")
+
+ try:
+ if "styles.css" in index:
+ style_path = webpages_path / "styles.css"
+ elif "style.css" in index:
+ style_path = webpages_path / "style.css"
+ else:
+ style_path = None
+ style = webpages.split("```css")[1].split("```")[0] if style_path else ""
+
+ if "scripts.js" in index:
+ js_path = webpages_path / "scripts.js"
+ elif "script.js" in index:
+ js_path = webpages_path / "script.js"
+ else:
+ js_path = None
+ js = webpages.split("```javascript")[1].split("```")[0] if js_path else ""
+ except IndexError:
+ raise ValueError("No css or js code found in the result, please check your image and try again.")
+
+ try:
+ with open(index_path, "w") as f:
+ f.write(index)
+ if style_path:
+ with open(style_path, "w") as f:
+ f.write(style)
+ if js_path:
+ with open(js_path, "w") as f:
+ f.write(js)
+ except FileNotFoundError as e:
+ raise FileNotFoundError(f"Cannot save the webpages to {str(webpages_path)}") from e
+
+ return webpages_path
diff --git a/metagpt/tools/functions/schemas/vision.yml b/metagpt/tools/functions/schemas/vision.yml
index 795854e75..4cb247419 100644
--- a/metagpt/tools/functions/schemas/vision.yml
+++ b/metagpt/tools/functions/schemas/vision.yml
@@ -12,9 +12,25 @@ Vision:
image_path:
type: str
description: "The path of the image file"
-
required:
- image_path
returns:
type: str
- description: "Generated web page content."
\ No newline at end of file
+ description: "Generated webpages content."
+
+ save_webpages:
+ description: "Save webpages including all code(HTML, CSS and JavaScript) at once"
+ parameters:
+ properties:
+ image_path:
+ type: str
+ description: "The path of the image file"
+ webpages:
+ type: str
+ description: "The generated webpages content"
+ required:
+ - image_path
+ - webpages
+ returns:
+ type: Path
+ description: "The path of the saved webpages"
\ No newline at end of file
diff --git a/tests/metagpt/tools/functions/test_vision.py b/tests/metagpt/tools/functions/test_vision.py
new file mode 100644
index 000000000..0359f14f1
--- /dev/null
+++ b/tests/metagpt/tools/functions/test_vision.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+@Time : 2024/01/15
+@Author : mannaandpoem
+@File : test_vision.py
+"""
+import base64
+from unittest.mock import AsyncMock
+
+from pytest_mock import mocker
+
+from metagpt import logs
+from metagpt.tools.functions.libs.vision import Vision
+
+
+def test_vision_generate_web_pages():
+ image_path = "./image.png"
+ vision = Vision()
+ rsp = vision.generate_web_pages(image_path=image_path)
+ logs.logger.info(rsp)
+ assert "html" in rsp
+ assert "css" in rsp
+ assert "javascript" in rsp
+
+
+def test_save_webpages():
+ image_path = "./image.png"
+ vision = Vision()
+ webpages = """```html: \n
+ \n```
+ "```css: .class { ... } ```\n ```javascript: function() { ... }```"""
+ webpages_dir = vision.save_webpages(image_path=image_path, webpages=webpages)
+ logs.logger.info(webpages_dir)
+ assert webpages_dir.exists()
+ assert (webpages_dir / "index.html").exists()
+ assert (webpages_dir / "style.css").exists() or (webpages_dir / "styles.css").exists()
+ assert (webpages_dir / "script.js").exists() or (webpages_dir / "scripts.js").exists()
+
+