diff --git a/metagpt/tools/functions/schemas/stable_diffusion.yml b/metagpt/tools/functions/schemas/stable_diffusion.yml new file mode 100644 index 000000000..119449caa --- /dev/null +++ b/metagpt/tools/functions/schemas/stable_diffusion.yml @@ -0,0 +1,49 @@ +SDEngine: + type: class + description: "Generate image using stable diffusion model" + methods: + __init__: + description: "Initialize the SDEngine instance." + parameters: + properties: + sd_url: + type: str + description: "URL of the stable diffusion service." + + simple_run_t2i: + description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images." + parameters: + properties: + payload: + type: dict + description: "Dictionary of input parameters for the stable diffusion API." + auto_save: + type: bool + description: "Save generated images automatically." + required: + - prompts + construct_payload: + description: "Modify and set the API parameters for image generation." + parameters: + properties: + prompt: + type: str + description: "Text input for image generation." + required: + - prompt + returns: + payload: + type: dict + description: "Updated parameters for the stable diffusion API." + save: + description: "Save generated images to the output directory." + parameters: + properties: + imgs: + type: str + description: "Generated images." + save_name: + type: str + description: "Output image name. Default is empty." + required: + - imgs diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index c4d9d2df4..de2988d2a 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -2,13 +2,14 @@ # @Date : 2023/7/19 16:28 # @Author : stellahong (stellahong@deepwisdom.ai) # @Desc : -import asyncio import base64 import io import json from os.path import join from typing import List +import hashlib +import requests from aiohttp import ClientSession from PIL import Image, PngImagePlugin @@ -51,59 +52,70 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" class SDEngine: - def __init__(self): + def __init__(self, sd_url=""): # Initialize the SDEngine with configuration - self.sd_url = CONFIG.get("SD_URL") + self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL") self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}" # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) - + def construct_payload( - self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", + self, + prompt, + negtive_prompt=default_negative_prompt, + width=512, + height=512, + sd_model="galaxytimemachinesGTM_photoV20", ): # Configure the payload with provided inputs self.payload["prompt"] = prompt - self.payload["negtive_prompt"] = negtive_prompt + self.payload["negative_prompt"] = negtive_prompt self.payload["width"] = width self.payload["height"] = height self.payload["override_settings"]["sd_model_checkpoint"] = sd_model logger.info(f"call sd payload is {self.payload}") return self.payload - - def _save(self, imgs, save_name=""): + + def save(self, imgs, save_name=""): save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name) - - async def run_t2i(self, prompts: List): + + def simple_run_t2i(self, payload: dict, auto_save: bool = True): + with requests.Session() as session: + logger.debug(self.sd_t2i_url) + rsp = session.post(self.sd_t2i_url, json=payload, timeout=600) + + results = rsp.json()["images"] + if auto_save: + save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6] + self.save(results, save_name=f"output_{save_name}") + return results + + async def run_t2i(self, payloads: List): # Asynchronously run the SD API for multiple prompts session = ClientSession() - for payload_idx, payload in enumerate(prompts): + for payload_idx, payload in enumerate(payloads): results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) - self._save(results, save_name=f"output_{payload_idx}") + self.save(results, save_name=f"output_{payload_idx}") await session.close() - + async def run(self, url, payload, session): # Perform the HTTP POST request to the SD API async with session.post(url, json=payload, timeout=600) as rsp: data = await rsp.read() - + rsp_json = json.loads(data) imgs = rsp_json["images"] logger.info(f"callback rsp json is {rsp_json.keys()}") return imgs - + async def run_i2i(self): # todo: 添加图生图接口调用 raise NotImplementedError - + async def run_sam(self): # todo:添加SAM接口调用 raise NotImplementedError @@ -125,9 +137,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): if __name__ == "__main__": engine = SDEngine() - prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" - + prompt = "1girl, beautiful" + prompt = "1boy, hansom" engine.construct_payload(prompt) - - event_loop = asyncio.get_event_loop() - event_loop.run_until_complete(engine.run_t2i(prompt)) + + engine.simple_run_t2i(engine.payload) + # event_loop = asyncio.get_event_loop() + # event_loop.run_until_complete(engine.run_t2i([engine.payload]))