From d74dab9bec1a42503984b9acd1c247d8b151b323 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Wed, 31 Jan 2024 16:03:16 +0800 Subject: [PATCH 1/4] update sd ut --- examples/imitate_webpage.py | 4 +- metagpt/tools/libs/sd_engine.py | 14 ++--- tests/metagpt/tools/libs/test_sd_engine.py | 66 +++++++++++++++++++--- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/examples/imitate_webpage.py b/examples/imitate_webpage.py index 6c12c7eda..b69101861 100644 --- a/examples/imitate_webpage.py +++ b/examples/imitate_webpage.py @@ -9,7 +9,7 @@ from metagpt.roles.code_interpreter import CodeInterpreter async def main(): - web_url = 'https://pytorch.org/' + web_url = "https://pytorch.org/" prompt = f"""This is a URL of webpage: '{web_url}' . Firstly, utilize Selenium and WebDriver for rendering. Secondly, convert image to a webpage including HTML, CSS and JS in one go. @@ -20,7 +20,7 @@ Note: All required dependencies and environments have been fully installed and c await ci.run(prompt) -if __name__ == '__main__': +if __name__ == "__main__": import asyncio asyncio.run(main()) diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index 794758f77..7f182f380 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -13,7 +13,8 @@ import requests from aiohttp import ClientSession from PIL import Image, PngImagePlugin -from metagpt.const import SD_OUTPUT_FILE_REPO +# +from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT from metagpt.logs import logger from metagpt.tools.tool_data_type import ToolTypeEnum from metagpt.tools.tool_registry import register_tool @@ -82,7 +83,7 @@ class SDEngine: return self.payload def save(self, imgs, save_name=""): - save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO + save_dir = SOURCE_ROOT / SD_OUTPUT_FILE_REPO if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name) @@ -113,17 +114,10 @@ class SDEngine: rsp_json = json.loads(data) imgs = rsp_json["images"] + logger.info(f"callback rsp json is {rsp_json.keys()}") return imgs - async def run_i2i(self): - # todo: 添加图生图接口调用 - raise NotImplementedError - - async def run_sam(self): - # todo:添加SAM接口调用 - raise NotImplementedError - def decode_base64_to_image(img, save_name): image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0]))) diff --git a/tests/metagpt/tools/libs/test_sd_engine.py b/tests/metagpt/tools/libs/test_sd_engine.py index 363cf96b9..322976806 100644 --- a/tests/metagpt/tools/libs/test_sd_engine.py +++ b/tests/metagpt/tools/libs/test_sd_engine.py @@ -2,20 +2,51 @@ # @Date : 1/10/2024 10:07 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : +import base64 +import io + import pytest +from aioresponses import aioresponses +from PIL import Image, ImageDraw +from requests_mock import Mocker from metagpt.tools.libs.sd_engine import SDEngine +def generate_mock_image_data(): + # 创建一个简单的图片对象 + image = Image.new("RGB", (100, 100), color="white") + draw = ImageDraw.Draw(image) + draw.text((10, 10), "Mock Image", fill="black") + + # 将图片转换为二进制数据 + with io.BytesIO() as buffer: + image.save(buffer, format="PNG") + image_binary = buffer.getvalue() + + # 对图片二进制数据进行 base64 编码 + image_base64 = base64.b64encode(image_binary).decode("utf-8") + + return image_base64 + + def test_sd_tools(): - engine = SDEngine() - prompt = "1boy, hansom" - engine.construct_payload(prompt) - engine.simple_run_t2i(engine.payload) + engine = SDEngine(sd_url="http://localhost:7860") + # 使用 requests_mock.Mocker 替换 simple_run_t2i 的网络请求 + mock_imgs = generate_mock_image_data() + with Mocker() as mocker: + # 指定模拟请求的返回值 + mocker.post(engine.sd_t2i_url, json={"images": [mock_imgs]}) + + # 在被测试代码中调用 simple_run_t2i + result = engine.simple_run_t2i(engine.payload) + + # 断言结果是否是指定的 Mock 返回值 + assert len(result) == 1 def test_sd_construct_payload(): - engine = SDEngine() + engine = SDEngine(sd_url="http://localhost:7860") prompt = "1boy, hansom" engine.construct_payload(prompt) assert "negative_prompt" in engine.payload @@ -23,8 +54,25 @@ def test_sd_construct_payload(): @pytest.mark.asyncio async def test_sd_asyn_t2i(): - engine = SDEngine() - prompt = "1boy, hansom" + engine = SDEngine(sd_url="http://example.com/mock_sd_t2i") + + prompt = "1boy, hansom" engine.construct_payload(prompt) - await engine.run_t2i([engine.payload]) - assert "negative_prompt" in engine.payload + # 构建mock数据 + mock_imgs = generate_mock_image_data() + + mock_responses = aioresponses() + + # 手动启动模拟 + mock_responses.start() + + try: + # 指定模拟请求的返回值 + mock_responses.post("http://example.com/mock_sd_t2i/sdapi/v1/txt2img", payload={"images": [mock_imgs]}) + + # 在被测试代码中调用异步函数 run_t2i + await engine.run_t2i([engine.payload]) + + finally: + # 手动停止模拟 + mock_responses.stop() From 28b0323d7552f109f186b06bdc7505c93db5be85 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Wed, 31 Jan 2024 16:14:33 +0800 Subject: [PATCH 2/4] add package for test_sd_engine --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 66b3c9fc0..4a9c0ab30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -66,3 +66,5 @@ google-generativeai==0.3.2 # playwright==1.40.0 # playwright extras require anytree ipywidgets==8.1.1 +aioresponses +requests_mock \ No newline at end of file From c44d08ceb05ee177915506a84fc40b021ef4698c Mon Sep 17 00:00:00 2001 From: stellahsr Date: Wed, 31 Jan 2024 16:30:50 +0800 Subject: [PATCH 3/4] rm config get in dev --- metagpt/tools/libs/sd_engine.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index 57a025f3c..7001eadf5 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -56,11 +56,9 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" @register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value) class SDEngine: def __init__(self, sd_url=""): - from metagpt.config2 import config - # Initialize the SDEngine with configuration - self.sd_url = sd_url if sd_url else config.get("SD_URL") - self.sd_t2i_url = f"{self.sd_url}{config.get('SD_T2I_API')}" + self.sd_url = sd_url + self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img" # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) From 45acde0d65abc7ee712aeccef2141d0846dbbb56 Mon Sep 17 00:00:00 2001 From: yzlin Date: Thu, 1 Feb 2024 11:27:08 +0800 Subject: [PATCH 4/4] use pytest to mock, rm dependency --- requirements.txt | 4 +- tests/metagpt/tools/libs/test_sd_engine.py | 55 ++++++++-------------- 2 files changed, 20 insertions(+), 39 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4a9c0ab30..dff615bdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -65,6 +65,4 @@ networkx~=3.2.1 google-generativeai==0.3.2 # playwright==1.40.0 # playwright extras require anytree -ipywidgets==8.1.1 -aioresponses -requests_mock \ No newline at end of file +ipywidgets==8.1.1 \ No newline at end of file diff --git a/tests/metagpt/tools/libs/test_sd_engine.py b/tests/metagpt/tools/libs/test_sd_engine.py index 322976806..e2c46e72a 100644 --- a/tests/metagpt/tools/libs/test_sd_engine.py +++ b/tests/metagpt/tools/libs/test_sd_engine.py @@ -4,11 +4,10 @@ # @Desc : import base64 import io +import json import pytest -from aioresponses import aioresponses from PIL import Image, ImageDraw -from requests_mock import Mocker from metagpt.tools.libs.sd_engine import SDEngine @@ -30,49 +29,33 @@ def generate_mock_image_data(): return image_base64 -def test_sd_tools(): - engine = SDEngine(sd_url="http://localhost:7860") - # 使用 requests_mock.Mocker 替换 simple_run_t2i 的网络请求 - mock_imgs = generate_mock_image_data() - with Mocker() as mocker: - # 指定模拟请求的返回值 - mocker.post(engine.sd_t2i_url, json={"images": [mock_imgs]}) +def test_sd_tools(mocker): + mock_response = mocker.MagicMock() + mock_response.json.return_value = {"images": [generate_mock_image_data()]} + mocker.patch("requests.Session.post", return_value=mock_response) - # 在被测试代码中调用 simple_run_t2i - result = engine.simple_run_t2i(engine.payload) - - # 断言结果是否是指定的 Mock 返回值 - assert len(result) == 1 + engine = SDEngine(sd_url="http://example_localhost:7860") + prompt = "1boy, hansom" + engine.construct_payload(prompt) + engine.simple_run_t2i(engine.payload) def test_sd_construct_payload(): - engine = SDEngine(sd_url="http://localhost:7860") + engine = SDEngine(sd_url="http://example_localhost:7860") prompt = "1boy, hansom" engine.construct_payload(prompt) assert "negative_prompt" in engine.payload @pytest.mark.asyncio -async def test_sd_asyn_t2i(): - engine = SDEngine(sd_url="http://example.com/mock_sd_t2i") +async def test_sd_asyn_t2i(mocker): + mock_post = mocker.patch("aiohttp.ClientSession.post") + mock_response = mocker.AsyncMock() + mock_response.read.return_value = json.dumps({"images": [generate_mock_image_data()]}) + mock_post.return_value.__aenter__.return_value = mock_response - prompt = "1boy, hansom" + engine = SDEngine(sd_url="http://example_localhost:7860") + prompt = "1boy, hansom" engine.construct_payload(prompt) - # 构建mock数据 - mock_imgs = generate_mock_image_data() - - mock_responses = aioresponses() - - # 手动启动模拟 - mock_responses.start() - - try: - # 指定模拟请求的返回值 - mock_responses.post("http://example.com/mock_sd_t2i/sdapi/v1/txt2img", payload={"images": [mock_imgs]}) - - # 在被测试代码中调用异步函数 run_t2i - await engine.run_t2i([engine.payload]) - - finally: - # 手动停止模拟 - mock_responses.stop() + await engine.run_t2i([engine.payload]) + assert "negative_prompt" in engine.payload