mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
update
This commit is contained in:
parent
c28034ccbc
commit
8df7c2c02c
3 changed files with 246 additions and 20 deletions
130
metagpt/actions/design.py
Normal file
130
metagpt/actions/design.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/8/17 13:43
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
# Standard library imports
|
||||
from functools import wraps
|
||||
from typing import Callable, Any, List, Optional, Tuple
|
||||
|
||||
# Local library imports
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.prompts.sd_design import (
|
||||
SD_PROMPT_KW_OPTIMIZE_TEMPLATE,
|
||||
SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE,
|
||||
FORMAT_INSTRUCTIONS,
|
||||
PROMPT_OUTPUT_MAPPING
|
||||
)
|
||||
from metagpt.utils.resp_parse import flatten_json_structure, try_parse_json
|
||||
|
||||
# A default template for the system primer.
|
||||
SYSTEM_PRIMER_TEMPLATE = "Act like you are a terminal and always format your response as json. Always return exactly {answer_count} answers per question in English."
|
||||
|
||||
|
||||
class Tool:
|
||||
"""Define a tool with its name, function and description."""
|
||||
|
||||
def __init__(self, name: str, func: Callable, description: str) -> None:
|
||||
"""Initialize tool."""
|
||||
self.name = name
|
||||
self.func = func
|
||||
self.description = description
|
||||
|
||||
|
||||
# Decorator for the BaseModelAction to wrap it with system primer details
|
||||
def system_primer_decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
system_primer = SYSTEM_PRIMER_TEMPLATE.format(answer_count=kwargs.get('answer_count', 1))
|
||||
logger.info(system_primer)
|
||||
return await func(*args, system_primer=system_primer, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class BaseModelAction(Action):
|
||||
|
||||
def __init__(self, name: str = "", description: str = "", *args, **kwargs):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
self.desc = description
|
||||
|
||||
async def handle_response(self, resp: str) -> Any:
|
||||
"""Handle JSON response and extract value."""
|
||||
try:
|
||||
resp_json = flatten_json_structure(try_parse_json(resp))
|
||||
logger.info(resp_json)
|
||||
return resp_json
|
||||
|
||||
except Exception as exp:
|
||||
logger.error(f" JSON response {exp}")
|
||||
return None
|
||||
|
||||
@system_primer_decorator
|
||||
async def run_optimize_or_improve(self, query: str, domain: str, template: str, answer_count: int = 1,
|
||||
system_primer=None) -> List[str]:
|
||||
"""Run optimization or improvement based on the given template."""
|
||||
prompt = template.format(messages=query, domain=domain, answer_count=answer_count)
|
||||
resp: str = await self._aask(prompt=prompt, system_msgs=[system_primer])
|
||||
result = await self.handle_response(resp)
|
||||
return result or [query]
|
||||
|
||||
|
||||
class SDPromptOptimize(BaseModelAction):
|
||||
"""
|
||||
Optimize graphical prompts based on keywords.
|
||||
扩充画图的提示词,根据keyword
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", *args, **kwargs):
|
||||
super().__init__(name, description="PromptOptimize", *args, **kwargs)
|
||||
|
||||
async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]:
|
||||
"""Run the optimization for the given query."""
|
||||
return await self.run_optimize_or_improve(query, domain, SD_PROMPT_KW_OPTIMIZE_TEMPLATE,
|
||||
answer_count=answer_count)
|
||||
|
||||
|
||||
class SDPromptImprove(BaseModelAction):
|
||||
"""Enhance the input prompt (recommended when the input prompt is long).
|
||||
fixme: 接入提示词优化的FT模型
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", *args, **kwargs):
|
||||
super().__init__(name, description="PromptImprove", *args, **kwargs)
|
||||
|
||||
async def run(self, query: str, domain: str = "realistic", answer_count: int = 1) -> List[str]:
|
||||
"""Run the improvement for the given query."""
|
||||
return await self.run_optimize_or_improve(query, domain, SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE,
|
||||
answer_count=answer_count)
|
||||
|
||||
|
||||
class SDPromptExtend(BaseModelAction):
|
||||
"""Action class to extend the prompt."""
|
||||
|
||||
def __init__(self, name: str = "", tools: Optional[List[Tool]] = [], **kwargs):
|
||||
super().__init__(name, description="Prompt Extend", **kwargs)
|
||||
self.tools = tools
|
||||
logger.info(self.tools)
|
||||
|
||||
def _parse_tools(self) -> Tuple[str, str]:
|
||||
"""Parse tool names and descriptions."""
|
||||
tool_strings = [f"{tool.name}: {tool.description}" for tool in self.tools]
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join(tool.name for tool in self.tools)
|
||||
return tool_names, formatted_tools
|
||||
|
||||
async def run(self, query: str, answer_count: int = 1, domain: str = "realistic",
|
||||
model_name="realisticVisionV30_v30VAE") -> str:
|
||||
"""Extend the prompt and get the "Final Action" from the output."""
|
||||
tool_names, formatted_tools = self._parse_tools()
|
||||
msg = FORMAT_INSTRUCTIONS.format(query=query, tool_names=tool_names,
|
||||
tool_description=formatted_tools,
|
||||
model_name=model_name, domain=domain)
|
||||
|
||||
resp = await self._aask(msg)
|
||||
output_block = OutputParser.parse_data_with_mapping(resp, PROMPT_OUTPUT_MAPPING)
|
||||
return output_block["Final Action"]
|
||||
|
||||
|
||||
94
metagpt/actions/ui_design.py
Normal file
94
metagpt/actions/ui_design.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/8/17 13:43
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from typing import List, Union
|
||||
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
from metagpt.actions.design import BaseModelAction
|
||||
from metagpt.prompts.sd_design import MODEL_SELECTION_PROMPT
|
||||
|
||||
|
||||
class SDPromptRanker(BaseModelAction):
|
||||
"""
|
||||
Class responsible for ranking multiple prompts based on current requirements and
|
||||
the underlying model to determine the most suitable prompt.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", *args, **kwargs):
|
||||
super().__init__(name, description="Prompt ranker", *args, **kwargs)
|
||||
|
||||
|
||||
class SDImgScorer(BaseModelAction):
|
||||
"""
|
||||
根据多个SD的生成结果,进行美学评分,选出评分最高的图片
|
||||
Class responsible for aesthetically scoring multiple SD generated results and
|
||||
selecting the highest scoring image.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", *args, **kwargs):
|
||||
super().__init__(name, description="Image Scorer", *args, **kwargs)
|
||||
|
||||
|
||||
class LoraSelection(BaseModelAction):
|
||||
"""
|
||||
Class responsible for selecting the most suitable Lora based on the
|
||||
current model and requirements.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", *args, **kwargs):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
|
||||
|
||||
class ModelSelection(BaseModelAction):
|
||||
DEFAULT_MODEL_INFO = {
|
||||
"realisticVisionV30_v30VAE": "Real Effects, Real Photo/Photography, v3.0",
|
||||
"pixelmix_v10": "an anime model merge with finetuned lineart and eyes."
|
||||
}
|
||||
|
||||
def __init__(self, name="ModelSelection", *args, **kwargs):
|
||||
super().__init__(name, description="Select models", *args, **kwargs)
|
||||
|
||||
def add_models(self, model_name="", model_desc=""):
|
||||
updated_info = {model_name: model_desc} if model_name else {}
|
||||
return {**self.DEFAULT_MODEL_INFO, **updated_info}
|
||||
|
||||
async def run(self, query: str, system_text: str = "model selection"):
|
||||
prompt = MODEL_SELECTION_PROMPT.format(query=query, model_info=self.add_models())
|
||||
resp = await self._aask(prompt=prompt, system_msgs=[system_text])
|
||||
result = resp.split("||")
|
||||
model_name = result[0].replace("Model:", "").strip()
|
||||
domain = result[-1].replace("Domain:", "").strip()
|
||||
return model_name, domain
|
||||
|
||||
|
||||
class SDGeneration(BaseModelAction):
|
||||
"""Generates an image via the sd t2i API."""
|
||||
|
||||
def __init__(self, name: str = "", *args, **kwargs):
|
||||
super().__init__(name, description="Stable Diffusion Generator", *args, **kwargs)
|
||||
self.engine = SDEngine()
|
||||
self.negative_prompts = {"realisticVisionV30_v30VAE": "worst quality, low quality, easynegative",
|
||||
"pixelmix_v10": ""}
|
||||
|
||||
def _construct_prompt(self, query: str, model_name: str) -> str:
|
||||
"""Constructs a prompt for the provided query and model."""
|
||||
negative_prompt = self.negative_prompts.get(model_name, "")
|
||||
return self.engine.construct_payload(query, negative_prompt=negative_prompt, sd_model=model_name)
|
||||
|
||||
async def _generate_image(self, queries: List[str], model_name: str, img_name: str) -> None:
|
||||
"""Generates image(s) using the provided queries and model name."""
|
||||
prompts = [self._construct_prompt(query, model_name) for query in queries]
|
||||
await self.engine.run_t2i(prompts, save_name=img_name)
|
||||
|
||||
async def run(self, query: Union[str, List[str]], model_name: str, **kwargs) -> None:
|
||||
"""
|
||||
Generate image via sd t2i API.
|
||||
"""
|
||||
img_name = kwargs.get("image_name", "")
|
||||
|
||||
queries = [query] if isinstance(query, str) else query
|
||||
await self._generate_image(queries, model_name, img_name)
|
||||
|
|
@ -27,7 +27,7 @@ payload = {
|
|||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 20,
|
||||
"cfg_scale": 7,
|
||||
"cfg_scale": 9,
|
||||
"width": 512,
|
||||
"height": 768,
|
||||
"restore_faces": False,
|
||||
|
|
@ -62,52 +62,54 @@ class SDEngine:
|
|||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
**kwargs
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
self.payload["negative_prompt"] = negative_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
self.payload.update(**kwargs)
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
if not save_dir.exists():
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
|
||||
async def run_t2i(self, prompts: List, save_name=""):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(prompts):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
self._save(results, save_name=f"{save_name}_output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
|
@ -128,8 +130,8 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue