diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index de2988d2a..ba61fd496 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -3,11 +3,11 @@ # @Author : stellahong (stellahong@deepwisdom.ai) # @Desc : import base64 +import hashlib import io import json from os.path import join from typing import List -import hashlib import requests from aiohttp import ClientSession @@ -59,14 +59,14 @@ class SDEngine: # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) - + def construct_payload( - self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", + self, + prompt, + negtive_prompt=default_negative_prompt, + width=512, + height=512, + sd_model="galaxytimemachinesGTM_photoV20", ): # Configure the payload with provided inputs self.payload["prompt"] = prompt @@ -76,24 +76,24 @@ class SDEngine: self.payload["override_settings"]["sd_model_checkpoint"] = sd_model logger.info(f"call sd payload is {self.payload}") return self.payload - + def save(self, imgs, save_name=""): save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name) - + def simple_run_t2i(self, payload: dict, auto_save: bool = True): with requests.Session() as session: logger.debug(self.sd_t2i_url) rsp = session.post(self.sd_t2i_url, json=payload, timeout=600) - + results = rsp.json()["images"] if auto_save: save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6] self.save(results, save_name=f"output_{save_name}") return results - + async def run_t2i(self, payloads: List): # Asynchronously run the SD API for multiple prompts session = ClientSession() @@ -101,21 +101,21 @@ class SDEngine: results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) self.save(results, save_name=f"output_{payload_idx}") await session.close() - + async def run(self, url, payload, session): # Perform the HTTP POST request to the SD API async with session.post(url, json=payload, timeout=600) as rsp: data = await rsp.read() - + rsp_json = json.loads(data) imgs = rsp_json["images"] logger.info(f"callback rsp json is {rsp_json.keys()}") return imgs - + async def run_i2i(self): # todo: 添加图生图接口调用 raise NotImplementedError - + async def run_sam(self): # todo:添加SAM接口调用 raise NotImplementedError @@ -133,14 +133,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): for idx, _img in enumerate(imgs): save_name = join(save_dir, save_name) decode_base64_to_image(_img, save_name=save_name) - - -if __name__ == "__main__": - engine = SDEngine() - prompt = "1girl, beautiful" - prompt = "1boy, hansom" - engine.construct_payload(prompt) - - engine.simple_run_t2i(engine.payload) - # event_loop = asyncio.get_event_loop() - # event_loop.run_until_complete(engine.run_t2i([engine.payload])) diff --git a/tests/metagpt/tools/functions/test_sd.py b/tests/metagpt/tools/functions/test_sd.py index 405ac9a32..142101cad 100644 --- a/tests/metagpt/tools/functions/test_sd.py +++ b/tests/metagpt/tools/functions/test_sd.py @@ -2,16 +2,29 @@ # @Date : 1/10/2024 10:07 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : +import pytest + from metagpt.tools.sd_engine import SDEngine + def test_sd_tools(): engine = SDEngine() prompt = "1boy, hansom" engine.construct_payload(prompt) engine.simple_run_t2i(engine.payload) - + + def test_sd_construct_payload(): engine = SDEngine() prompt = "1boy, hansom" engine.construct_payload(prompt) - assert "negative_prompt" in engine.payload \ No newline at end of file + assert "negative_prompt" in engine.payload + + +@pytest.mark.asyncio +async def test_sd_asyn_t2i(): + engine = SDEngine() + prompt = "1boy, hansom" + engine.construct_payload(prompt) + await engine.run_t2i([engine.payload]) + assert "negative_prompt" in engine.payload