增加非异步接口

sd工具yaml
This commit is contained in:
stellahsr 2024-01-11 20:48:27 +08:00
parent e56caa6f5e
commit a98edada1a
2 changed files with 88 additions and 26 deletions

View file

@ -0,0 +1,49 @@
SDEngine:
type: class
description: "Generate image using stable diffusion model"
methods:
__init__:
description: "Initialize the SDEngine instance."
parameters:
properties:
sd_url:
type: str
description: "URL of the stable diffusion service."
simple_run_t2i:
description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images."
parameters:
properties:
payload:
type: dict
description: "Dictionary of input parameters for the stable diffusion API."
auto_save:
type: bool
description: "Save generated images automatically."
required:
- prompts
construct_payload:
description: "Modify and set the API parameters for image generation."
parameters:
properties:
prompt:
type: str
description: "Text input for image generation."
required:
- prompt
returns:
payload:
type: dict
description: "Updated parameters for the stable diffusion API."
save:
description: "Save generated images to the output directory."
parameters:
properties:
imgs:
type: str
description: "Generated images."
save_name:
type: str
description: "Output image name. Default is empty."
required:
- imgs

View file

@ -2,13 +2,14 @@
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@deepwisdom.ai)
# @Desc :
import asyncio
import base64
import io
import json
from os.path import join
from typing import List
import hashlib
import requests
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
@ -51,59 +52,70 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
class SDEngine:
def __init__(self):
def __init__(self, sd_url=""):
# Initialize the SDEngine with configuration
self.sd_url = CONFIG.get("SD_URL")
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
self.payload["negative_prompt"] = negtive_prompt
self.payload["width"] = width
self.payload["height"] = height
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
def save(self, imgs, save_name=""):
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
async def run_t2i(self, prompts: List):
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
with requests.Session() as session:
logger.debug(self.sd_t2i_url)
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
results = rsp.json()["images"]
if auto_save:
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
self.save(results, save_name=f"output_{save_name}")
return results
async def run_t2i(self, payloads: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
for payload_idx, payload in enumerate(prompts):
for payload_idx, payload in enumerate(payloads):
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
self.save(results, save_name=f"output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
@ -125,9 +137,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
if __name__ == "__main__":
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
prompt = "1girl, beautiful"
prompt = "1boy, hansom"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))
engine.simple_run_t2i(engine.payload)
# event_loop = asyncio.get_event_loop()
# event_loop.run_until_complete(engine.run_t2i([engine.payload]))