From 5d6c9217e9432f3cf0500d995659720b420d10a0 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Fri, 12 Apr 2024 17:12:13 +0800 Subject: [PATCH] update --- metagpt/tools/libs/sd_engine.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index b62e39db8..aff8ad6b5 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -8,6 +8,7 @@ import base64 import hashlib import io import json +import os from os.path import join import requests @@ -68,7 +69,7 @@ class SDEngine: Args: sd_url (str, optional): URL of the stable diffusion service. Defaults to "". """ - self.sd_url = sd_url + self.sd_url = os.getenv("sd_url") if sd_url else sd_url self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img" # Define default payload settings for SD API self.payload = payload @@ -76,12 +77,12 @@ class SDEngine: def construct_payload( self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", - ): + prompt: object, + negtive_prompt: object = default_negative_prompt, + width: object = 512, + height: object = 512, + sd_model: object = "galaxytimemachinesGTM_photoV20", + ) -> object: """Modify and set the API parameters for image generation. Args: @@ -179,3 +180,9 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): for idx, _img in enumerate(imgs): save_name = join(save_dir, save_name) decode_base64_to_image(_img, save_name=save_name) + + +if __name__ == "__main__": + sd = SDEngine(sd_url="http://172.31.0.51:49094") + payload = sd.construct_payload(prompt="a girl") + sd.simple_run_t2i(payload=payload)