update code

This commit is contained in:
stellahsr 2023-09-12 22:14:30 +08:00
parent 8df7c2c02c
commit 74dc79a3dd
8 changed files with 629 additions and 12 deletions

View file

@ -0,0 +1,101 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/18 09:51
@Author : stellahong
@File : __init__.py
"""
MODEL_SELECTION_PROMPT = """Please help me find a suitable model for painting in this scene.
Model list will be given in the format like:
'''
model_name: model desc,
'''
you should select the model and tell me the model name. answer it in the form like Model: model_name || Domain:xxx
###
Model List:
{model_info}
My scene is: {query}
"""
DOMAIN_JUDGEMENT_TEMPLATE = '''
use model {model_name}, decide the domain, answer it in the form like Domain: xxx
###
Model Information:
{model_info}
'''
MODEL_SELECTION_OUTPUT_MAPPING = {
"Model:": (str, ...), }
SD_PROMPT_KW_OPTIMIZE_TEMPLATE = '''
I want you to act as a prompt generator. Compose each answer as a visual sentence. Do not write explanations on replies. Format the answers as javascript json arrays with a single string per answer. Return exactly {answer_count} to my question. Answer the questions exactly, in the form like responses:xxx. Answer the following question:
Find 3 keywords related to the prompt "{messages}" that are not found in the prompt. The keywords should be related to each other. Each keyword is a single word.
'''
SD_PROMPT_IMPROVE_OPTIMIZE_TEMPLATE = '''
I want you to act as a prompt generator. Compose each answer as a visual sentence. Do not write explanations on replies. Format the answers as javascript json arrays with a single string per answer. Return exactly {answer_count} to my question. Answer the questions exactly, in the form like responses:xxx. Answer the following question:
domain is {domain}
if domain is anime or game like, Take the prompt "{messages}, Cute kawaii sticker , white background, vector, pastel colors" and improve it.
if domain is realistic like, Take the prompt "{messages}" and improve it.
'''
# Die-cut sticker, illustration minimalism,
FORMAT_INSTRUCTIONS = """The problem is to make the user input a better text2image prompt, the input is {query}"
Let's first understand the problem and devise a plan to solve the problem.
Based on the text2image model selected {model_name} and domain {domain}
You have access to the following tools:
{tool_names}
{tool_description}
Use a json blob to specify a tool by providing an action key (tool name) and an Observation (tool description).
Valid "action" values: {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"Observation": $TOOL_DESCRIPTION
}}}}
```
Follow this format:
## Think Chain
```
Question: input question to answer
Thought: select a better method for the input by go through these two tools and its observations respectively
Action1:
```
$JSON_BLOB
```
Action2:
```
$JSON_BLOB
```
Thought:When evaluating a prompt's richness, I need to specify which tool to use and I can only select one tool . To finish this selection, in the form:
## Final Action:
TOOL_NAME
"""
PROMPT_OUTPUT_MAPPING = {
"Final Action:": (str, ...),
}

View file

@ -0,0 +1,205 @@
# -*- coding: utf-8 -*-
# @Date : 2023/8/16 13:58
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from functools import wraps
import json5
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.actions.design import Tool, SDPromptExtend, SDPromptOptimize, SDPromptImprove
from metagpt.actions.ui_design import ModelSelection, SDGeneration
def retrieve(func):
@wraps(func)
def wrapper(*args, **kwargs):
content, keyword = func(*args, **kwargs)
info = content.replace(keyword, "")
return info
return wrapper
class Designer(Role):
"""Class representing the UI designer Role."""
def __init__(
self,
name="Catherine",
profile="UI Design",
goal="Generate UI icon",
constraints="Give clear icon description and generate images to finish the design",
actions=[ModelSelection, SDPromptExtend, SDGeneration]):
super().__init__(name, profile, goal, constraints)
self._init_actions(actions)
@property
def memory_model_name(self):
return "MODEL_NAME: "
@property
def memory_user_input(self):
return "User Input: "
@property
def memory_domain(self):
return "Domain: "
def memory_property(self, memory_keyword: str, memory_content: str):
self._rc.memory.add(Message(f"{memory_keyword}{memory_content}", role=self.profile))
@retrieve
def get_important_memory(self, keyword: str):
query_memory = self._rc.memory.get_by_content(keyword)[0]
return query_memory.content, keyword
async def _plan_and_select(self):
"""
这里实现的是二选一的optionaction在这里进行了选择
理论上应该可以实现4种选择 (&表示串行顺序目前只选择了前2种
1) action1
2) action2
3) action1 & action2
4) action2 & action1
"""
msg = self._rc.memory.get(k=1)[0]
query = msg.content
logger.info(query)
if query == "PromptImprove":
self._actions.insert(self._rc.state + 1, SDPromptImprove())
elif query == "PromptOptimize":
self._actions.insert(self._rc.state + 1, SDPromptOptimize())
return self._rc.state
async def _think(self) -> None:
logger.info(self._rc.state)
if self._rc.todo is None:
self._set_state(0)
return
if self._rc.state == 1:
await self._plan_and_select()
self._set_state(self._rc.state + 1)
elif self._rc.state + 1 < len(self._actions):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
async def handle_model_selection(self, query, **kwargs):
ms = ModelSelection()
model_name, domain = await ms.run(query)
logger.info(f"{model_name}, {domain}")
self.memory_property(self.memory_user_input, query)
self.memory_property(self.memory_model_name, model_name)
self.memory_property(self.memory_domain, domain)
return f"{model_name}||{domain}"
async def handle_sd_prompt_extend(self, *args, **kwargs):
tools = [
Tool(name="PromptOptimize",
func=SDPromptOptimize().run,
description="Find 3 keywords related to the prompt that are not found in the prompt. The keywords should be related to each other. Each keyword is a single word. useful for when you need to add extra keywords for input prompt, specially for long enough input"),
Tool(name="PromptImprove",
func=SDPromptImprove().run,
description="Take the prompt and improve it. useful for when you need to add improve and extend the prompt for input prompt, specially for short input"),
]
query = self.get_important_memory(self.memory_user_input)
domain = self.get_important_memory(self.memory_domain)
sd_exd = SDPromptExtend(tools=tools)
resp = await sd_exd.run(query=query, domain=domain, answer_count=1)
return resp
async def handle_sd_prompt_improve(self, *args, **kwargs):
query = self.get_important_memory(self.memory_user_input)
domain = self.get_important_memory(self.memory_domain)
sd_pi = SDPromptImprove()
resp = await sd_pi.run(query=query, domain=domain, answer_count=1)
return resp
async def handle_sd_prompt_optimize(self, *args, **kwargs):
query = self.get_important_memory(self.memory_user_input)
domain = self.get_important_memory(self.memory_domain)
sd_op = SDPromptOptimize()
resp = await sd_op.run(query=query, domain=domain, answer_count=1)
return resp
async def handle_sd_generation(self, *args, **kwargs):
msg = self._rc.memory.get_by_action(SDPromptImprove)[0]
image_name = self.get_important_memory(self.memory_user_input)
logger.info(type(msg.content))
logger.info(msg.content)
resp = json5.loads(msg.content)
logger.info(resp)
model_name = self.get_important_memory(self.memory_model_name)
await SDGeneration().run(query=resp, model_name=model_name, **{"image_name":image_name})
return resp
async def _act(self) -> Message:
logger.info(f"{self._setting}: ready to {self._rc.todo}")
todo = self._rc.todo
msg = self._rc.memory.get(k=1)[0]
query = msg.content
logger.info(msg.cause_by)
logger.info(query)
logger.info(todo)
handler_map = {
ModelSelection: self.handle_model_selection,
SDPromptExtend: self.handle_sd_prompt_extend,
SDPromptImprove: self.handle_sd_prompt_improve,
SDPromptOptimize: self.handle_sd_prompt_optimize,
SDGeneration: self.handle_sd_generation,
}
handler = handler_map.get(type(todo))
if handler:
resp = await handler(query)
if type(todo) in [SDPromptImprove, SDPromptOptimize]:
ret = Message(f"{resp}", role=self.profile, cause_by=SDPromptImprove)
else:
ret = Message(f"{resp}", role=self.profile, cause_by=type(todo))
self._rc.memory.add(ret)
return ret
raise ValueError(f"Unknown todo type: {type(todo)}")
async def _react(self) -> Message:
while True:
await self._think()
if self._rc.todo is None:
break
msg = await self._act()
return msg
if __name__ == "__main__":
import asyncio
import platform
test_queries = ["Flappy Bird",
"Clash of Clans",
"Subway Surfers",
"Pokémon Go",
"Super Mario",
"Tetris",
"Call of Duty"
]
for prompt in test_queries:
designer = Designer()
if platform.system() == "Windows":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(designer.run(prompt))

View file

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
# @Date : 2023/8/22 22:18
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import json5
import re
def flatten_json_structure(json_array):
if (isinstance(json_array, list) and len(json_array) == 1 and not isinstance(json_array[0], str)):
return flatten_json_structure(json_array[0])
if (isinstance(json_array, dict) and len(json_array.values()) == 1 and not isinstance(list(json_array.values())[0],
str)):
return flatten_json_structure(list(json_array.values())[0])
flattened_json_array = []
if (isinstance(json_array, dict)):
json_array = json_array.values()
for json_object in json_array:
flattened_dict = flatten_json_object(json_object)
flattened_values = ", ".join(str(v) for v in flattened_dict.values())
flattened_json_array.append(flattened_values)
return flattened_json_array
def flatten_json_object(obj, parent_key='', sep=', '):
if isinstance(obj, str):
return dict([("value", obj)])
if isinstance(obj, list):
return dict([("value", sep.join(str(v) for v in obj))])
items = []
for key, value in obj.items():
new_key = f"{parent_key}{sep}{key}" if parent_key else key
if isinstance(value, dict):
items.extend(flatten_json_object(value, new_key, sep=sep).items())
elif isinstance(value, list):
items.append((new_key, sep.join(str(v) for v in value)))
else:
items.append((new_key, value))
return dict(items)
def try_parse_json(input_text):
input_text.index
start_index_brackets = input_text.find('[')
end_index_brackets = input_text.rfind(']')
start_index_curly = input_text.find('{')
end_index_curly = input_text.rfind('}')
start_index = start_index_brackets
end_index = end_index_brackets
if (start_index_curly != -1 and (start_index_curly < start_index_brackets or start_index_brackets < 0)):
start_index = start_index_curly
end_index = end_index_curly
if start_index >= 0 and end_index > 0:
json_string = input_text[start_index:end_index + 1]
json_string = re.sub(r'\}[\s]*\{', '}, {', json_string)
json_string = re.sub(r'\][\s]*\[', '], [', json_string)
json_string = re.sub(r'\"[\s]*\"', '", "', json_string)
try:
json_object = json5.loads(json_string)
except ValueError:
json_object = json5.loads(f"[{json_string}]")
return json_object
raise Exception("No JSON object found in input text.")