use config

This commit is contained in:
geekan 2024-01-10 20:19:56 +08:00
parent 21cac0bffb
commit 479bbc9b2d
37 changed files with 102 additions and 276 deletions

View file

@ -13,7 +13,7 @@ import aiohttp
import requests
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
@ -47,7 +47,8 @@ class OpenAIText2Embedding:
"""
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
"""
self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY
self.openai_llm = config.get_openai_llm()
self.openai_api_key = openai_api_key or self.openai_llm.api_key
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
"""Text to embedding
@ -57,7 +58,7 @@ class OpenAIText2Embedding:
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
"""
proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {}
proxies = {"proxy": self.openai_llm.proxy} if self.openai_llm.proxy else {}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
data = {"input": text, "model": model}
url = "https://api.openai.com/v1/embeddings"
@ -83,5 +84,5 @@ async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", op
if not text:
return ""
if not openai_api_key:
openai_api_key = CONFIG.OPENAI_API_KEY
openai_api_key = config.get_openai_llm().api_key
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)

View file

@ -1,133 +0,0 @@
# -*- coding: utf-8 -*-
# @Date : 2023/7/19 16:28
# @Author : stellahong (stellahong@deepwisdom.ai)
# @Desc :
import asyncio
import base64
import io
import json
from os.path import join
from typing import List
from aiohttp import ClientSession
from PIL import Image, PngImagePlugin
from metagpt.config import CONFIG
from metagpt.const import SD_OUTPUT_FILE_REPO
from metagpt.logs import logger
payload = {
"prompt": "",
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
"seed": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 20,
"cfg_scale": 7,
"width": 512,
"height": 768,
"restore_faces": False,
"tiling": False,
"do_not_save_samples": False,
"do_not_save_grid": False,
"enable_hr": False,
"hr_scale": 2,
"hr_upscaler": "Latent",
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
"hr_upscale_to_x": 0,
"hr_upscale_to_y": 0,
"truncate_x": 0,
"truncate_y": 0,
"applied_old_hires_behavior_to": None,
"eta": None,
"sampler_index": "DPM++ SDE Karras",
"alwayson_scripts": {},
}
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
class SDEngine:
def __init__(self):
# Initialize the SDEngine with configuration
self.sd_url = CONFIG.get("SD_URL")
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
# Define default payload settings for SD API
self.payload = payload
logger.info(self.sd_t2i_url)
def construct_payload(
self,
prompt,
negtive_prompt=default_negative_prompt,
width=512,
height=512,
sd_model="galaxytimemachinesGTM_photoV20",
):
# Configure the payload with provided inputs
self.payload["prompt"] = prompt
self.payload["negtive_prompt"] = negtive_prompt
self.payload["width"] = width
self.payload["height"] = height
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
logger.info(f"call sd payload is {self.payload}")
return self.payload
def _save(self, imgs, save_name=""):
save_dir = CONFIG.path / SD_OUTPUT_FILE_REPO
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
async def run_t2i(self, prompts: List):
# Asynchronously run the SD API for multiple prompts
session = ClientSession()
for payload_idx, payload in enumerate(prompts):
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
self._save(results, save_name=f"output_{payload_idx}")
await session.close()
async def run(self, url, payload, session):
# Perform the HTTP POST request to the SD API
async with session.post(url, json=payload, timeout=600) as rsp:
data = await rsp.read()
rsp_json = json.loads(data)
imgs = rsp_json["images"]
logger.info(f"callback rsp json is {rsp_json.keys()}")
return imgs
async def run_i2i(self):
# todo: 添加图生图接口调用
raise NotImplementedError
async def run_sam(self):
# todo添加SAM接口调用
raise NotImplementedError
def decode_base64_to_image(img, save_name):
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
pnginfo = PngImagePlugin.PngInfo()
logger.info(save_name)
image.save(f"{save_name}.png", pnginfo=pnginfo)
return pnginfo, image
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
for idx, _img in enumerate(imgs):
save_name = join(save_dir, save_name)
decode_base64_to_image(_img, save_name=save_name)
if __name__ == "__main__":
engine = SDEngine()
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
engine.construct_payload(prompt)
event_loop = asyncio.get_event_loop()
event_loop.run_until_complete(engine.run_t2i(prompt))

View file

@ -7,6 +7,8 @@ import json
from concurrent import futures
from typing import Literal, overload
from metagpt.config2 import config
try:
from duckduckgo_search import DDGS
except ImportError:
@ -15,8 +17,6 @@ except ImportError:
"You can install it by running the command: `pip install -e.[search-ddg]`"
)
from metagpt.config import CONFIG
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
@ -31,8 +31,8 @@ class DDGAPIWrapper:
executor: futures.Executor | None = None,
):
kwargs = {}
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
if config.proxy:
kwargs["proxies"] = config.proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)

View file

@ -11,7 +11,7 @@ from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
try:
@ -35,7 +35,7 @@ class GoogleAPIWrapper(BaseModel):
@field_validator("google_api_key", mode="before")
@classmethod
def check_google_api_key(cls, val: str):
val = val or CONFIG.google_api_key
val = val or config.search["google"].api_key
if not val:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
@ -47,7 +47,7 @@ class GoogleAPIWrapper(BaseModel):
@field_validator("google_cse_id", mode="before")
@classmethod
def check_google_cse_id(cls, val: str):
val = val or CONFIG.google_cse_id
val = val or config.search["google"].cse_id
if not val:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
@ -59,8 +59,8 @@ class GoogleAPIWrapper(BaseModel):
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
if config.proxy:
parse_result = urlparse(config.proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.config2 import config
class SerpAPIWrapper(BaseModel):
@ -32,7 +32,7 @@ class SerpAPIWrapper(BaseModel):
@field_validator("serpapi_api_key", mode="before")
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or CONFIG.serpapi_api_key
val = val or config.search["serpapi"].api_key
if not val:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "

View file

@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.config2 import config
class SerperWrapper(BaseModel):
@ -25,7 +25,7 @@ class SerperWrapper(BaseModel):
@field_validator("serper_api_key", mode="before")
@classmethod
def check_serper_api_key(cls, val: str):
val = val or CONFIG.serper_api_key
val = val or config.search["serper"].api_key
if not val:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "

View file

@ -8,7 +8,6 @@ from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, overload
from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
from metagpt.utils.parse_html import WebPage
@ -19,7 +18,6 @@ class WebBrowserEngine:
engine: WebBrowserEngineType | None = None,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
engine = engine or CONFIG.web_browser_engine
if engine is None:
raise NotImplementedError

View file

@ -12,7 +12,7 @@ from typing import Literal
from playwright.async_api import async_playwright
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.parse_html import WebPage
@ -33,13 +33,13 @@ class PlaywrightWrapper:
**kwargs,
) -> None:
if browser_type is None:
browser_type = CONFIG.playwright_browser_type
browser_type = config.browser["playwright"].driver
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if CONFIG.global_proxy and "proxy" not in launch_kwargs:
if config.proxy and "proxy" not in launch_kwargs:
args = launch_kwargs.get("args", [])
if not any(str.startswith(i, "--proxy-server=") for i in args):
launch_kwargs["proxy"] = {"server": CONFIG.global_proxy}
launch_kwargs["proxy"] = {"server": config.proxy}
self.launch_kwargs = launch_kwargs
context_kwargs = {}
if "ignore_https_errors" in kwargs:
@ -79,8 +79,8 @@ class PlaywrightWrapper:
executable_path = Path(browser_type.executable_path)
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
kwargs = {}
if CONFIG.global_proxy:
kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy}
if config.proxy:
kwargs["env"] = {"ALL_PROXY": config.proxy}
await _install_browsers(self.browser_type, **kwargs)
if self._has_run_precheck:

View file

@ -17,7 +17,7 @@ from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.core.download_manager import WDMDownloadManager
from webdriver_manager.core.http import WDMHttpClient
from metagpt.config import CONFIG
from metagpt.config2 import config
from metagpt.utils.parse_html import WebPage
@ -41,12 +41,10 @@ class SeleniumWrapper:
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
) -> None:
if browser_type is None:
browser_type = CONFIG.selenium_browser_type
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if CONFIG.global_proxy and "proxy-server" not in launch_kwargs:
launch_kwargs["proxy-server"] = CONFIG.global_proxy
if config.proxy and "proxy-server" not in launch_kwargs:
launch_kwargs["proxy-server"] = config.proxy
self.executable_path = launch_kwargs.pop("executable_path", None)
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
@ -97,8 +95,8 @@ _webdriver_manager_types = {
class WDMHttpProxyClient(WDMHttpClient):
def get(self, url, **kwargs):
if "proxies" not in kwargs and CONFIG.global_proxy:
kwargs["proxies"] = {"all_proxy": CONFIG.global_proxy}
if "proxies" not in kwargs and config.proxy:
kwargs["proxies"] = {"all_proxy": config.proxy}
return super().get(url, **kwargs)