From e201bf71d912542a4b4541528881583cb28e128a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 28 Aug 2023 22:04:06 +0800 Subject: [PATCH] fixbug: CONFIG initialization --- metagpt/config.py | 17 +++++-- metagpt/provider/openai_api.py | 88 ++++------------------------------ metagpt/roles/role.py | 7 +-- metagpt/software_company.py | 7 +-- metagpt/utils/cost_manager.py | 79 ++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 89 deletions(-) create mode 100644 metagpt/utils/cost_manager.py diff --git a/metagpt/config.py b/metagpt/config.py index 05949408d..4cae79b17 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -4,6 +4,7 @@ Provide configuration, singleton. @Modified BY: mashenquan, 2023/8/28. Replace the global variable `CONFIG` with `ContextVar`. """ +import json import os from copy import deepcopy from typing import Any @@ -14,6 +15,7 @@ import yaml from metagpt.const import PROJECT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.cost_manager import CostManager from metagpt.utils.singleton import Singleton @@ -43,12 +45,17 @@ class Config(metaclass=Singleton): def __init__(self, yaml_file=default_yaml_file): self._init_with_config_files_and_env(yaml_file) + self.cost_manager = CostManager(**json.loads(self.COST_MANAGER)) if self.COST_MANAGER else CostManager() + logger.info("Config loading done.") + self._update() + + def _update(self): self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("Anthropic_API_KEY") if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and ( - not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key + not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key ): logger.warning("Set OPENAI_API_KEY or Anthropic_API_KEY first") self.openai_api_base = self._get("OPENAI_API_BASE") @@ -78,8 +85,7 @@ class Config(metaclass=Singleton): self.long_term_memory = self._get("LONG_TERM_MEMORY", False) if self.long_term_memory: logger.warning("LONG_TERM_MEMORY is True") - self.max_budget = self._get("MAX_BUDGET", 10.0) - self.total_cost = 0.0 + self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0) self.puppeteer_config = self._get("PUPPETEER_CONFIG", "") self.mmdc = self._get("MMDC", "mmdc") @@ -109,7 +115,8 @@ class Config(metaclass=Singleton): return m.get(*args, **kwargs) def get(self, key, *args, **kwargs): - """Retrieve values from config/key.yaml, config/config.yaml, and environment variables. Throw an error if not found.""" + """Retrieve values from config/key.yaml, config/config.yaml, and environment variables. + Throw an error if not found.""" value = self._get(key, *args, **kwargs) if value is None: raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file") @@ -127,10 +134,12 @@ class Config(metaclass=Singleton): opts = deepcopy(OPTIONS.get()) opts.update(options) OPTIONS.set(opts) + self._update() @property def options(self): """Return all key-values""" return OPTIONS.get() + CONFIG = Config() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 7dba00530..e4dfade78 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -11,19 +11,18 @@ import re import time import random -from typing import NamedTuple, List +from typing import List import traceback import openai from openai.error import APIConnectionError -from pydantic import BaseModel from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type from metagpt.config import CONFIG from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.utils.cost_manager import Costs from metagpt.utils.token_counter import ( - TOKEN_COSTS, count_message_tokens, count_string_tokens, get_max_completion_tokens, @@ -55,73 +54,6 @@ class RateLimiter: self.last_call_time = time.time() -class Costs(NamedTuple): - total_prompt_tokens: int - total_completion_tokens: int - total_cost: float - total_budget: float - - -class CostManager(BaseModel): - """计算使用接口的开销""" - - total_prompt_tokens: int = 0 - total_completion_tokens: int = 0 - total_budget: float = 0 - max_budget: float = CONFIG.max_budget - total_cost: float = 0 - - def update_cost(self, prompt_tokens, completion_tokens, model): - """ - Update the total cost, prompt tokens, and completion tokens. - - Args: - prompt_tokens (int): The number of tokens used in the prompt. - completion_tokens (int): The number of tokens used in the completion. - model (str): The model used for the API call. - """ - self.total_prompt_tokens += prompt_tokens - self.total_completion_tokens += completion_tokens - cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][ - "completion"]) / 1000 - self.total_cost += cost - logger.info( - f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | " - f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" - ) - - def get_total_prompt_tokens(self): - """ - Get the total number of prompt tokens. - - Returns: - int: The total number of prompt tokens. - """ - return self.total_prompt_tokens - - def get_total_completion_tokens(self): - """ - Get the total number of completion tokens. - - Returns: - int: The total number of completion tokens. - """ - return self.total_completion_tokens - - def get_total_cost(self): - """ - Get the total cost of API calls. - - Returns: - float: The total cost of API calls. - """ - return self.total_cost - - def get_costs(self) -> Costs: - """获得所有开销""" - return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) - - def log_and_reraise(retry_state): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning(""" @@ -136,12 +68,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): Check https://platform.openai.com/examples for examples """ - def __init__(self, cost_manager=None): + def __init__(self): self.__init_openai(CONFIG) self.llm = openai self.model = CONFIG.openai_api_model self.auto_max_tokens = False - self._cost_manager = cost_manager or CostManager() RateLimiter.__init__(self, rpm=self.rpm) def __init_openai(self, config): @@ -155,9 +86,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): async def _achat_completion_stream(self, messages: list[dict]) -> str: response = await self.async_retry_call(openai.ChatCompletion.acreate, - **self._cons_kwargs(messages), - stream=True - ) + **self._cons_kwargs(messages), + stream=True + ) # create variables to collect the stream of chunks collected_chunks = [] collected_messages = [] @@ -276,12 +207,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): try: prompt_tokens = int(usage['prompt_tokens']) completion_tokens = int(usage['completion_tokens']) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error("updating costs failed!", e) def get_costs(self) -> Costs: - return self._cost_manager.get_costs() + return CONFIG.cost_manager.get_costs() def get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: @@ -366,7 +297,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return None, input_string @staticmethod - async def async_retry_call(func, *args, **kwargs): + async def async_retry_call(func, *args, **kwargs): for i in range(OpenAIGPTAPI.MAX_TRY): try: rsp = await func(*args, **kwargs) @@ -399,4 +330,3 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): raise openai.error.OpenAIError("Exceeds the maximum retries") MAX_TRY = 5 - diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index a1ac0d9e7..5d2cce802 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -9,13 +9,14 @@ """ from __future__ import annotations -from typing import Iterable, Type, Dict +from typing import Iterable, Type + from pydantic import BaseModel, Field -from metagpt.config import Config, CONFIG +from metagpt.config import CONFIG from metagpt.const import OPTIONS -from metagpt.provider.openai_api import OpenAIGPTAPI as LLM, CostManager +from metagpt.llm import LLM from metagpt.actions import Action, ActionOutput from metagpt.logs import logger from metagpt.memory import Memory, LongTermMemory diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 8d9c990ee..cfa3bd492 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -35,12 +35,13 @@ class SoftwareCompany(BaseModel): def invest(self, investment: float): """Invest company. raise NoMoneyException when exceed max_budget.""" self.investment = investment - CONFIG.max_budget = investment + CONFIG.cost_manager.max_budget = investment logger.info(f'Investment: ${investment}.') def _check_balance(self): - if CONFIG.total_cost > CONFIG.max_budget: - raise NoMoneyException(CONFIG.total_cost, f'Insufficient funds: {CONFIG.max_budget}') + if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget: + raise NoMoneyException(CONFIG.cost_manager.total_cost, + f'Insufficient funds: {CONFIG.cost_manager.max_budget}') def start_project(self, idea, role="BOSS", cause_by=BossRequirement, **kwargs): """Start a project from publishing boss requirement.""" diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py new file mode 100644 index 000000000..21b37d552 --- /dev/null +++ b/metagpt/utils/cost_manager.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : openai.py +@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting. +""" + +from pydantic import BaseModel +from metagpt.logs import logger +from metagpt.utils.token_counter import TOKEN_COSTS +from typing import NamedTuple + + +class Costs(NamedTuple): + total_prompt_tokens: int + total_completion_tokens: int + total_cost: float + total_budget: float + + +class CostManager(BaseModel): + """Calculate the overhead of using the interface.""" + + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + total_budget: float = 0 + max_budget: float = 10.0 + total_cost: float = 0 + + def update_cost(self, prompt_tokens, completion_tokens, model): + """ + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][ + "completion"]) / 1000 + self.total_cost += cost + logger.info( + f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | " + f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) + + def get_total_prompt_tokens(self): + """ + Get the total number of prompt tokens. + + Returns: + int: The total number of prompt tokens. + """ + return self.total_prompt_tokens + + def get_total_completion_tokens(self): + """ + Get the total number of completion tokens. + + Returns: + int: The total number of completion tokens. + """ + return self.total_completion_tokens + + def get_total_cost(self): + """ + Get the total cost of API calls. + + Returns: + float: The total cost of API calls. + """ + return self.total_cost + + def get_costs(self) -> Costs: + """获得所有开销""" + return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)