This commit is contained in:
stellahsr 2023-09-12 22:12:31 +08:00
parent c28034ccbc
commit 8df7c2c02c
3 changed files with 246 additions and 20 deletions

View file

@ -27,7 +27,7 @@ payload = {
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 7,
"cfg_scale": 9,
"width": 512,
"height": 768,
"restore_faces": False,
@ -62,52 +62,54 @@ class SDEngine:
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
self,
prompt,
negative_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
**kwargs
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
self.payload["negative_prompt"] = negative_prompt
self.payload["width"] = width
self.payload["height"] = height
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
self.payload.update(**kwargs)
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
if not os.path.exists(save_dir):
if not save_dir.exists():
os.makedirs(save_dir, exist_ok=True)
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
async def run_t2i(self, prompts: List):
async def run_t2i(self, prompts: List, save_name=""):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
for payload_idx, payload in enumerate(prompts):
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
self._save(results, save_name=f"{save_name}_output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
@ -128,8 +130,8 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
if __name__ == "__main__":
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))