mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
增加非异步接口
sd工具yaml
This commit is contained in:
parent
e56caa6f5e
commit
a98edada1a
2 changed files with 88 additions and 26 deletions
49
metagpt/tools/functions/schemas/stable_diffusion.yml
Normal file
49
metagpt/tools/functions/schemas/stable_diffusion.yml
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
SDEngine:
|
||||
type: class
|
||||
description: "Generate image using stable diffusion model"
|
||||
methods:
|
||||
__init__:
|
||||
description: "Initialize the SDEngine instance."
|
||||
parameters:
|
||||
properties:
|
||||
sd_url:
|
||||
type: str
|
||||
description: "URL of the stable diffusion service."
|
||||
|
||||
simple_run_t2i:
|
||||
description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images."
|
||||
parameters:
|
||||
properties:
|
||||
payload:
|
||||
type: dict
|
||||
description: "Dictionary of input parameters for the stable diffusion API."
|
||||
auto_save:
|
||||
type: bool
|
||||
description: "Save generated images automatically."
|
||||
required:
|
||||
- prompts
|
||||
construct_payload:
|
||||
description: "Modify and set the API parameters for image generation."
|
||||
parameters:
|
||||
properties:
|
||||
prompt:
|
||||
type: str
|
||||
description: "Text input for image generation."
|
||||
required:
|
||||
- prompt
|
||||
returns:
|
||||
payload:
|
||||
type: dict
|
||||
description: "Updated parameters for the stable diffusion API."
|
||||
save:
|
||||
description: "Save generated images to the output directory."
|
||||
parameters:
|
||||
properties:
|
||||
imgs:
|
||||
type: str
|
||||
description: "Generated images."
|
||||
save_name:
|
||||
type: str
|
||||
description: "Output image name. Default is empty."
|
||||
required:
|
||||
- imgs
|
||||
|
|
@ -2,13 +2,14 @@
|
|||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from os.path import join
|
||||
from typing import List
|
||||
import hashlib
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
|
|
@ -51,59 +52,70 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
|
||||
|
||||
class SDEngine:
|
||||
def __init__(self):
|
||||
def __init__(self, sd_url=""):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = CONFIG.get("SD_URL")
|
||||
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
self.payload["negative_prompt"] = negtive_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
|
||||
def save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
|
||||
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
|
||||
with requests.Session() as session:
|
||||
logger.debug(self.sd_t2i_url)
|
||||
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
|
||||
|
||||
results = rsp.json()["images"]
|
||||
if auto_save:
|
||||
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
|
||||
self.save(results, save_name=f"output_{save_name}")
|
||||
return results
|
||||
|
||||
async def run_t2i(self, payloads: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(prompts):
|
||||
for payload_idx, payload in enumerate(payloads):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
self.save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
|
@ -125,9 +137,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
|
||||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
prompt = "1girl, beautiful"
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
# event_loop = asyncio.get_event_loop()
|
||||
# event_loop.run_until_complete(engine.run_t2i([engine.payload]))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue