mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
azure client
This commit is contained in:
parent
f03a6d8029
commit
a617aab65b
3 changed files with 98 additions and 67 deletions
|
|
@ -10,7 +10,14 @@ import time
|
|||
from typing import NamedTuple, Union
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
AsyncAzureOpenAI,
|
||||
AsyncOpenAI,
|
||||
AsyncStream,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
)
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from tenacity import (
|
||||
|
|
@ -26,7 +33,6 @@ from metagpt.logs import logger
|
|||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import ensure_trailing_slash
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
|
|
@ -154,40 +160,49 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self, config: Config):
|
||||
client_kwargs, async_client_kwargs = self._make_client_kwargs(config)
|
||||
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_client_kwargs)
|
||||
|
||||
self._make_client(config)
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
def _make_client_kwargs(self, config: Config) -> (dict, dict):
|
||||
mapping = {
|
||||
"api_key": "openai_api_key",
|
||||
"base_url": "openai_base_url",
|
||||
}
|
||||
kwargs = {}
|
||||
for key, attr in mapping.items():
|
||||
value = getattr(config, attr, None)
|
||||
if value:
|
||||
kwargs[key] = value
|
||||
def _make_client(self, config: Config):
|
||||
kwargs, async_kwargs = self._make_client_kwargs(config)
|
||||
|
||||
# OpenAI v1 requires the base_url to end with /
|
||||
if config.openai_base_url:
|
||||
kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url)
|
||||
if self._is_azure(config):
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
else:
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
|
||||
def _make_client_kwargs(self, config: Config) -> (dict, dict):
|
||||
if self._is_azure(config):
|
||||
kwargs = dict(
|
||||
api_key=config.openai_api_key,
|
||||
api_version=config.openai_api_version,
|
||||
azure_endpoint=config.openai_base_url,
|
||||
)
|
||||
else:
|
||||
kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url)
|
||||
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# Create http_client if proxy is specified
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params(config)
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = httpx.Client(**proxy_params)
|
||||
async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _is_azure(self, config: Config) -> bool:
|
||||
return config.openai_api_type == "azure"
|
||||
|
||||
def _get_proxy_params(self, config: Config) -> dict:
|
||||
params = {}
|
||||
if config.openai_proxy:
|
||||
params = {"proxies": config.openai_proxy}
|
||||
if config.openai_base_url:
|
||||
params["base_url"] = config.openai_base_url
|
||||
|
||||
kwargs["http_client"] = httpx.Client(**params)
|
||||
async_kwargs["http_client"] = httpx.AsyncClient(**params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
|
|
@ -230,9 +245,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
|
||||
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
|
||||
kwargs_mode = (
|
||||
{"engine": CONFIG.deployment_name}
|
||||
if CONFIG.deployment_name
|
||||
else {"deployment_id": CONFIG.deployment_id}
|
||||
{"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id}
|
||||
)
|
||||
else:
|
||||
kwargs_mode = {"model": self.model}
|
||||
|
|
|
|||
|
|
@ -305,9 +305,3 @@ def parse_recipient(text):
|
|||
pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now
|
||||
recipient = re.search(pattern, text)
|
||||
return recipient.group(1) if recipient else ""
|
||||
|
||||
|
||||
def ensure_trailing_slash(url):
|
||||
if not url:
|
||||
return url
|
||||
return url if url.endswith("/") else url + "/"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue