feat: Unified Cost Calculation Logic.

This commit is contained in:
莘权 马 2024-02-02 16:26:55 +08:00
parent e8d2819031
commit fa622c2f97
3 changed files with 35 additions and 14 deletions

View file

@ -6,16 +6,20 @@
@File : base_llm.py
@Desc : mashenquan, 2023/8/22. + try catch
"""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Dict, Optional, Union
from openai import AsyncOpenAI
from openai.types import CompletionUsage
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
class BaseLLM(ABC):
@ -29,6 +33,7 @@ class BaseLLM(ABC):
aclient: Optional[Union[AsyncOpenAI]] = None
cost_manager: Optional[CostManager] = None
model: Optional[str] = None
pricing_plan: Optional[str] = None
@abstractmethod
def __init__(self, config: LLMConfig):
@ -149,3 +154,30 @@ class BaseLLM(ABC):
{'language': 'python', 'code': "print('Hello, World!')"}
"""
return json.loads(self.get_choice_function(rsp)["arguments"])
@handle_exception
def _update_costs(self, usage: CompletionUsage | Dict):
"""
Updates the costs based on the provided usage information.
Args:
usage (Union[CompletionUsage, Dict]): The usage information used to calculate and update costs.
It can be either an instance of CompletionUsage or a dictionary.
Returns:
None: This method does not return any value.
Raises:
ValueError: If the provided usage is not a valid format.
Example:
Usage example goes here, demonstrating how to call and utilize this method.
"""
if self.config.calc_usage and usage and self.cost_manager:
if isinstance(usage, Dict):
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
else:
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)

View file

@ -9,7 +9,7 @@
from __future__ import annotations
import json
from typing import AsyncIterator, Dict, Optional, Union
from typing import AsyncIterator, Optional, Union
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
@ -218,17 +218,6 @@ class OpenAILLM(BaseLLM):
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage | Dict):
if self.config.calc_usage and usage and self.cost_manager:
if isinstance(usage, Dict):
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
else:
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)

View file

@ -41,11 +41,11 @@ class ZhiPuAILLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.__init_zhipuai(config)
self.config = config
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.pricing_plan = self.config.pricing_plan or self.model
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.config = config
self.cost_manager: Optional[CostManager] = None
def __init_zhipuai(self, config: LLMConfig):