upgrade tiktoken to support azure

This commit is contained in:
seehi 2023-12-06 16:06:17 +08:00
parent a617aab65b
commit ad347e0717
6 changed files with 50 additions and 45 deletions

View file

@ -58,8 +58,7 @@ class Config(metaclass=Singleton):
self.openai_api_rpm = self._get("RPM", 3)
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4")
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
self.deployment_name = self._get("DEPLOYMENT_NAME")
self.deployment_id = self._get("DEPLOYMENT_ID")
self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4")
self.spark_appid = self._get("SPARK_APPID")
self.spark_api_secret = self._get("SPARK_API_SECRET")

View file

@ -153,55 +153,63 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
def __init__(self):
self.__init_openai(CONFIG)
self.model = CONFIG.openai_api_model
self.config: Config = CONFIG
self.__init_openai()
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config: Config):
self._make_client(config)
self.rpm = int(config.get("RPM", 10))
@property
def model(self):
if self._is_azure():
return self.config.deployment_name
def _make_client(self, config: Config):
kwargs, async_kwargs = self._make_client_kwargs(config)
return self.config.openai_api_model
if self._is_azure(config):
def __init_openai(self):
self._make_client()
self.rpm = int(self.config.get("RPM", 10))
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
if self._is_azure():
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
else:
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
def _make_client_kwargs(self, config: Config) -> (dict, dict):
if self._is_azure(config):
def _make_client_kwargs(self) -> (dict, dict):
if self._is_azure():
kwargs = dict(
api_key=config.openai_api_key,
api_version=config.openai_api_version,
azure_endpoint=config.openai_base_url,
api_key=self.config.openai_api_key,
api_version=self.config.openai_api_version,
azure_endpoint=self.config.openai_base_url,
)
else:
kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url)
kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url)
async_kwargs = kwargs.copy()
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params(config)
proxy_params = self._get_proxy_params()
if proxy_params:
kwargs["http_client"] = httpx.Client(**proxy_params)
async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params)
return kwargs, async_kwargs
def _is_azure(self, config: Config) -> bool:
return config.openai_api_type == "azure"
def _is_azure(self) -> bool:
return self.config.openai_api_type == "azure"
def _get_proxy_params(self, config: Config) -> dict:
def _get_proxy_params(self) -> dict:
params = {}
if config.openai_proxy:
params = {"proxies": config.openai_proxy}
if config.openai_base_url:
params["base_url"] = config.openai_base_url
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
return params
async def _achat_completion_stream(self, messages: list[dict]) -> str:
@ -235,21 +243,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"stop": None,
"temperature": 0.3,
"timeout": 3,
"model": self.model,
}
if configs:
kwargs.update(configs)
if CONFIG.openai_api_type == "azure":
if CONFIG.deployment_name and CONFIG.deployment_id:
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
kwargs_mode = (
{"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id}
)
else:
kwargs_mode = {"model": self.model}
kwargs.update(kwargs_mode)
return kwargs
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
@ -382,7 +380,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
usage.completion_tokens = count_string_tokens(rsp, self.model)
return usage
except Exception as e:
logger.error("usage calculation failed!", e)
logger.error(f"usage calculation failed!: {e}")
else:
return usage

View file

@ -16,13 +16,15 @@ TOKEN_COSTS = {
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
"gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004},
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
"gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
}
@ -32,13 +34,15 @@ TOKEN_MAX = {
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-35-turbo": 4096,
"gpt-35-turbo-16k": 16384,
"gpt-4-0314": 8192,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768
"chatglm_turbo": 32768,
}
@ -52,6 +56,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-35-turbo",
"gpt-35-turbo-16k",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",