mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
upgrade tiktoken to support azure
This commit is contained in:
parent
a617aab65b
commit
ad347e0717
6 changed files with 50 additions and 45 deletions
|
|
@ -58,8 +58,7 @@ class Config(metaclass=Singleton):
|
|||
self.openai_api_rpm = self._get("RPM", 3)
|
||||
self.openai_api_model = self._get("OPENAI_API_MODEL", "gpt-4")
|
||||
self.max_tokens_rsp = self._get("MAX_TOKENS", 2048)
|
||||
self.deployment_name = self._get("DEPLOYMENT_NAME")
|
||||
self.deployment_id = self._get("DEPLOYMENT_ID")
|
||||
self.deployment_name = self._get("DEPLOYMENT_NAME", "gpt-4")
|
||||
|
||||
self.spark_appid = self._get("SPARK_APPID")
|
||||
self.spark_api_secret = self._get("SPARK_API_SECRET")
|
||||
|
|
|
|||
|
|
@ -153,55 +153,63 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openai()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self, config: Config):
|
||||
self._make_client(config)
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
@property
|
||||
def model(self):
|
||||
if self._is_azure():
|
||||
return self.config.deployment_name
|
||||
|
||||
def _make_client(self, config: Config):
|
||||
kwargs, async_kwargs = self._make_client_kwargs(config)
|
||||
return self.config.openai_api_model
|
||||
|
||||
if self._is_azure(config):
|
||||
def __init_openai(self):
|
||||
self._make_client()
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
|
||||
if self._is_azure():
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
else:
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
|
||||
def _make_client_kwargs(self, config: Config) -> (dict, dict):
|
||||
if self._is_azure(config):
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
if self._is_azure():
|
||||
kwargs = dict(
|
||||
api_key=config.openai_api_key,
|
||||
api_version=config.openai_api_version,
|
||||
azure_endpoint=config.openai_base_url,
|
||||
api_key=self.config.openai_api_key,
|
||||
api_version=self.config.openai_api_version,
|
||||
azure_endpoint=self.config.openai_base_url,
|
||||
)
|
||||
else:
|
||||
kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url)
|
||||
kwargs = dict(api_key=self.config.openai_api_key, base_url=self.config.openai_base_url)
|
||||
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params(config)
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = httpx.Client(**proxy_params)
|
||||
async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _is_azure(self, config: Config) -> bool:
|
||||
return config.openai_api_type == "azure"
|
||||
def _is_azure(self) -> bool:
|
||||
return self.config.openai_api_type == "azure"
|
||||
|
||||
def _get_proxy_params(self, config: Config) -> dict:
|
||||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
if config.openai_proxy:
|
||||
params = {"proxies": config.openai_proxy}
|
||||
if config.openai_base_url:
|
||||
params["base_url"] = config.openai_base_url
|
||||
if self.config.openai_proxy:
|
||||
params = {"proxies": self.config.openai_proxy}
|
||||
if self.config.openai_base_url:
|
||||
params["base_url"] = self.config.openai_base_url
|
||||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
|
|
@ -235,21 +243,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"timeout": 3,
|
||||
"model": self.model,
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
if CONFIG.deployment_name and CONFIG.deployment_id:
|
||||
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
|
||||
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
|
||||
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
|
||||
kwargs_mode = (
|
||||
{"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id}
|
||||
)
|
||||
else:
|
||||
kwargs_mode = {"model": self.model}
|
||||
kwargs.update(kwargs_mode)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> ChatCompletion:
|
||||
|
|
@ -382,7 +380,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
return usage
|
||||
except Exception as e:
|
||||
logger.error("usage calculation failed!", e)
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
else:
|
||||
return usage
|
||||
|
||||
|
|
|
|||
|
|
@ -16,13 +16,15 @@ TOKEN_COSTS = {
|
|||
"gpt-3.5-turbo-0613": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-3.5-turbo-16k-0613": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-35-turbo": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-35-turbo-16k": {"prompt": 0.003, "completion": 0.004},
|
||||
"gpt-4-0314": {"prompt": 0.03, "completion": 0.06},
|
||||
"gpt-4": {"prompt": 0.03, "completion": 0.06},
|
||||
"gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -32,13 +34,15 @@ TOKEN_MAX = {
|
|||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-35-turbo": 4096,
|
||||
"gpt-35-turbo-16k": 16384,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-4-0613": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768
|
||||
"chatglm_turbo": 32768,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -52,6 +56,8 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-35-turbo",
|
||||
"gpt-35-turbo-16k",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue