This commit is contained in:
stellahsr 2023-09-12 22:12:31 +08:00
parent c28034ccbc
commit 8df7c2c02c
3 changed files with 246 additions and 20 deletions

130
metagpt/actions/design.py Normal file
View file

@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
# @Date : 2023/8/17 13:43
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
# Standard library imports
from functools import wraps
from typing import Callable, Any, List, Optional, Tuple
# Local library imports
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.utils.common import OutputParser
from metagpt.prompts.sd_design import (
SD_PROMPT_KW_OPTIMIZE_TEMPLATE,
SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE,
FORMAT_INSTRUCTIONS,
PROMPT_OUTPUT_MAPPING
)
from metagpt.utils.resp_parse import flatten_json_structure, try_parse_json
# A default template for the system primer.
SYSTEM_PRIMER_TEMPLATE = "Act like you are a terminal and always format your response as json. Always return exactly {answer_count} answers per question in English."
class Tool:
"""Define a tool with its name, function and description."""
def __init__(self, name: str, func: Callable, description: str) -> None:
"""Initialize tool."""
self.name = name
self.func = func
self.description = description
# Decorator for the BaseModelAction to wrap it with system primer details
def system_primer_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
system_primer = SYSTEM_PRIMER_TEMPLATE.format(answer_count=kwargs.get('answer_count', 1))
logger.info(system_primer)
return await func(*args, system_primer=system_primer, **kwargs)
return wrapper
class BaseModelAction(Action):
def __init__(self, name: str = "", description: str = "", *args, **kwargs):
super().__init__(name, *args, **kwargs)
self.desc = description
async def handle_response(self, resp: str) -> Any:
"""Handle JSON response and extract value."""
try:
resp_json = flatten_json_structure(try_parse_json(resp))
logger.info(resp_json)
return resp_json
except Exception as exp:
logger.error(f" JSON response {exp}")
return None
@system_primer_decorator
async def run_optimize_or_improve(self, query: str, domain: str, template: str, answer_count: int = 1,
system_primer=None) -> List[str]:
"""Run optimization or improvement based on the given template."""
prompt = template.format(messages=query, domain=domain, answer_count=answer_count)
resp: str = await self._aask(prompt=prompt, system_msgs=[system_primer])
result = await self.handle_response(resp)
return result or [query]
class SDPromptOptimize(BaseModelAction):
"""
Optimize graphical prompts based on keywords.
扩充画图的提示词根据keyword
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, description="PromptOptimize", *args, **kwargs)
async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]:
"""Run the optimization for the given query."""
return await self.run_optimize_or_improve(query, domain, SD_PROMPT_KW_OPTIMIZE_TEMPLATE,
answer_count=answer_count)
class SDPromptImprove(BaseModelAction):
"""Enhance the input prompt (recommended when the input prompt is long).
fixme: 接入提示词优化的FT模型
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, description="PromptImprove", *args, **kwargs)
async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]:
"""Run the improvement for the given query."""
return await self.run_optimize_or_improve(query, domain, SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE,
answer_count=answer_count)
class SDPromptExtend(BaseModelAction):
"""Action class to extend the prompt."""
def __init__(self, name: str = "", tools: Optional[List[Tool]] = [], **kwargs):
super().__init__(name, description="Prompt Extend", **kwargs)
self.tools = tools
logger.info(self.tools)
def _parse_tools(self) -> Tuple[str, str]:
"""Parse tool names and descriptions."""
tool_strings = [f"{tool.name}: {tool.description}" for tool in self.tools]
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join(tool.name for tool in self.tools)
return tool_names, formatted_tools
async def run(self, query: str, answer_count: int = 1, domain: str = "realistic",
model_name="realisticVisionV30_v30VAE") -> str:
"""Extend the prompt and get the "Final Action" from the output."""
tool_names, formatted_tools = self._parse_tools()
msg = FORMAT_INSTRUCTIONS.format(query=query, tool_names=tool_names,
tool_description=formatted_tools,
model_name=model_name, domain=domain)
resp = await self._aask(msg)
output_block = OutputParser.parse_data_with_mapping(resp, PROMPT_OUTPUT_MAPPING)
return output_block["Final Action"]

View file

@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# @Date : 2023/8/17 13:43
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import List, Union
from metagpt.tools.sd_engine import SDEngine
from metagpt.actions.design import BaseModelAction
from metagpt.prompts.sd_design import MODEL_SELECTION_PROMPT
class SDPromptRanker(BaseModelAction):
"""
Class responsible for ranking multiple prompts based on current requirements and
the underlying model to determine the most suitable prompt.
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, description="Prompt ranker", *args, **kwargs)
class SDImgScorer(BaseModelAction):
"""
根据多个SD的生成结果进行美学评分选出评分最高的图片
Class responsible for aesthetically scoring multiple SD generated results and
selecting the highest scoring image.
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, description="Image Scorer", *args, **kwargs)
class LoraSelection(BaseModelAction):
"""
Class responsible for selecting the most suitable Lora based on the
current model and requirements.
"""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, *args, **kwargs)
class ModelSelection(BaseModelAction):
DEFAULT_MODEL_INFO = {
"realisticVisionV30_v30VAE": "Real Effects, Real Photo/Photography, v3.0",
"pixelmix_v10": "an anime model merge with finetuned lineart and eyes."
}
def __init__(self, name="ModelSelection", *args, **kwargs):
super().__init__(name, description="Select models", *args, **kwargs)
def add_models(self, model_name="", model_desc=""):
updated_info = {model_name: model_desc} if model_name else {}
return {**self.DEFAULT_MODEL_INFO, **updated_info}
async def run(self, query: str, system_text: str = "model selection"):
prompt = MODEL_SELECTION_PROMPT.format(query=query, model_info=self.add_models())
resp = await self._aask(prompt=prompt, system_msgs=[system_text])
result = resp.split("||")
model_name = result[0].replace("Model:", "").strip()
domain = result[-1].replace("Domain:", "").strip()
return model_name, domain
class SDGeneration(BaseModelAction):
"""Generates an image via the sd t2i API."""
def __init__(self, name: str = "", *args, **kwargs):
super().__init__(name, description="Stable Diffusion Generator", *args, **kwargs)
self.engine = SDEngine()
self.negative_prompts = {"realisticVisionV30_v30VAE": "worst quality, low quality, easynegative",
"pixelmix_v10": ""}
def _construct_prompt(self, query: str, model_name: str) -> str:
"""Constructs a prompt for the provided query and model."""
negative_prompt = self.negative_prompts.get(model_name, "")
return self.engine.construct_payload(query, negative_prompt=negative_prompt, sd_model=model_name)
async def _generate_image(self, queries: List[str], model_name: str, img_name: str) -> None:
"""Generates image(s) using the provided queries and model name."""
prompts = [self._construct_prompt(query, model_name) for query in queries]
await self.engine.run_t2i(prompts, save_name=img_name)
async def run(self, query: Union[str, List[str]], model_name: str, **kwargs) -> None:
"""
Generate image via sd t2i API.
"""
img_name = kwargs.get("image_name", "")
queries = [query] if isinstance(query, str) else query
await self._generate_image(queries, model_name, img_name)

View file

@ -27,7 +27,7 @@ payload = {
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 7,
"cfg_scale": 9,
"width": 512,
"height": 768,
"restore_faces": False,
@ -62,52 +62,54 @@ class SDEngine:
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
self,
prompt,
negative_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
**kwargs
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
self.payload["negative_prompt"] = negative_prompt
self.payload["width"] = width
self.payload["height"] = height
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
self.payload.update(**kwargs)
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
if not os.path.exists(save_dir):
if not save_dir.exists():
os.makedirs(save_dir, exist_ok=True)
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
async def run_t2i(self, prompts: List):
async def run_t2i(self, prompts: List, save_name=""):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
for payload_idx, payload in enumerate(prompts):
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
self._save(results, save_name=f"{save_name}_output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
@ -128,8 +130,8 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
if __name__ == "__main__":
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))