feat: Unified Cost Calculation Logic.

This commit is contained in:
莘权 马 2024-02-02 15:54:11 +08:00
parent 28727a03d8
commit e8d2819031
11 changed files with 30 additions and 106 deletions

View file

@ -42,6 +42,7 @@ class LLMConfig(YamlModel):
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters.
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None

View file

@ -8,14 +8,10 @@
"""
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
from openai.types import CompletionUsage
from metagpt.configs.llm_config import LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils import TOKEN_COSTS, count_message_tokens, count_string_tokens
from metagpt.utils.exceptions import handle_exception
@register_provider(LLMType.AZURE)
@ -29,6 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan
def _make_client_kwargs(self) -> dict:
kwargs = dict(
@ -43,38 +40,3 @@ class AzureOpenAILLM(OpenAILLM):
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not self.config.calc_usage:
return usage
model_name = "gpt-35-turbo" if "gpt-3" in self.model.lower() else "gpt-4-turbo-preview"
try:
usage.prompt_tokens = count_message_tokens(messages, model_name)
usage.completion_tokens = count_string_tokens(rsp, model_name)
except Exception as e:
logger.error(f"usage calculation failed: {e}")
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
model_name = self._get_azure_model()
# More about pricing: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, model_name)
def _get_azure_model(self) -> str:
models = [i.lower() for i in TOKEN_COSTS.keys() if "azure" in i]
mappings = {i: set(i.split("-")) for i in models}
words = self.model.lower().split("-")
weights = []
for k, v in mappings.items():
count = 0
for i in words:
if i in v:
count += 1
weights.append((k, count))
sorted_list = sorted(weights, key=lambda x: x[1], reverse=True)
return sorted_list[0][0]

View file

@ -81,14 +81,6 @@ class FireworksLLM(OpenAILLM):
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()

View file

@ -53,6 +53,7 @@ class GeminiLLM(BaseLLM):
self.__init_gemini(config)
self.config = config
self.model = "gemini-pro" # so far only one model
self.pricing_plan = self.config.pricing_plan or self.model
self.llm = GeminiGenerativeModel(model_name=self.model)
def __init_gemini(self, config: LLMConfig):
@ -70,16 +71,6 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -16,12 +16,9 @@ from metagpt.utils.exceptions import handle_exception
@register_provider(LLMType.METAGPT)
class MetaGPTLLM(OpenAILLM):
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
# The current billing is based on usage frequency. If there is a future billing logic based on the
# number of tokens, please refine the logic here accordingly.
return usage
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
@handle_exception
def _update_costs(self, usage: CompletionUsage):

View file

@ -36,26 +36,17 @@ class OllamaLLM(BaseLLM):
self.suffix_url = "/chat"
self.http_method = "post"
self.use_system_prompt = False
self._cost_manager = TokenCostManager()
self.cost_manager = TokenCostManager()
def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
self.pricing_plan = self.model
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})

View file

@ -16,7 +16,7 @@ from metagpt.utils.token_counter import count_message_tokens, count_string_token
class OpenLLM(OpenAILLM):
def __init__(self, config: LLMConfig):
super().__init__(config)
self._cost_manager = TokenCostManager()
self.cost_manager = TokenCostManager()
def _make_client_kwargs(self) -> dict:
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
@ -35,13 +35,5 @@ class OpenLLM(OpenAILLM):
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
return self.cost_manager.get_costs()

View file

@ -6,9 +6,10 @@
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
from __future__ import annotations
import json
from typing import AsyncIterator, Optional, Union
from typing import AsyncIterator, Dict, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
@ -61,6 +62,7 @@ class OpenAILLM(BaseLLM):
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
self.pricing_plan = self.config.pricing_plan or self.model
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
@ -209,17 +211,23 @@ class OpenAILLM(BaseLLM):
return usage
try:
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
usage.completion_tokens = count_string_tokens(rsp, self.pricing_plan)
except Exception as e:
logger.error(f"usage calculation failed: {e}")
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
def _update_costs(self, usage: CompletionUsage | Dict):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
if isinstance(usage, Dict):
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
else:
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
def get_costs(self) -> Costs:
if not self.cost_manager:

View file

@ -3,6 +3,7 @@
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
from enum import Enum
from typing import Optional
import openai
import zhipuai
@ -21,6 +22,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
from metagpt.utils.cost_manager import CostManager
class ZhiPuEvent(Enum):
@ -41,8 +43,10 @@ class ZhiPuAILLM(BaseLLM):
self.__init_zhipuai(config)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.pricing_plan = self.config.pricing_plan or self.model
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.config = config
self.cost_manager: Optional[CostManager] = None
def __init_zhipuai(self, config: LLMConfig):
assert config.api_key
@ -57,16 +61,6 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()

View file

@ -39,7 +39,7 @@ class CostManager(BaseModel):
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
if prompt_tokens + completion_tokens == 0:
if prompt_tokens + completion_tokens == 0 or not model:
return
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens

View file

@ -34,14 +34,10 @@ TOKEN_COSTS = {
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
# Azure
"azure-gpt-3.5-turbo-4k": {"prompt": 0.0015, "completion": 0.002},
"azure-gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004},
"azure-gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
"azure-gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
"azure-gpt-4-turbo-vision": {"prompt": 0.01, "completion": 0.03},
"azure-gpt-4-8k": {"prompt": 0.03, "completion": 0.06},
"azure-gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
"gpt-3.5-turbo-4k": {"prompt": 0.0015, "completion": 0.002},
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
"gpt-4-turbo-vision": {"prompt": 0.01, "completion": 0.03},
"gpt-4-8k": {"prompt": 0.03, "completion": 0.06},
}
TOKEN_MAX = {