upgrade tiktoken to support azure

This commit is contained in:
seehi 2023-12-06 16:23:43 +08:00
parent ad347e0717
commit f4505d0e39

View file

@ -159,21 +159,16 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
@property
def model(self):
if self._is_azure():
return self.config.deployment_name
return self.config.openai_api_model
def __init_openai(self):
self._make_client()
self.is_azure = self.config.openai_api_type == "azure"
self.model = self.config.deployment_name if self.is_azure else self.config.openai_api_model
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
if self._is_azure():
if self.is_azure:
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
else:
@ -181,7 +176,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self.async_client = AsyncOpenAI(**async_kwargs)
def _make_client_kwargs(self) -> (dict, dict):
if self._is_azure():
if self.is_azure:
kwargs = dict(
api_key=self.config.openai_api_key,
api_version=self.config.openai_api_version,
@ -200,9 +195,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return kwargs, async_kwargs
def _is_azure(self) -> bool:
return self.config.openai_api_type == "azure"
def _get_proxy_params(self) -> dict:
params = {}
if self.config.openai_proxy: