From 5d6c9217e9432f3cf0500d995659720b420d10a0 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Fri, 12 Apr 2024 17:12:13 +0800 Subject: [PATCH 1/2] update --- metagpt/tools/libs/sd_engine.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index b62e39db8..aff8ad6b5 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -8,6 +8,7 @@ import base64 import hashlib import io import json +import os from os.path import join import requests @@ -68,7 +69,7 @@ class SDEngine: Args: sd_url (str, optional): URL of the stable diffusion service. Defaults to "". """ - self.sd_url = sd_url + self.sd_url = os.getenv("sd_url") if sd_url else sd_url self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img" # Define default payload settings for SD API self.payload = payload @@ -76,12 +77,12 @@ class SDEngine: def construct_payload( self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", - ): + prompt: object, + negtive_prompt: object = default_negative_prompt, + width: object = 512, + height: object = 512, + sd_model: object = "galaxytimemachinesGTM_photoV20", + ) -> object: """Modify and set the API parameters for image generation. Args: @@ -179,3 +180,9 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): for idx, _img in enumerate(imgs): save_name = join(save_dir, save_name) decode_base64_to_image(_img, save_name=save_name) + + +if __name__ == "__main__": + sd = SDEngine(sd_url="http://172.31.0.51:49094") + payload = sd.construct_payload(prompt="a girl") + sd.simple_run_t2i(payload=payload) From 04217d45e0d74a8c45b124acc02a81a57485ba82 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Fri, 12 Apr 2024 17:34:36 +0800 Subject: [PATCH 2/2] update code, set default url and output dir --- metagpt/const.py | 3 ++- metagpt/tools/libs/sd_engine.py | 11 ++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/metagpt/const.py b/metagpt/const.py index 484987a03..c01f92adc 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -103,12 +103,13 @@ TEST_OUTPUTS_FILE_REPO = "test_outputs" CODE_SUMMARIES_FILE_REPO = "docs/code_summary" CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary" RESOURCES_FILE_REPO = "resources" -SD_OUTPUT_FILE_REPO = "resources/sd_output" +SD_OUTPUT_FILE_REPO = DEFAULT_WORKSPACE_ROOT GRAPH_REPO_FILE_REPO = "docs/graph_repo" VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db" CLASS_VIEW_FILE_REPO = "docs/class_view" YAPI_URL = "http://yapi.deepwisdomai.com/" +SD_URL = "http://172.31.0.51:49094" DEFAULT_LANGUAGE = "English" DEFAULT_MAX_TOKENS = 1500 diff --git a/metagpt/tools/libs/sd_engine.py b/metagpt/tools/libs/sd_engine.py index aff8ad6b5..4cf7d2310 100644 --- a/metagpt/tools/libs/sd_engine.py +++ b/metagpt/tools/libs/sd_engine.py @@ -8,14 +8,13 @@ import base64 import hashlib import io import json -import os from os.path import join import requests from aiohttp import ClientSession from PIL import Image, PngImagePlugin -from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT +from metagpt.const import SD_OUTPUT_FILE_REPO, SD_URL, SOURCE_ROOT from metagpt.logs import logger from metagpt.tools.tool_registry import register_tool @@ -69,7 +68,7 @@ class SDEngine: Args: sd_url (str, optional): URL of the stable diffusion service. Defaults to "". """ - self.sd_url = os.getenv("sd_url") if sd_url else sd_url + self.sd_url = SD_URL if not sd_url else sd_url self.sd_t2i_url = f"{self.sd_url}/sdapi/v1/txt2img" # Define default payload settings for SD API self.payload = payload @@ -180,9 +179,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): for idx, _img in enumerate(imgs): save_name = join(save_dir, save_name) decode_base64_to_image(_img, save_name=save_name) - - -if __name__ == "__main__": - sd = SDEngine(sd_url="http://172.31.0.51:49094") - payload = sd.construct_payload(prompt="a girl") - sd.simple_run_t2i(payload=payload)