support new openai package

This commit is contained in:
seehi 2023-12-05 15:27:57 +08:00
parent eaf531e0ac
commit 09134c9c72
8 changed files with 73 additions and 15 deletions

View file

@ -26,6 +26,7 @@ from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.schema import Message
from metagpt.utils.common import ensure_trailing_slash
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
TOKEN_COSTS,
@ -153,27 +154,37 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config: Config):
client_kwargs, async_client_kwargs = self.__make_client_args(config)
client_kwargs, async_client_kwargs = self._make_client_kwargs(config)
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**async_client_kwargs)
self.rpm = int(config.get("RPM", 10))
def __make_client_args(self, config: Config):
def _make_client_kwargs(self, config: Config) -> (dict, dict):
mapping = {
"api_key": "openai_api_key",
"base_url": "openai_base_url",
}
kwargs = {}
for key, attr in mapping.items():
value = getattr(config, attr, None)
if value:
kwargs[key] = value
if config.openai_base_url:
kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url)
kwargs = {key: getattr(config, mapping[key]) for key in mapping if getattr(config, mapping[key], None)}
async_kwargs = kwargs.copy()
# need http_client to support proxy
# Create http_client if proxy is specified
if config.openai_proxy:
httpx_args = dict(base_url=kwargs["base_url"], proxies=config.openai_proxy)
kwargs["http_client"] = httpx.Client(**httpx_args)
async_kwargs["http_client"] = httpx.AsyncClient(**httpx_args)
params = {"proxies": config.openai_proxy}
if config.openai_base_url:
params["base_url"] = config.openai_base_url
kwargs["http_client"] = httpx.Client(**params)
async_kwargs["http_client"] = httpx.AsyncClient(**params)
return kwargs, async_kwargs

View file

@ -305,3 +305,9 @@ def parse_recipient(text):
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
recipient = re.search(pattern, text)
return recipient.group(1) if recipient else ""
def ensure_trailing_slash(url):
if not url:
return url
return url if url.endswith("/") else url + "/"