diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 529c9b16d..5c8970d10 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -6,16 +6,20 @@ @File : base_llm.py @Desc : mashenquan, 2023/8/22. + try catch """ +from __future__ import annotations + import json from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Dict, Optional, Union from openai import AsyncOpenAI +from openai.types import CompletionUsage from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.cost_manager import CostManager +from metagpt.utils.exceptions import handle_exception class BaseLLM(ABC): @@ -29,6 +33,7 @@ class BaseLLM(ABC): aclient: Optional[Union[AsyncOpenAI]] = None cost_manager: Optional[CostManager] = None model: Optional[str] = None + pricing_plan: Optional[str] = None @abstractmethod def __init__(self, config: LLMConfig): @@ -149,3 +154,30 @@ class BaseLLM(ABC): {'language': 'python', 'code': "print('Hello, World!')"} """ return json.loads(self.get_choice_function(rsp)["arguments"]) + + @handle_exception + def _update_costs(self, usage: CompletionUsage | Dict): + """ + Updates the costs based on the provided usage information. + + Args: + usage (Union[CompletionUsage, Dict]): The usage information used to calculate and update costs. + It can be either an instance of CompletionUsage or a dictionary. + + Returns: + None: This method does not return any value. + + Raises: + ValueError: If the provided usage is not a valid format. + + Example: + Usage example goes here, demonstrating how to call and utilize this method. + """ + if self.config.calc_usage and usage and self.cost_manager: + if isinstance(usage, Dict): + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + else: + prompt_tokens = usage.prompt_tokens + completion_tokens = usage.completion_tokens + self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index f33a4136b..70847deec 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -9,7 +9,7 @@ from __future__ import annotations import json -from typing import AsyncIterator, Dict, Optional, Union +from typing import AsyncIterator, Optional, Union from openai import APIConnectionError, AsyncOpenAI, AsyncStream from openai._base_client import AsyncHttpxClientWrapper @@ -218,17 +218,6 @@ class OpenAILLM(BaseLLM): return usage - @handle_exception - def _update_costs(self, usage: CompletionUsage | Dict): - if self.config.calc_usage and usage and self.cost_manager: - if isinstance(usage, Dict): - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - else: - prompt_tokens = usage.prompt_tokens - completion_tokens = usage.completion_tokens - self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.pricing_plan) - def get_costs(self) -> Costs: if not self.cost_manager: return Costs(0, 0, 0, 0) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 3bc130090..d0359d710 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -41,11 +41,11 @@ class ZhiPuAILLM(BaseLLM): def __init__(self, config: LLMConfig): self.__init_zhipuai(config) + self.config = config self.llm = ZhiPuModelAPI self.model = "chatglm_turbo" # so far only one model, just use it self.pricing_plan = self.config.pricing_plan or self.model self.use_system_prompt: bool = False # zhipuai has no system prompt when use api - self.config = config self.cost_manager: Optional[CostManager] = None def __init_zhipuai(self, config: LLMConfig):