This commit is contained in:
stellahsr 2024-04-12 17:12:13 +08:00
parent 780c68e43f
commit 5d6c9217e9

View file

@ -8,6 +8,7 @@ import base64
import hashlib
import io
import json
import os
from os.path import join
import requests
@ -68,7 +69,7 @@ class SDEngine:
Args:
sd_url (str, optional): URL of the stable diffusion service. Defaults to "".
"""
self.sd_url = sd_url
self.sd_url = os.getenv("sd_url") if sd_url else sd_url
self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img"
# Define default payload settings for SD API
self.payload = payload
@ -76,12 +77,12 @@ class SDEngine:
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
):
prompt: object,
negtive_prompt: object = default_negative_prompt,
width: object = 512,
height: object = 512,
sd_model: object = "galaxytimemachinesGTM_photoV20",
) -> object:
"""Modify and set the API parameters for image generation.
Args:
@ -179,3 +180,9 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
for idx, _img in enumerate(imgs):
save_name = join(save_dir, save_name)
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
sd = SDEngine(sd_url="http://172.31.0.51:49094")
payload = sd.construct_payload(prompt="a girl")
sd.simple_run_t2i(payload=payload)