merge main

This commit is contained in:
leonzh0u 2023-07-23 17:02:02 -04:00
commit 242bbbc13a
110 changed files with 959 additions and 253 deletions

View file

@ -5,7 +5,6 @@
@Author : alexanderwu
@File : prompt_writer.py
"""
from abc import ABC
from typing import Union

135
metagpt/tools/sd_engine.py Normal file
View file

@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import os
import asyncio
from os.path import join
from typing import List
import json
import io
import base64
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
from metagpt.logs import logger
from metagpt.config import Config
from metagpt.const import WORKSPACE_ROOT
config = Config()
payload = {
"prompt": "",
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
"override_settings": {
"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"
},
"seed": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 7,
"width": 512,
"height": 768,
"restore_faces": False,
"tiling": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
'enable_hr': False,
'hr_scale': 2,
'hr_upscaler': 'Latent',
'hr_second_pass_steps': 0,
'hr_resize_x': 0,
'hr_resize_y': 0,
'hr_upscale_to_x': 0,
'hr_upscale_to_y': 0,
'truncate_x': 0,
'truncate_y': 0,
'applied_old_hires_behavior_to': None,
"eta": None,
"sampler_index": "DPM++ SDE Karras",
"alwayson_scripts": {}
}
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
class SDEngine:
def __init__(self):
# Initialize the SDEngine with configuration
self.config = Config()
self.sd_url = self.config.get('SD_URL')
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
sd_model="galaxytimemachinesGTM_photoV20"):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
self.payload["width"] = width
self.payload["height"] = height
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
save_dir = WORKSPACE_ROOT / "resources"/"SD_Output"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
async def run_t2i(self, prompts: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
for payload_idx, payload in enumerate(prompts):
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json['images']
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
def decode_base64_to_image(img, save_name):
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
pnginfo = PngImagePlugin.PngInfo()
logger.info(save_name)
image.save(f"{save_name}.png", pnginfo=pnginfo)
return pnginfo, image
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
for idx, _img in enumerate(imgs):
save_name = join(save_dir, save_name)
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
import asyncio
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))

View file

@ -9,10 +9,8 @@ from __future__ import annotations
import json
from metagpt.logs import logger
from duckduckgo_search import ddg
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
from metagpt.tools.search_engine_serper import SerperWrapper
@ -55,11 +53,11 @@ class SearchEngine:
return rsp
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
"""Return the results of a Google search using the official Google API
def google_official_search(queries: list[str], num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
"""Return the results of a batch of Google search using the official Google API
Args:
query (str): The search query.
queries (list[str]): A batch of search queries.
num_results (int): The number of results to return.
Returns:
@ -73,8 +71,12 @@ def google_official_search(query: str, num_results: int = 8, focus=['snippet', '
api_key = config.google_api_key
custom_search_engine_id = config.google_cse_id
service = build("customsearch", "v1", developerKey=api_key)
service = build("customsearch", "v2", developerKey=api_key)
batch = service.new_batch_http_request()
for query in queries:
batch.add(service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results))
batch.execute()
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)

View file

@ -6,10 +6,10 @@
@File : search_engine_meilisearch.py
"""
from metagpt.logs import logger
from typing import List
import meilisearch
from meilisearch.index import Index
from typing import List
class DataSource:

View file

@ -6,7 +6,7 @@
@File : search_engine_serpapi.py
"""
from typing import Any, Dict, Optional, Tuple
from metagpt.logs import logger
import aiohttp
from pydantic import BaseModel, Field

View file

@ -5,10 +5,10 @@
@Author : alexanderwu
@File : search_engine_serpapi.py
"""
from typing import Any, Dict, Optional, Tuple
from metagpt.logs import logger
import aiohttp
import json
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field
from metagpt.config import Config
@ -55,7 +55,6 @@ class SerperWrapper(BaseModel):
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers) as response:
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
res = await response.json()

View file

@ -6,7 +6,6 @@ from pathlib import Path
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
ICL_SAMPLE = '''接口定义:
```text
接口名称元素打标签