resolved current conflicts

This commit is contained in:
brucemeek 2023-08-15 06:51:39 -05:00
parent 730e2f912f
commit 193178b7d1
11 changed files with 290 additions and 269 deletions

View file

@ -2,29 +2,27 @@
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import os
import asyncio
import base64
import io
import json
import os
from os.path import join
from typing import List
import json
import io
import base64
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
from metagpt.logs import logger
from metagpt.config import Config
from metagpt.const import WORKSPACE_ROOT
from metagpt.logs import logger
config = Config()
payload = {
"prompt": "",
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
"override_settings": {
"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"
},
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
"seed": -1,
"batch_size": 1,
"n_iter": 1,
@ -36,21 +34,20 @@ payload = {
"tiling": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
'enable_hr': False,
'hr_scale': 2,
'hr_upscaler': 'Latent',
'hr_second_pass_steps': 0,
'hr_resize_x': 0,
'hr_resize_y': 0,
'hr_upscale_to_x': 0,
'hr_upscale_to_y': 0,
'truncate_x': 0,
'truncate_y': 0,
'applied_old_hires_behavior_to': None,
"enable_hr": False,
"hr_scale": 2,
"hr_upscaler": "Latent",
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
"hr_upscale_to_x": 0,
"hr_upscale_to_y": 0,
"truncate_x": 0,
"truncate_y": 0,
"applied_old_hires_behavior_to": None,
"eta": None,
"sampler_index": "DPM++ SDE Karras",
"alwayson_scripts": {}
"alwayson_scripts": {},
}
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
@ -60,14 +57,20 @@ class SDEngine:
def __init__(self):
# Initialize the SDEngine with configuration
self.config = Config()
self.sd_url = self.config.get('SD_URL')
self.sd_url = self.config.get("SD_URL")
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
sd_model="galaxytimemachinesGTM_photoV20"):
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
@ -76,13 +79,13 @@ class SDEngine:
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
save_dir = WORKSPACE_ROOT / "resources"/"SD_Output"
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
async def run_t2i(self, prompts: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
@ -90,24 +93,25 @@ class SDEngine:
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json['images']
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: Add image-to-image interface call
raise NotImplementedError
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
async def run_sam(self):
# todo: Add SAM interface call
raise NotImplementedError
def decode_base64_to_image(img, save_name):
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
@ -122,12 +126,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
import asyncio
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))

View file

@ -7,118 +7,76 @@
"""
from __future__ import annotations
import json
import importlib
from typing import Callable, Coroutine, Literal, overload
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
from metagpt.tools.search_engine_serper import SerperWrapper
config = Config()
from metagpt.config import CONFIG
from metagpt.tools import SearchEngineType
class SearchEngine:
"""
TODO: Integrate Google Search and reverse proxy.
Note: Google here requires a Proxifier or similar global proxy.
- DDG: https://pypi.org/project/duckduckgo-search/
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
"""
def __init__(self, engine=None, run_func=None):
self.config = Config()
self.run_func = run_func
self.engine = engine or self.config.search_engine
"""Class representing a search engine.
@classmethod
def run_google(cls, query, max_results=8):
# results = ddg(query, max_results=max_results)
results = google_official_search(query, num_results=max_results)
logger.info(results)
return results
Args:
engine: The search engine type. Defaults to the search engine specified in the config.
run_func: The function to run the search. Defaults to None.
async def run(self, query: str, max_results=8):
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
api = SerpAPIWrapper()
rsp = await api.run(query)
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
rsp = SearchEngine.run_google(query, max_results)
elif self.engine == SearchEngineType.SERPER_GOOGLE:
api = SerperWrapper()
rsp = await api.run(query)
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
rsp = self.run_func(query)
Attributes:
run_func: The function to run the search.
engine: The search engine type.
"""
def __init__(
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
):
engine = engine or CONFIG.search_engine
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper().run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper().run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper().run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper().run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:
raise NotImplementedError
return rsp
self.engine = engine
self.run_func = run_func
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
"""Return the results of a Google search using the official Google API
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
) -> str:
...
Args:
query (str): The search query.
num_results (int): The number of results to return.
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
) -> list[dict[str, str]]:
...
Returns:
str: The results of the search.
"""
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
"""Run a search query.
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
Args:
query: The search query.
max_results: The maximum number of results to return. Defaults to 8.
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
try:
api_key = config.google_api_key
custom_search_engine_id = config.google_cse_id
with build("customsearch", "v1", developerKey=api_key) as service:
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.execute()
)
logger.info(result)
# Extract the search result items from the response
search_results = result.get("items", [])
# Create a list of only the URLs from the search results
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
except HttpError as e:
# Handle errors in the API call
error_details = json.loads(e.content.decode())
# Check if the error is related to an invalid or missing API key
if error_details.get("error", {}).get(
"code"
) == 403 and "invalid API key" in error_details.get("error", {}).get(
"message", ""
):
return "Error: The provided Google API key is invalid or missing."
else:
return f"Error: {e}"
# Return the list of search result URLs
return search_results_details
def safe_google_results(results: str | list) -> str:
"""
Return the results of a google search in a safe format.
Args:
results (str | list): The search results.
Returns:
str: The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps(
[result for result in results]
)
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message
if __name__ == '__main__':
SearchEngine.run(query='wtf')
Returns:
The search results as a string or a list of dictionaries.
"""
return await self.run_func(query, max_results=max_results, as_string=as_string)

View file

@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, **kwargs: Any) -> str:
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.results(query))
return self._process_response(await self.results(query, max_results), as_string=as_string)
async def results(self, query: str) -> dict:
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
if self.serpapi_api_key:
params["serp_api_key"] = self.serpapi_api_key
params["output"] = "json"
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
return params
@staticmethod
def _process_response(res: dict) -> str:
def _process_response(res: dict, as_string: bool) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
focus = ["title", "snippet", "link"]
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
toret = res["sports_results"]["game_spotlight"]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["snippet"]
@ -112,5 +104,10 @@ class SerpAPIWrapper(BaseModel):
if res.get("organic_results"):
toret_l += [get_focused(i) for i in res.get("organic_results")]
return str(toret) + '\n' + str(toret_l)
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
if __name__ == "__main__":
import fire
fire.Fire(SerpAPIWrapper().run)

View file

@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, **kwargs: Any) -> str:
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
queries = query.split("\n")
return "\n".join([self._process_response(res) for res in await self.results(queries)])
if isinstance(query, str):
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
else:
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
return "\n".join(results) if as_string else results
async def results(self, queries: list[str]) -> dict:
async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
payloads = self.get_payloads(queries)
payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
return res
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
"""Get payloads for Serper."""
payloads = []
for query in queries:
_payload = {
"q": query,
"num": max_results,
}
payloads.append({**self.payload, **_payload})
return json.dumps(payloads, sort_keys=True)
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
return headers
@staticmethod
def _process_response(res: dict) -> str:
def _process_response(res: dict, as_string: bool = False) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
@ -117,5 +121,10 @@ class SerperWrapper(BaseModel):
if res.get("organic"):
toret_l += [get_focused(i) for i in res.get("organic")]
return str(toret) + '\n' + str(toret_l)
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
if __name__ == "__main__":
import fire
fire.Fire(SerperWrapper().run)

View file

@ -1,22 +1,20 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import importlib
from typing import Any, Callable, Coroutine, overload
import importlib
from typing import Any, Callable, Coroutine, Literal, overload
from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
from bs4 import BeautifulSoup
from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType | None = None,
run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None,
parse_func: Callable[[str], str] | None = None,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
engine = engine or CONFIG.web_browser_engine
@ -30,31 +28,25 @@ class WebBrowserEngine:
run_func = run_func
else:
raise NotImplementedError
self.parse_func = parse_func or get_page_content
self.run_func = run_func
self.engine = engine
@overload
async def run(self, url: str) -> str:
async def run(self, url: str) -> WebPage:
...
@overload
async def run(self, url: str, *urls: str) -> list[str]:
async def run(self, url: str, *urls: str) -> list[WebPage]:
...
async def run(self, url: str, *urls: str) -> str | list[str]:
page = await self.run_func(url, *urls)
if isinstance(page, str):
return self.parse_func(page)
return [self.parse_func(i) for i in page]
def get_page_content(page: str):
soup = BeautifulSoup(page, "html.parser")
return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"]))
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
return await self.run_func(url, *urls)
if __name__ == "__main__":
text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/"))
print(text)
import fire
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -2,16 +2,17 @@
from __future__ import annotations
import asyncio
from copy import deepcopy
import importlib
from concurrent import futures
from copy import deepcopy
from typing import Literal
from metagpt.config import CONFIG
import asyncio
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from concurrent import futures
from metagpt.config import CONFIG
from metagpt.utils.parse_html import WebPage
class SeleniumWrapper:
@ -48,7 +49,7 @@ class SeleniumWrapper:
self.loop = loop
self.executor = executor
async def run(self, url: str, *urls: str) -> str | list[str]:
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
await self._run_precheck()
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
@ -69,9 +70,15 @@ class SeleniumWrapper:
def _scrape_website(self, url):
with self._get_driver() as driver:
driver.get(url)
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
return driver.page_source
try:
driver.get(url)
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
inner_text = driver.execute_script("return document.body.innerText;")
html = driver.page_source
except Exception as e:
inner_text = f"Fail to load page content for {e}"
html = ""
return WebPage(inner_text=inner_text, html=html, url=url)
_webdriver_manager_types = {
@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
def _get_driver():
options = Options()
options.add_argument("--headless")
options.add_argument("--enable-javascript")
if browser_type == "chrome":
options.add_argument("--no-sandbox")
for i in args:
@ -107,6 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
if __name__ == "__main__":
text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/"))
print(text)
import fire
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
fire.Fire(main)