add asyn sd ut

This commit is contained in:
stellahsr 2024-01-11 22:43:02 +08:00
parent 3be26cf94f
commit 12bc0104b6
2 changed files with 31 additions and 29 deletions

View file

@ -3,11 +3,11 @@
# @Author : stellahong (stellahong@deepwisdom.ai)
# @Desc :
import base64
import hashlib
import io
import json
from os.path import join
from typing import List
import hashlib
import requests
from aiohttp import ClientSession
@ -59,14 +59,14 @@ class SDEngine:
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
@ -76,24 +76,24 @@ class SDEngine:
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
def save(self, imgs, save_name=""):
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
with requests.Session() as session:
logger.debug(self.sd_t2i_url)
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
results = rsp.json()["images"]
if auto_save:
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
self.save(results, save_name=f"output_{save_name}")
return results
async def run_t2i(self, payloads: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
@ -101,21 +101,21 @@ class SDEngine:
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self.save(results, save_name=f"output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
@ -133,14 +133,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
for idx, _img in enumerate(imgs):
save_name = join(save_dir, save_name)
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
engine = SDEngine()
prompt = "1girl, beautiful"
prompt = "1boy, hansom"
engine.construct_payload(prompt)
engine.simple_run_t2i(engine.payload)
# event_loop = asyncio.get_event_loop()
# event_loop.run_until_complete(engine.run_t2i([engine.payload]))

View file

@ -2,16 +2,29 @@
# @Date : 1/10/2024 10:07 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import pytest
from metagpt.tools.sd_engine import SDEngine
def test_sd_tools():
engine = SDEngine()
prompt = "1boy, hansom"
engine.construct_payload(prompt)
engine.simple_run_t2i(engine.payload)
def test_sd_construct_payload():
engine = SDEngine()
prompt = "1boy, hansom"
engine.construct_payload(prompt)
assert "negative_prompt" in engine.payload
assert "negative_prompt" in engine.payload
@pytest.mark.asyncio
async def test_sd_asyn_t2i():
engine = SDEngine()
prompt = "1boy, hansom"
engine.construct_payload(prompt)
await engine.run_t2i([engine.payload])
assert "negative_prompt" in engine.payload