diff --git a/metagpt/actions/design.py b/metagpt/actions/design.py new file mode 100644 index 000000000..968640e64 --- /dev/null +++ b/metagpt/actions/design.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/8/17 13:43 +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +# Standard library imports +from functools import wraps +from typing import Callable, Any, List, Optional, Tuple + +# Local library imports +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.utils.common import OutputParser +from metagpt.prompts.sd_design import ( + SD_PROMPT_KW_OPTIMIZE_TEMPLATE, + SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE, + FORMAT_INSTRUCTIONS, + PROMPT_OUTPUT_MAPPING +) +from metagpt.utils.resp_parse import flatten_json_structure, try_parse_json + +# A default template for the system primer. +SYSTEM_PRIMER_TEMPLATE = "Act like you are a terminal and always format your response as json. Always return exactly {answer_count} answers per question in English." + + +class Tool: + """Define a tool with its name, function and description.""" + + def __init__(self, name: str, func: Callable, description: str) -> None: + """Initialize tool.""" + self.name = name + self.func = func + self.description = description + + +# Decorator for the BaseModelAction to wrap it with system primer details +def system_primer_decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + system_primer = SYSTEM_PRIMER_TEMPLATE.format(answer_count=kwargs.get('answer_count', 1)) + logger.info(system_primer) + return await func(*args, system_primer=system_primer, **kwargs) + + return wrapper + + +class BaseModelAction(Action): + + def __init__(self, name: str = "", description: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + self.desc = description + + async def handle_response(self, resp: str) -> Any: + """Handle JSON response and extract value.""" + try: + resp_json = flatten_json_structure(try_parse_json(resp)) + logger.info(resp_json) + return resp_json + + except Exception as exp: + logger.error(f" JSON response {exp}") + return None + + @system_primer_decorator + async def run_optimize_or_improve(self, query: str, domain: str, template: str, answer_count: int = 1, + system_primer=None) -> List[str]: + """Run optimization or improvement based on the given template.""" + prompt = template.format(messages=query, domain=domain, answer_count=answer_count) + resp: str = await self._aask(prompt=prompt, system_msgs=[system_primer]) + result = await self.handle_response(resp) + return result or [query] + + +class SDPromptOptimize(BaseModelAction): + """ + Optimize graphical prompts based on keywords. + 扩充画图的提示词,根据keyword + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="PromptOptimize", *args, **kwargs) + + async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]: + """Run the optimization for the given query.""" + return await self.run_optimize_or_improve(query, domain, SD_PROMPT_KW_OPTIMIZE_TEMPLATE, + answer_count=answer_count) + + +class SDPromptImprove(BaseModelAction): + """Enhance the input prompt (recommended when the input prompt is long). + fixme: 接入提示词优化的FT模型 + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="PromptImprove", *args, **kwargs) + + async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]: + """Run the improvement for the given query.""" + return await self.run_optimize_or_improve(query, domain, SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE, + answer_count=answer_count) + + +class SDPromptExtend(BaseModelAction): + """Action class to extend the prompt.""" + + def __init__(self, name: str = "", tools: Optional[List[Tool]] = [], **kwargs): + super().__init__(name, description="Prompt Extend", **kwargs) + self.tools = tools + logger.info(self.tools) + + def _parse_tools(self) -> Tuple[str, str]: + """Parse tool names and descriptions.""" + tool_strings = [f"{tool.name}: {tool.description}" for tool in self.tools] + formatted_tools = "\n".join(tool_strings) + tool_names = ", ".join(tool.name for tool in self.tools) + return tool_names, formatted_tools + + async def run(self, query: str, answer_count: int = 1, domain: str = "realistic", + model_name="realisticVisionV30_v30VAE") -> str: + """Extend the prompt and get the "Final Action" from the output.""" + tool_names, formatted_tools = self._parse_tools() + msg = FORMAT_INSTRUCTIONS.format(query=query, tool_names=tool_names, + tool_description=formatted_tools, + model_name=model_name, domain=domain) + + resp = await self._aask(msg) + output_block = OutputParser.parse_data_with_mapping(resp, PROMPT_OUTPUT_MAPPING) + return output_block["Final Action"] + + diff --git a/metagpt/actions/ui_design.py b/metagpt/actions/ui_design.py new file mode 100644 index 000000000..693ca904d --- /dev/null +++ b/metagpt/actions/ui_design.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# @Date : 2023/8/17 13:43 +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +from typing import List, Union + +from metagpt.tools.sd_engine import SDEngine + +from metagpt.actions.design import BaseModelAction +from metagpt.prompts.sd_design import MODEL_SELECTION_PROMPT + + +class SDPromptRanker(BaseModelAction): + """ + Class responsible for ranking multiple prompts based on current requirements and + the underlying model to determine the most suitable prompt. + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="Prompt ranker", *args, **kwargs) + + +class SDImgScorer(BaseModelAction): + """ + 根据多个SD的生成结果,进行美学评分,选出评分最高的图片 + Class responsible for aesthetically scoring multiple SD generated results and + selecting the highest scoring image. + + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="Image Scorer", *args, **kwargs) + + +class LoraSelection(BaseModelAction): + """ + Class responsible for selecting the most suitable Lora based on the + current model and requirements. + """ + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, *args, **kwargs) + + +class ModelSelection(BaseModelAction): + DEFAULT_MODEL_INFO = { + "realisticVisionV30_v30VAE": "Real Effects, Real Photo/Photography, v3.0", + "pixelmix_v10": "an anime model merge with finetuned lineart and eyes." + } + + def __init__(self, name="ModelSelection", *args, **kwargs): + super().__init__(name, description="Select models", *args, **kwargs) + + def add_models(self, model_name="", model_desc=""): + updated_info = {model_name: model_desc} if model_name else {} + return {**self.DEFAULT_MODEL_INFO, **updated_info} + + async def run(self, query: str, system_text: str = "model selection"): + prompt = MODEL_SELECTION_PROMPT.format(query=query, model_info=self.add_models()) + resp = await self._aask(prompt=prompt, system_msgs=[system_text]) + result = resp.split("||") + model_name = result[0].replace("Model:", "").strip() + domain = result[-1].replace("Domain:", "").strip() + return model_name, domain + + +class SDGeneration(BaseModelAction): + """Generates an image via the sd t2i API.""" + + def __init__(self, name: str = "", *args, **kwargs): + super().__init__(name, description="Stable Diffusion Generator", *args, **kwargs) + self.engine = SDEngine() + self.negative_prompts = {"realisticVisionV30_v30VAE": "worst quality, low quality, easynegative", + "pixelmix_v10": ""} + + def _construct_prompt(self, query: str, model_name: str) -> str: + """Constructs a prompt for the provided query and model.""" + negative_prompt = self.negative_prompts.get(model_name, "") + return self.engine.construct_payload(query, negative_prompt=negative_prompt, sd_model=model_name) + + async def _generate_image(self, queries: List[str], model_name: str, img_name: str) -> None: + """Generates image(s) using the provided queries and model name.""" + prompts = [self._construct_prompt(query, model_name) for query in queries] + await self.engine.run_t2i(prompts, save_name=img_name) + + async def run(self, query: Union[str, List[str]], model_name: str, **kwargs) -> None: + """ + Generate image via sd t2i API. + """ + img_name = kwargs.get("image_name", "") + + queries = [query] if isinstance(query, str) else query + await self._generate_image(queries, model_name, img_name) diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index 1d9cd0b2a..2e34be9de 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -27,7 +27,7 @@ payload = { "batch_size": 1, "n_iter": 1, "steps": 20, - "cfg_scale": 7, + "cfg_scale": 9, "width": 512, "height": 768, "restore_faces": False, @@ -62,52 +62,54 @@ class SDEngine: # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) - + def construct_payload( - self, - prompt, - negtive_prompt=default_negative_prompt, - width=512, - height=512, - sd_model="galaxytimemachinesGTM_photoV20", + self, + prompt, + negative_prompt=default_negative_prompt, + width=512, + height=512, + sd_model="galaxytimemachinesGTM_photoV20", + **kwargs ): # Configure the payload with provided inputs self.payload["prompt"] = prompt - self.payload["negtive_prompt"] = negtive_prompt + self.payload["negative_prompt"] = negative_prompt self.payload["width"] = width self.payload["height"] = height self.payload["override_settings"]["sd_model_checkpoint"] = sd_model + self.payload.update(**kwargs) logger.info(f"call sd payload is {self.payload}") return self.payload - + def _save(self, imgs, save_name=""): save_dir = WORKSPACE_ROOT / "resources" / "SD_Output" - if not os.path.exists(save_dir): + if not save_dir.exists(): os.makedirs(save_dir, exist_ok=True) batch_decode_base64_to_image(imgs, save_dir, save_name=save_name) - - async def run_t2i(self, prompts: List): + + async def run_t2i(self, prompts: List, save_name=""): # Asynchronously run the SD API for multiple prompts session = ClientSession() for payload_idx, payload in enumerate(prompts): results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) - self._save(results, save_name=f"output_{payload_idx}") + self._save(results, save_name=f"{save_name}_output_{payload_idx}") await session.close() - + async def run(self, url, payload, session): # Perform the HTTP POST request to the SD API async with session.post(url, json=payload, timeout=600) as rsp: data = await rsp.read() - + rsp_json = json.loads(data) imgs = rsp_json["images"] logger.info(f"callback rsp json is {rsp_json.keys()}") return imgs - + async def run_i2i(self): # todo: 添加图生图接口调用 raise NotImplementedError - + async def run_sam(self): # todo:添加SAM接口调用 raise NotImplementedError @@ -128,8 +130,8 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): if __name__ == "__main__": engine = SDEngine() prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" - + engine.construct_payload(prompt) - + event_loop = asyncio.get_event_loop() event_loop.run_until_complete(engine.run_t2i(prompt))