use pytest to mock, rm dependency

This commit is contained in:
yzlin 2024-02-01 11:27:08 +08:00
parent c44d08ceb0
commit 45acde0d65
2 changed files with 20 additions and 39 deletions

View file

@ -4,11 +4,10 @@
# @Desc :
import base64
import io
import json
import pytest
from aioresponses import aioresponses
from PIL import Image, ImageDraw
from requests_mock import Mocker
from metagpt.tools.libs.sd_engine import SDEngine
@ -30,49 +29,33 @@ def generate_mock_image_data():
return image_base64
def test_sd_tools():
engine = SDEngine(sd_url="http://localhost:7860")
# 使用 requests_mock.Mocker 替换 simple_run_t2i 的网络请求
mock_imgs = generate_mock_image_data()
with Mocker() as mocker:
# 指定模拟请求的返回值
mocker.post(engine.sd_t2i_url, json={"images": [mock_imgs]})
def test_sd_tools(mocker):
mock_response = mocker.MagicMock()
mock_response.json.return_value = {"images": [generate_mock_image_data()]}
mocker.patch("requests.Session.post", return_value=mock_response)
# 在被测试代码中调用 simple_run_t2i
result = engine.simple_run_t2i(engine.payload)
# 断言结果是否是指定的 Mock 返回值
assert len(result) == 1
engine = SDEngine(sd_url="http://example_localhost:7860")
prompt = "1boy, hansom"
engine.construct_payload(prompt)
engine.simple_run_t2i(engine.payload)
def test_sd_construct_payload():
engine = SDEngine(sd_url="http://localhost:7860")
engine = SDEngine(sd_url="http://example_localhost:7860")
prompt = "1boy, hansom"
engine.construct_payload(prompt)
assert "negative_prompt" in engine.payload
@pytest.mark.asyncio
async def test_sd_asyn_t2i():
engine = SDEngine(sd_url="http://example.com/mock_sd_t2i")
async def test_sd_asyn_t2i(mocker):
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.read.return_value = json.dumps({"images": [generate_mock_image_data()]})
mock_post.return_value.__aenter__.return_value = mock_response
prompt = "1boy, hansom"
engine = SDEngine(sd_url="http://example_localhost:7860")
prompt = "1boy, hansom"
engine.construct_payload(prompt)
# 构建mock数据
mock_imgs = generate_mock_image_data()
mock_responses = aioresponses()
# 手动启动模拟
mock_responses.start()
try:
# 指定模拟请求的返回值
mock_responses.post("http://example.com/mock_sd_t2i/sdapi/v1/txt2img", payload={"images": [mock_imgs]})
# 在被测试代码中调用异步函数 run_t2i
await engine.run_t2i([engine.payload])
finally:
# 手动停止模拟
mock_responses.stop()
await engine.run_t2i([engine.payload])
assert "negative_prompt" in engine.payload