mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add asyn sd ut
This commit is contained in:
parent
3be26cf94f
commit
12bc0104b6
2 changed files with 31 additions and 29 deletions
|
|
@ -3,11 +3,11 @@
|
|||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
from os.path import join
|
||||
from typing import List
|
||||
import hashlib
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
|
|
@ -59,14 +59,14 @@ class SDEngine:
|
|||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
|
|
@ -76,24 +76,24 @@ class SDEngine:
|
|||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
|
||||
def save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
|
||||
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
|
||||
with requests.Session() as session:
|
||||
logger.debug(self.sd_t2i_url)
|
||||
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
|
||||
|
||||
|
||||
results = rsp.json()["images"]
|
||||
if auto_save:
|
||||
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
|
||||
self.save(results, save_name=f"output_{save_name}")
|
||||
return results
|
||||
|
||||
|
||||
async def run_t2i(self, payloads: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
|
|
@ -101,21 +101,21 @@ class SDEngine:
|
|||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self.save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
|
@ -133,14 +133,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
for idx, _img in enumerate(imgs):
|
||||
save_name = join(save_dir, save_name)
|
||||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "1girl, beautiful"
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
# event_loop = asyncio.get_event_loop()
|
||||
# event_loop.run_until_complete(engine.run_t2i([engine.payload]))
|
||||
|
|
|
|||
|
|
@ -2,16 +2,29 @@
|
|||
# @Date : 1/10/2024 10:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
|
||||
def test_sd_tools():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
|
||||
|
||||
|
||||
def test_sd_construct_payload():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
assert "negative_prompt" in engine.payload
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sd_asyn_t2i():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
await engine.run_t2i([engine.payload])
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue