update sd ut

This commit is contained in:
stellahsr 2024-01-31 16:03:16 +08:00
parent d60a4c1cdb
commit d74dab9bec
3 changed files with 63 additions and 21 deletions

View file

@ -9,7 +9,7 @@ from metagpt.roles.code_interpreter import CodeInterpreter
async def main():
web_url = 'https://pytorch.org/'
web_url = "https://pytorch.org/"
prompt = f"""This is a URL of webpage: '{web_url}' .
Firstly, utilize Selenium and WebDriver for rendering.
Secondly, convert image to a webpage including HTML, CSS and JS in one go.
@ -20,7 +20,7 @@ Note: All required dependencies and environments have been fully installed and c
await ci.run(prompt)
if __name__ == '__main__':
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -13,7 +13,8 @@ import requests
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
from metagpt.const import SD_OUTPUT_FILE_REPO
#
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
from metagpt.logs import logger
from metagpt.tools.tool_data_type import ToolTypeEnum
from metagpt.tools.tool_registry import register_tool
@ -82,7 +83,7 @@ class SDEngine:
return self.payload
def save(self, imgs, save_name=""):
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
save_dir = SOURCE_ROOT / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
@ -113,17 +114,10 @@ class SDEngine:
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
def decode_base64_to_image(img, save_name):
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))

View file

@ -2,20 +2,51 @@
# @Date : 1/10/2024 10:07 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import base64
import io
import pytest
from aioresponses import aioresponses
from PIL import Image, ImageDraw
from requests_mock import Mocker
from metagpt.tools.libs.sd_engine import SDEngine
def generate_mock_image_data():
# 创建一个简单的图片对象
image = Image.new("RGB", (100, 100), color="white")
draw = ImageDraw.Draw(image)
draw.text((10, 10), "Mock Image", fill="black")
# 将图片转换为二进制数据
with io.BytesIO() as buffer:
image.save(buffer, format="PNG")
image_binary = buffer.getvalue()
# 对图片二进制数据进行 base64 编码
image_base64 = base64.b64encode(image_binary).decode("utf-8")
return image_base64
def test_sd_tools():
engine = SDEngine()
prompt = "1boy, hansom"
engine.construct_payload(prompt)
engine.simple_run_t2i(engine.payload)
engine = SDEngine(sd_url="http://localhost:7860")
# 使用 requests_mock.Mocker 替换 simple_run_t2i 的网络请求
mock_imgs = generate_mock_image_data()
with Mocker() as mocker:
# 指定模拟请求的返回值
mocker.post(engine.sd_t2i_url, json={"images": [mock_imgs]})
# 在被测试代码中调用 simple_run_t2i
result = engine.simple_run_t2i(engine.payload)
# 断言结果是否是指定的 Mock 返回值
assert len(result) == 1
def test_sd_construct_payload():
engine = SDEngine()
engine = SDEngine(sd_url="http://localhost:7860")
prompt = "1boy, hansom"
engine.construct_payload(prompt)
assert "negative_prompt" in engine.payload
@ -23,8 +54,25 @@ def test_sd_construct_payload():
@pytest.mark.asyncio
async def test_sd_asyn_t2i():
engine = SDEngine()
prompt = "1boy, hansom"
engine = SDEngine(sd_url="http://example.com/mock_sd_t2i")
prompt = "1boy, hansom"
engine.construct_payload(prompt)
await engine.run_t2i([engine.payload])
assert "negative_prompt" in engine.payload
# 构建mock数据
mock_imgs = generate_mock_image_data()
mock_responses = aioresponses()
# 手动启动模拟
mock_responses.start()
try:
# 指定模拟请求的返回值
mock_responses.post("http://example.com/mock_sd_t2i/sdapi/v1/txt2img", payload={"images": [mock_imgs]})
# 在被测试代码中调用异步函数 run_t2i
await engine.run_t2i([engine.payload])
finally:
# 手动停止模拟
mock_responses.stop()