simplify _update_costs and related code

This commit is contained in:
better629 2024-02-07 16:16:14 +08:00
parent d180d3912e
commit 15a9c5e941
8 changed files with 19 additions and 77 deletions

View file

@ -11,11 +11,12 @@ from abc import ABC, abstractmethod
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs
class BaseLLM(ABC):
@ -67,14 +68,15 @@ class BaseLLM(ABC):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def _update_costs(self, usage: dict, model: str = None, local_calc_usage: bool = True):
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
"""update each request's token cost
Args:
model (str): model name or in some scenarios called endpoint
local_calc_usage (bool): some models don't calculate usage, it will overwrite calc_usage
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
model = model if model else self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
@ -83,6 +85,11 @@ class BaseLLM(ABC):
except Exception as e:
logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}")
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
async def aask(
self,
msg: str,

View file

@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM):
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
@ -113,7 +102,7 @@ class FireworksLLM(OpenAILLM):
usage = CompletionUsage(**chunk.usage)
full_content = "".join(collected_content)
self._update_costs(usage)
self._update_costs(usage.model_dump())
return full_content
@retry(

View file

@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})

View file

@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.cost_manager import TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM):
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -29,7 +29,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -55,16 +55,13 @@ class OpenAILLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -240,16 +237,6 @@ class OpenAILLM(BaseLLM):
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token

View file

@ -78,7 +78,7 @@ class QianFanLLM(BaseLLM):
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
self.client = qianfan.ChatCompletion()
self.aclient = qianfan.ChatCompletion()
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
@ -110,12 +110,12 @@ class QianFanLLM(BaseLLM):
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.client.do(**self._const_kwargs(messages=messages, stream=False))
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=False))
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
@ -123,7 +123,7 @@ class QianFanLLM(BaseLLM):
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.client.ado(**self._const_kwargs(messages=messages, stream=True))
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:

View file

@ -57,16 +57,6 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()