fixbug: CONFIG initialization

This commit is contained in:
莘权 马 2023-08-28 22:04:06 +08:00
parent 58369c4e3a
commit e201bf71d9
5 changed files with 109 additions and 89 deletions

View file

@ -4,6 +4,7 @@
Provide configuration, singleton.
@Modified BY: mashenquan, 2023/8/28. Replace the global variable `CONFIG` with `ContextVar`.
"""
import json
import os
from copy import deepcopy
from typing import Any
@ -14,6 +15,7 @@ import yaml
from metagpt.const import PROJECT_ROOT, OPTIONS
from metagpt.logs import logger
from metagpt.tools import SearchEngineType, WebBrowserEngineType
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.singleton import Singleton
@ -43,12 +45,17 @@ class Config(metaclass=Singleton):
def __init__(self, yaml_file=default_yaml_file):
self._init_with_config_files_and_env(yaml_file)
self.cost_manager = CostManager(**json.loads(self.COST_MANAGER)) if self.COST_MANAGER else CostManager()
logger.info("Config loading done.")
self._update()
def _update(self):
self.global_proxy = self._get("GLOBAL_PROXY")
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("Anthropic_API_KEY")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and (
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
):
logger.warning("Set OPENAI_API_KEY or Anthropic_API_KEY first")
self.openai_api_base = self._get("OPENAI_API_BASE")
@ -78,8 +85,7 @@ class Config(metaclass=Singleton):
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")
self.max_budget = self._get("MAX_BUDGET", 10.0)
self.total_cost = 0.0
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
self.mmdc = self._get("MMDC", "mmdc")
@ -109,7 +115,8 @@ class Config(metaclass=Singleton):
return m.get(*args, **kwargs)
def get(self, key, *args, **kwargs):
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables. Throw an error if not found."""
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
Throw an error if not found."""
value = self._get(key, *args, **kwargs)
if value is None:
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
@ -127,10 +134,12 @@ class Config(metaclass=Singleton):
opts = deepcopy(OPTIONS.get())
opts.update(options)
OPTIONS.set(opts)
self._update()
@property
def options(self):
"""Return all key-values"""
return OPTIONS.get()
CONFIG = Config()

View file

@ -11,19 +11,18 @@ import re
import time
import random
from typing import NamedTuple, List
from typing import List
import traceback
import openai
from openai.error import APIConnectionError
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.utils.cost_manager import Costs
from metagpt.utils.token_counter import (
TOKEN_COSTS,
count_message_tokens,
count_string_tokens,
get_max_completion_tokens,
@ -55,73 +54,6 @@ class RateLimiter:
self.last_call_time = time.time()
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(BaseModel):
"""计算使用接口的开销"""
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_budget: float = 0
max_budget: float = CONFIG.max_budget
total_cost: float = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][
"completion"]) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
def get_total_prompt_tokens(self):
"""
Get the total number of prompt tokens.
Returns:
int: The total number of prompt tokens.
"""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""
Get the total number of completion tokens.
Returns:
int: The total number of completion tokens.
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""获得所有开销"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning("""
@ -136,12 +68,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
Check https://platform.openai.com/examples for examples
"""
def __init__(self, cost_manager=None):
def __init__(self):
self.__init_openai(CONFIG)
self.llm = openai
self.model = CONFIG.openai_api_model
self.auto_max_tokens = False
self._cost_manager = cost_manager or CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config):
@ -155,9 +86,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await self.async_retry_call(openai.ChatCompletion.acreate,
**self._cons_kwargs(messages),
stream=True
)
**self._cons_kwargs(messages),
stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
@ -276,12 +207,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
try:
prompt_tokens = int(usage['prompt_tokens'])
completion_tokens = int(usage['completion_tokens'])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed!", e)
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
return CONFIG.cost_manager.get_costs()
def get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
@ -366,7 +297,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return None, input_string
@staticmethod
async def async_retry_call(func, *args, **kwargs):
async def async_retry_call(func, *args, **kwargs):
for i in range(OpenAIGPTAPI.MAX_TRY):
try:
rsp = await func(*args, **kwargs)
@ -399,4 +330,3 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
raise openai.error.OpenAIError("Exceeds the maximum retries")
MAX_TRY = 5

View file

@ -9,13 +9,14 @@
"""
from __future__ import annotations
from typing import Iterable, Type, Dict
from typing import Iterable, Type
from pydantic import BaseModel, Field
from metagpt.config import Config, CONFIG
from metagpt.config import CONFIG
from metagpt.const import OPTIONS
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM, CostManager
from metagpt.llm import LLM
from metagpt.actions import Action, ActionOutput
from metagpt.logs import logger
from metagpt.memory import Memory, LongTermMemory

View file

@ -35,12 +35,13 @@ class SoftwareCompany(BaseModel):
def invest(self, investment: float):
"""Invest company. raise NoMoneyException when exceed max_budget."""
self.investment = investment
CONFIG.max_budget = investment
CONFIG.cost_manager.max_budget = investment
logger.info(f'Investment: ${investment}.')
def _check_balance(self):
if CONFIG.total_cost > CONFIG.max_budget:
raise NoMoneyException(CONFIG.total_cost, f'Insufficient funds: {CONFIG.max_budget}')
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
raise NoMoneyException(CONFIG.cost_manager.total_cost,
f'Insufficient funds: {CONFIG.cost_manager.max_budget}')
def start_project(self, idea, role="BOSS", cause_by=BossRequirement, **kwargs):
"""Start a project from publishing boss requirement."""

View file

@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/8/28
@Author : mashenquan
@File : openai.py
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
"""
from pydantic import BaseModel
from metagpt.logs import logger
from metagpt.utils.token_counter import TOKEN_COSTS
from typing import NamedTuple
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(BaseModel):
"""Calculate the overhead of using the interface."""
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][
"completion"]) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
def get_total_prompt_tokens(self):
"""
Get the total number of prompt tokens.
Returns:
int: The total number of prompt tokens.
"""
return self.total_prompt_tokens
def get_total_completion_tokens(self):
"""
Get the total number of completion tokens.
Returns:
int: The total number of completion tokens.
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
Returns:
float: The total cost of API calls.
"""
return self.total_cost
def get_costs(self) -> Costs:
"""获得所有开销"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)