mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-23 15:48:11 +02:00
update
This commit is contained in:
parent
c28034ccbc
commit
8df7c2c02c
3 changed files with 246 additions and 20 deletions
|
|
@ -27,7 +27,7 @@ payload = {
|
|||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 20,
|
||||
"cfg_scale": 7,
|
||||
"cfg_scale": 9,
|
||||
"width": 512,
|
||||
"height": 768,
|
||||
"restore_faces": False,
|
||||
|
|
@ -62,52 +62,54 @@ class SDEngine:
|
|||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
**kwargs
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
self.payload["negative_prompt"] = negative_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
self.payload.update(**kwargs)
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
if not save_dir.exists():
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
|
||||
async def run_t2i(self, prompts: List, save_name=""):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(prompts):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
self._save(results, save_name=f"{save_name}_output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
|
@ -128,8 +130,8 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue