mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge branch 'ci_sd_ut_new' into 'code_intepreter'
SD unittest See merge request agents/data_agents_opt!70
This commit is contained in:
commit
e85f749031
3 changed files with 43 additions and 20 deletions
|
|
@ -13,7 +13,8 @@ import requests
|
|||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
#
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_data_type import ToolTypeEnum
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
|
|
@ -55,11 +56,9 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
@register_tool(tool_type=ToolTypeEnum.STABLE_DIFFUSION.value)
|
||||
class SDEngine:
|
||||
def __init__(self, sd_url=""):
|
||||
from metagpt.config2 import config
|
||||
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = sd_url if sd_url else config.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{config.get('SD_T2I_API')}"
|
||||
self.sd_url = sd_url
|
||||
self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
|
@ -82,7 +81,7 @@ class SDEngine:
|
|||
return self.payload
|
||||
|
||||
def save(self, imgs, save_name=""):
|
||||
save_dir = config.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
save_dir = SOURCE_ROOT / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
|
@ -113,17 +112,10 @@ class SDEngine:
|
|||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def decode_base64_to_image(img, save_name):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
|
||||
|
|
|
|||
|
|
@ -65,4 +65,4 @@ networkx~=3.2.1
|
|||
google-generativeai==0.3.2
|
||||
# playwright==1.40.0 # playwright extras require
|
||||
anytree
|
||||
ipywidgets==8.1.1
|
||||
ipywidgets==8.1.1
|
||||
|
|
@ -2,28 +2,59 @@
|
|||
# @Date : 1/10/2024 10:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from metagpt.tools.libs.sd_engine import SDEngine
|
||||
|
||||
|
||||
def test_sd_tools():
|
||||
engine = SDEngine()
|
||||
def generate_mock_image_data():
|
||||
# 创建一个简单的图片对象
|
||||
image = Image.new("RGB", (100, 100), color="white")
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.text((10, 10), "Mock Image", fill="black")
|
||||
|
||||
# 将图片转换为二进制数据
|
||||
with io.BytesIO() as buffer:
|
||||
image.save(buffer, format="PNG")
|
||||
image_binary = buffer.getvalue()
|
||||
|
||||
# 对图片二进制数据进行 base64 编码
|
||||
image_base64 = base64.b64encode(image_binary).decode("utf-8")
|
||||
|
||||
return image_base64
|
||||
|
||||
|
||||
def test_sd_tools(mocker):
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.json.return_value = {"images": [generate_mock_image_data()]}
|
||||
mocker.patch("requests.Session.post", return_value=mock_response)
|
||||
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
|
||||
|
||||
def test_sd_construct_payload():
|
||||
engine = SDEngine()
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sd_asyn_t2i():
|
||||
engine = SDEngine()
|
||||
async def test_sd_asyn_t2i(mocker):
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.read.return_value = json.dumps({"images": [generate_mock_image_data()]})
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
engine = SDEngine(sd_url="http://example_localhost:7860")
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
await engine.run_t2i([engine.payload])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue