azure client

This commit is contained in:
seehi 2023-12-06 11:58:13 +08:00
parent f03a6d8029
commit a617aab65b
3 changed files with 98 additions and 67 deletions

View file

@ -10,7 +10,14 @@ import time
from typing import NamedTuple, Union
import httpx
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
from openai import (
APIConnectionError,
AsyncAzureOpenAI,
AsyncOpenAI,
AsyncStream,
AzureOpenAI,
OpenAI,
)
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from tenacity import (
@ -26,7 +33,6 @@ from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.schema import Message
from metagpt.utils.common import ensure_trailing_slash
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
TOKEN_COSTS,
@ -154,40 +160,49 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config: Config):
client_kwargs, async_client_kwargs = self._make_client_kwargs(config)
self.client = OpenAI(**client_kwargs)
self.async_client = AsyncOpenAI(**async_client_kwargs)
self._make_client(config)
self.rpm = int(config.get("RPM", 10))
def _make_client_kwargs(self, config: Config) -> (dict, dict):
mapping = {
"api_key": "openai_api_key",
"base_url": "openai_base_url",
}
kwargs = {}
for key, attr in mapping.items():
value = getattr(config, attr, None)
if value:
kwargs[key] = value
def _make_client(self, config: Config):
kwargs, async_kwargs = self._make_client_kwargs(config)
# OpenAI v1 requires the base_url to end with /
if config.openai_base_url:
kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url)
if self._is_azure(config):
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
else:
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
def _make_client_kwargs(self, config: Config) -> (dict, dict):
if self._is_azure(config):
kwargs = dict(
api_key=config.openai_api_key,
api_version=config.openai_api_version,
azure_endpoint=config.openai_base_url,
)
else:
kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url)
async_kwargs = kwargs.copy()
# Create http_client if proxy is specified
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params(config)
if proxy_params:
kwargs["http_client"] = httpx.Client(**proxy_params)
async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params)
return kwargs, async_kwargs
def _is_azure(self, config: Config) -> bool:
return config.openai_api_type == "azure"
def _get_proxy_params(self, config: Config) -> dict:
params = {}
if config.openai_proxy:
params = {"proxies": config.openai_proxy}
if config.openai_base_url:
params["base_url"] = config.openai_base_url
kwargs["http_client"] = httpx.Client(**params)
async_kwargs["http_client"] = httpx.AsyncClient(**params)
return kwargs, async_kwargs
return params
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
@ -230,9 +245,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
kwargs_mode = (
{"engine": CONFIG.deployment_name}
if CONFIG.deployment_name
else {"deployment_id": CONFIG.deployment_id}
{"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id}
)
else:
kwargs_mode = {"model": self.model}

View file

@ -305,9 +305,3 @@ def parse_recipient(text):
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
recipient = re.search(pattern, text)
return recipient.group(1) if recipient else ""
def ensure_trailing_slash(url):
if not url:
return url
return url if url.endswith("/") else url + "/"