From 15a9c5e94135992e9854a57d14d581040879386f Mon Sep 17 00:00:00 2001 From: better629 Date: Wed, 7 Feb 2024 16:16:14 +0800 Subject: [PATCH] simplify _update_costs and related code --- metagpt/provider/base_llm.py | 13 ++++++++++--- metagpt/provider/fireworks_api.py | 15 ++------------- metagpt/provider/google_gemini_api.py | 10 ---------- metagpt/provider/ollama_api.py | 10 ---------- metagpt/provider/open_llm_api.py | 13 +------------ metagpt/provider/openai_api.py | 17 ++--------------- metagpt/provider/qianfan_api.py | 8 ++++---- metagpt/provider/zhipuai_api.py | 10 ---------- 8 files changed, 19 insertions(+), 77 deletions(-) diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index d3d9c829b..2f57b15aa 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -11,11 +11,12 @@ from abc import ABC, abstractmethod from typing import Optional, Union from openai import AsyncOpenAI +from pydantic import BaseModel from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs class BaseLLM(ABC): @@ -67,14 +68,15 @@ class BaseLLM(ABC): def _default_system_msg(self): return self._system_msg(self.system_prompt) - def _update_costs(self, usage: dict, model: str = None, local_calc_usage: bool = True): + def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): """update each request's token cost Args: model (str): model name or in some scenarios called endpoint - local_calc_usage (bool): some models don't calculate usage, it will overwrite calc_usage + local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage """ calc_usage = self.config.calc_usage and local_calc_usage model = model if model else self.model + usage = usage.model_dump() if isinstance(usage, BaseModel) else usage if calc_usage and self.cost_manager: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) @@ -83,6 +85,11 @@ class BaseLLM(ABC): except Exception as e: logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}") + def get_costs(self) -> Costs: + if not self.cost_manager: + return Costs(0, 0, 0, 0) + return self.cost_manager.get_costs() + async def aask( self, msg: str, diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index d56453a85..e62a7066e 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM, log_and_reraise -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager MODEL_GRADE_TOKEN_COSTS = { "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition @@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM): kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url) return kwargs - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage: - try: - # use FireworksCostManager not context.cost_manager - self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") - - def get_costs(self) -> Costs: - return self.cost_manager.get_costs() - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True @@ -113,7 +102,7 @@ class FireworksLLM(OpenAILLM): usage = CompletionUsage(**chunk.usage) full_content = "".join(collected_content) - self._update_costs(usage) + self._update_costs(usage.model_dump()) return full_content @retry( diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 2647ab16b..87ea81c80 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM): kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"google gemini updats costs failed! exp: {e}") - def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index c9103b018..52e8dbe36 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"ollama updats costs failed! exp: {e}") - def get_choice_text(self, resp: dict) -> str: """get the resp content from llm response""" assist_msg = resp.get("message", {}) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index a29b263a4..69371e379 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM -from metagpt.utils.cost_manager import Costs, TokenCostManager +from metagpt.utils.cost_manager import TokenCostManager from metagpt.utils.token_counter import count_message_tokens, count_string_tokens @@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM): logger.error(f"usage calculation failed!: {e}") return usage - - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage: - try: - # use OpenLLMCostManager not CONFIG.cost_manager - self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") - - def get_costs(self) -> Costs: - return self._cost_manager.get_costs() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 63e68c9bd..1e5770d74 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -29,7 +29,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -55,16 +55,13 @@ class OpenAILLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config - self._init_model() self._init_client() self.auto_max_tokens = False self.cost_manager: Optional[CostManager] = None - def _init_model(self): - self.model = self.config.model # Used in _calc_usage & _cons_kwargs - def _init_client(self): """https://github.com/openai/openai-python#async-usage""" + self.model = self.config.model # Used in _calc_usage & _cons_kwargs kwargs = self._make_client_kwargs() self.aclient = AsyncOpenAI(**kwargs) @@ -240,16 +237,6 @@ class OpenAILLM(BaseLLM): return usage - @handle_exception - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage and self.cost_manager: - self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - - def get_costs(self) -> Costs: - if not self.cost_manager: - return Costs(0, 0, 0, 0) - return self.cost_manager.get_costs() - def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return self.config.max_token diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 180935e61..fbbff7085 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -78,7 +78,7 @@ class QianFanLLM(BaseLLM): # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee self.calc_usage = self.config.calc_usage and self.config.endpoint is None - self.client = qianfan.ChatCompletion() + self.aclient = qianfan.ChatCompletion() def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = { @@ -110,12 +110,12 @@ class QianFanLLM(BaseLLM): return resp.get("result", "") def completion(self, messages: list[dict]) -> JsonBody: - resp = self.client.do(**self._const_kwargs(messages=messages, stream=False)) + resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False)) self._update_costs(resp.body.get("usage", {})) return resp.body async def _achat_completion(self, messages: list[dict]) -> JsonBody: - resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=False)) + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False)) self._update_costs(resp.body.get("usage", {})) return resp.body @@ -123,7 +123,7 @@ class QianFanLLM(BaseLLM): return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: - resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=True)) + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True)) collected_content = [] usage = {} async for chunk in resp: diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 9108a1fba..b7c160a41 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -57,16 +57,6 @@ class ZhiPuAILLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"zhipuai updats costs failed! exp: {e}") - def completion(self, messages: list[dict], timeout=3) -> dict: resp = self.llm.chat.completions.create(**self._const_kwargs(messages)) usage = resp.usage.model_dump()