mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 09:16:21 +02:00
use config
This commit is contained in:
parent
21cac0bffb
commit
479bbc9b2d
37 changed files with 102 additions and 276 deletions
|
|
@ -13,7 +13,7 @@ import aiohttp
|
|||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -47,7 +47,8 @@ class OpenAIText2Embedding:
|
|||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY
|
||||
self.openai_llm = config.get_openai_llm()
|
||||
self.openai_api_key = openai_api_key or self.openai_llm.api_key
|
||||
|
||||
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
|
||||
"""Text to embedding
|
||||
|
|
@ -57,7 +58,7 @@ class OpenAIText2Embedding:
|
|||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
|
||||
proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {}
|
||||
proxies = {"proxy": self.openai_llm.proxy} if self.openai_llm.proxy else {}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
|
||||
data = {"input": text, "model": model}
|
||||
url = "https://api.openai.com/v1/embeddings"
|
||||
|
|
@ -83,5 +84,5 @@ async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", op
|
|||
if not text:
|
||||
return ""
|
||||
if not openai_api_key:
|
||||
openai_api_key = CONFIG.OPENAI_API_KEY
|
||||
openai_api_key = config.get_openai_llm().api_key
|
||||
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from os.path import join
|
||||
from typing import List
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 20,
|
||||
"cfg_scale": 7,
|
||||
"width": 512,
|
||||
"height": 768,
|
||||
"restore_faces": False,
|
||||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
||||
|
||||
class SDEngine:
|
||||
def __init__(self):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(prompts):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def decode_base64_to_image(img, save_name):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
logger.info(save_name)
|
||||
image.save(f"{save_name}.png", pnginfo=pnginfo)
|
||||
return pnginfo, image
|
||||
|
||||
|
||||
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
||||
for idx, _img in enumerate(imgs):
|
||||
save_name = join(save_dir, save_name)
|
||||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
@ -7,6 +7,8 @@ import json
|
|||
from concurrent import futures
|
||||
from typing import Literal, overload
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
|
|
@ -15,8 +17,6 @@ except ImportError:
|
|||
"You can install it by running the command: `pip install -e.[search-ddg]`"
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
"""Wrapper around duckduckgo_search API.
|
||||
|
|
@ -31,8 +31,8 @@ class DDGAPIWrapper:
|
|||
executor: futures.Executor | None = None,
|
||||
):
|
||||
kwargs = {}
|
||||
if CONFIG.global_proxy:
|
||||
kwargs["proxies"] = CONFIG.global_proxy
|
||||
if config.proxy:
|
||||
kwargs["proxies"] = config.proxy
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
self.ddgs = DDGS(**kwargs)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from urllib.parse import urlparse
|
|||
import httplib2
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
|
||||
try:
|
||||
|
|
@ -35,7 +35,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@field_validator("google_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or CONFIG.google_api_key
|
||||
val = val or config.search["google"].api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
|
|
@ -47,7 +47,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or CONFIG.google_cse_id
|
||||
val = val or config.search["google"].cse_id
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
|
|
@ -59,8 +59,8 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@property
|
||||
def google_api_client(self):
|
||||
build_kwargs = {"developerKey": self.google_api_key}
|
||||
if CONFIG.global_proxy:
|
||||
parse_result = urlparse(CONFIG.global_proxy)
|
||||
if config.proxy:
|
||||
parse_result = urlparse(config.proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
if proxy_type == "https":
|
||||
proxy_type = "http"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
|
|
@ -32,7 +32,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
@field_validator("serpapi_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or CONFIG.serpapi_api_key
|
||||
val = val or config.search["serpapi"].api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
|
|
@ -25,7 +25,7 @@ class SerperWrapper(BaseModel):
|
|||
@field_validator("serper_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or CONFIG.serper_api_key
|
||||
val = val or config.search["serper"].api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from __future__ import annotations
|
|||
import importlib
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
|
@ -19,7 +18,6 @@ class WebBrowserEngine:
|
|||
engine: WebBrowserEngineType | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
engine = engine or CONFIG.web_browser_engine
|
||||
if engine is None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Literal
|
|||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
|
@ -33,13 +33,13 @@ class PlaywrightWrapper:
|
|||
**kwargs,
|
||||
) -> None:
|
||||
if browser_type is None:
|
||||
browser_type = CONFIG.playwright_browser_type
|
||||
browser_type = config.browser["playwright"].driver
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if CONFIG.global_proxy and "proxy" not in launch_kwargs:
|
||||
if config.proxy and "proxy" not in launch_kwargs:
|
||||
args = launch_kwargs.get("args", [])
|
||||
if not any(str.startswith(i, "--proxy-server=") for i in args):
|
||||
launch_kwargs["proxy"] = {"server": CONFIG.global_proxy}
|
||||
launch_kwargs["proxy"] = {"server": config.proxy}
|
||||
self.launch_kwargs = launch_kwargs
|
||||
context_kwargs = {}
|
||||
if "ignore_https_errors" in kwargs:
|
||||
|
|
@ -79,8 +79,8 @@ class PlaywrightWrapper:
|
|||
executable_path = Path(browser_type.executable_path)
|
||||
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
|
||||
kwargs = {}
|
||||
if CONFIG.global_proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy}
|
||||
if config.proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": config.proxy}
|
||||
await _install_browsers(self.browser_type, **kwargs)
|
||||
|
||||
if self._has_run_precheck:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from selenium.webdriver.support.wait import WebDriverWait
|
|||
from webdriver_manager.core.download_manager import WDMDownloadManager
|
||||
from webdriver_manager.core.http import WDMHttpClient
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
|
|
@ -41,12 +41,10 @@ class SeleniumWrapper:
|
|||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
) -> None:
|
||||
if browser_type is None:
|
||||
browser_type = CONFIG.selenium_browser_type
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if CONFIG.global_proxy and "proxy-server" not in launch_kwargs:
|
||||
launch_kwargs["proxy-server"] = CONFIG.global_proxy
|
||||
if config.proxy and "proxy-server" not in launch_kwargs:
|
||||
launch_kwargs["proxy-server"] = config.proxy
|
||||
|
||||
self.executable_path = launch_kwargs.pop("executable_path", None)
|
||||
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
|
||||
|
|
@ -97,8 +95,8 @@ _webdriver_manager_types = {
|
|||
|
||||
class WDMHttpProxyClient(WDMHttpClient):
|
||||
def get(self, url, **kwargs):
|
||||
if "proxies" not in kwargs and CONFIG.global_proxy:
|
||||
kwargs["proxies"] = {"all_proxy": CONFIG.global_proxy}
|
||||
if "proxies" not in kwargs and config.proxy:
|
||||
kwargs["proxies"] = {"all_proxy": config.proxy}
|
||||
return super().get(url, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue