mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-06 06:12:39 +02:00
feat: Unified Cost Calculation Logic.
This commit is contained in:
parent
28727a03d8
commit
e8d2819031
11 changed files with 30 additions and 106 deletions
|
|
@ -42,6 +42,7 @@ class LLMConfig(YamlModel):
|
|||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters.
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
app_id: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -8,14 +8,10 @@
|
|||
"""
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils import TOKEN_COSTS, count_message_tokens, count_string_tokens
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
@register_provider(LLMType.AZURE)
|
||||
|
|
@ -29,6 +25,7 @@ class AzureOpenAILLM(OpenAILLM):
|
|||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.aclient = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
self.pricing_plan = self.config.pricing_plan
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
|
|
@ -43,38 +40,3 @@ class AzureOpenAILLM(OpenAILLM):
|
|||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
model_name = "gpt-35-turbo" if "gpt-3" in self.model.lower() else "gpt-4-turbo-preview"
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, model_name)
|
||||
usage.completion_tokens = count_string_tokens(rsp, model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
model_name = self._get_azure_model()
|
||||
# More about pricing: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, model_name)
|
||||
|
||||
def _get_azure_model(self) -> str:
|
||||
models = [i.lower() for i in TOKEN_COSTS.keys() if "azure" in i]
|
||||
mappings = {i: set(i.split("-")) for i in models}
|
||||
words = self.model.lower().split("-")
|
||||
weights = []
|
||||
for k, v in mappings.items():
|
||||
count = 0
|
||||
for i in words:
|
||||
if i in v:
|
||||
count += 1
|
||||
weights.append((k, count))
|
||||
sorted_list = sorted(weights, key=lambda x: x[1], reverse=True)
|
||||
return sorted_list[0][0]
|
||||
|
|
|
|||
|
|
@ -81,14 +81,6 @@ class FireworksLLM(OpenAILLM):
|
|||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not context.cost_manager
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self.cost_manager.get_costs()
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class GeminiLLM(BaseLLM):
|
|||
self.__init_gemini(config)
|
||||
self.config = config
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: LLMConfig):
|
||||
|
|
@ -70,16 +71,6 @@ class GeminiLLM(BaseLLM):
|
|||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,9 @@ from metagpt.utils.exceptions import handle_exception
|
|||
@register_provider(LLMType.METAGPT)
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
|
||||
# The current billing is based on usage frequency. If there is a future billing logic based on the
|
||||
# number of tokens, please refine the logic here accordingly.
|
||||
|
||||
return usage
|
||||
return CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
|
|
|
|||
|
|
@ -36,26 +36,17 @@ class OllamaLLM(BaseLLM):
|
|||
self.suffix_url = "/chat"
|
||||
self.http_method = "post"
|
||||
self.use_system_prompt = False
|
||||
self._cost_manager = TokenCostManager()
|
||||
self.cost_manager = TokenCostManager()
|
||||
|
||||
def __init_ollama(self, config: LLMConfig):
|
||||
assert config.base_url, "ollama base url is required!"
|
||||
self.model = config.model
|
||||
self.pricing_plan = self.model
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"ollama updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the resp content from llm response"""
|
||||
assist_msg = resp.get("message", {})
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from metagpt.utils.token_counter import count_message_tokens, count_string_token
|
|||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._cost_manager = TokenCostManager()
|
||||
self.cost_manager = TokenCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
|
||||
|
|
@ -35,13 +35,5 @@ class OpenLLM(OpenAILLM):
|
|||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
return self.cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@
|
|||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
from typing import AsyncIterator, Dict, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
|
@ -61,6 +62,7 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
def _init_model(self):
|
||||
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
|
|
@ -209,17 +211,23 @@ class OpenAILLM(BaseLLM):
|
|||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.pricing_plan)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.pricing_plan)
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
def _update_costs(self, usage: CompletionUsage | Dict):
|
||||
if self.config.calc_usage and usage and self.cost_manager:
|
||||
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
if isinstance(usage, Dict):
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
else:
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
completion_tokens = usage.completion_tokens
|
||||
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
if not self.cost_manager:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
import zhipuai
|
||||
|
|
@ -21,6 +22,7 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
||||
class ZhiPuEvent(Enum):
|
||||
|
|
@ -41,8 +43,10 @@ class ZhiPuAILLM(BaseLLM):
|
|||
self.__init_zhipuai(config)
|
||||
self.llm = ZhiPuModelAPI
|
||||
self.model = "chatglm_turbo" # so far only one model, just use it
|
||||
self.pricing_plan = self.config.pricing_plan or self.model
|
||||
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
|
||||
self.config = config
|
||||
self.cost_manager: Optional[CostManager] = None
|
||||
|
||||
def __init_zhipuai(self, config: LLMConfig):
|
||||
assert config.api_key
|
||||
|
|
@ -57,16 +61,6 @@ class ZhiPuAILLM(BaseLLM):
|
|||
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if self.config.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
|
||||
usage = resp.usage.model_dump()
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class CostManager(BaseModel):
|
|||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
if prompt_tokens + completion_tokens == 0:
|
||||
if prompt_tokens + completion_tokens == 0 or not model:
|
||||
return
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
|
|
|||
|
|
@ -34,14 +34,10 @@ TOKEN_COSTS = {
|
|||
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
# Azure
|
||||
"azure-gpt-3.5-turbo-4k": {"prompt": 0.0015, "completion": 0.002},
|
||||
"azure-gpt-3.5-turbo-16k": {"prompt": 0.003, "completion": 0.004},
|
||||
"azure-gpt-3.5-turbo-1106": {"prompt": 0.001, "completion": 0.002},
|
||||
"azure-gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"azure-gpt-4-turbo-vision": {"prompt": 0.01, "completion": 0.03},
|
||||
"azure-gpt-4-8k": {"prompt": 0.03, "completion": 0.06},
|
||||
"azure-gpt-4-32k": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-3.5-turbo-4k": {"prompt": 0.0015, "completion": 0.002},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo-vision": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-8k": {"prompt": 0.03, "completion": 0.06},
|
||||
}
|
||||
|
||||
TOKEN_MAX = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue