mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
fixbug: CONFIG initialization
This commit is contained in:
parent
58369c4e3a
commit
e201bf71d9
5 changed files with 109 additions and 89 deletions
|
|
@ -4,6 +4,7 @@
|
|||
Provide configuration, singleton.
|
||||
@Modified BY: mashenquan, 2023/8/28. Replace the global variable `CONFIG` with `ContextVar`.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
|
@ -14,6 +15,7 @@ import yaml
|
|||
from metagpt.const import PROJECT_ROOT, OPTIONS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
|
|
@ -43,12 +45,17 @@ class Config(metaclass=Singleton):
|
|||
|
||||
def __init__(self, yaml_file=default_yaml_file):
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
self.cost_manager = CostManager(**json.loads(self.COST_MANAGER)) if self.COST_MANAGER else CostManager()
|
||||
|
||||
logger.info("Config loading done.")
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
self.anthropic_api_key = self._get("Anthropic_API_KEY")
|
||||
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and (
|
||||
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
|
||||
not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key
|
||||
):
|
||||
logger.warning("Set OPENAI_API_KEY or Anthropic_API_KEY first")
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
|
|
@ -78,8 +85,7 @@ class Config(metaclass=Singleton):
|
|||
self.long_term_memory = self._get("LONG_TERM_MEMORY", False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
self.cost_manager.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
self.mmdc = self._get("MMDC", "mmdc")
|
||||
|
|
@ -109,7 +115,8 @@ class Config(metaclass=Singleton):
|
|||
return m.get(*args, **kwargs)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables. Throw an error if not found."""
|
||||
"""Retrieve values from config/key.yaml, config/config.yaml, and environment variables.
|
||||
Throw an error if not found."""
|
||||
value = self._get(key, *args, **kwargs)
|
||||
if value is None:
|
||||
raise ValueError(f"Key '{key}' not found in environment variables or in the YAML file")
|
||||
|
|
@ -127,10 +134,12 @@ class Config(metaclass=Singleton):
|
|||
opts = deepcopy(OPTIONS.get())
|
||||
opts.update(options)
|
||||
OPTIONS.set(opts)
|
||||
self._update()
|
||||
|
||||
@property
|
||||
def options(self):
|
||||
"""Return all key-values"""
|
||||
return OPTIONS.get()
|
||||
|
||||
|
||||
CONFIG = Config()
|
||||
|
|
|
|||
|
|
@ -11,19 +11,18 @@ import re
|
|||
import time
|
||||
import random
|
||||
|
||||
from typing import NamedTuple, List
|
||||
from typing import List
|
||||
import traceback
|
||||
import openai
|
||||
from openai.error import APIConnectionError
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
|
|
@ -55,73 +54,6 @@ class RateLimiter:
|
|||
self.last_call_time = time.time()
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
total_prompt_tokens: int
|
||||
total_completion_tokens: int
|
||||
total_cost: float
|
||||
total_budget: float
|
||||
|
||||
|
||||
class CostManager(BaseModel):
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
total_budget: float = 0
|
||||
max_budget: float = CONFIG.max_budget
|
||||
total_cost: float = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][
|
||||
"completion"]) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""获得所有开销"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning("""
|
||||
|
|
@ -136,12 +68,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self, cost_manager=None):
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = cost_manager or CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self, config):
|
||||
|
|
@ -155,9 +86,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await self.async_retry_call(openai.ChatCompletion.acreate,
|
||||
**self._cons_kwargs(messages),
|
||||
stream=True
|
||||
)
|
||||
**self._cons_kwargs(messages),
|
||||
stream=True
|
||||
)
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
|
|
@ -276,12 +207,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
try:
|
||||
prompt_tokens = int(usage['prompt_tokens'])
|
||||
completion_tokens = int(usage['completion_tokens'])
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
CONFIG.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed!", e)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
|
|
@ -366,7 +297,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return None, input_string
|
||||
|
||||
@staticmethod
|
||||
async def async_retry_call(func, *args, **kwargs):
|
||||
async def async_retry_call(func, *args, **kwargs):
|
||||
for i in range(OpenAIGPTAPI.MAX_TRY):
|
||||
try:
|
||||
rsp = await func(*args, **kwargs)
|
||||
|
|
@ -399,4 +330,3 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
raise openai.error.OpenAIError("Exceeds the maximum retries")
|
||||
|
||||
MAX_TRY = 5
|
||||
|
||||
|
|
|
|||
|
|
@ -9,13 +9,14 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, Type, Dict
|
||||
from typing import Iterable, Type
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import Config, CONFIG
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import OPTIONS
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM, CostManager
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory, LongTermMemory
|
||||
|
|
|
|||
|
|
@ -35,12 +35,13 @@ class SoftwareCompany(BaseModel):
|
|||
def invest(self, investment: float):
|
||||
"""Invest company. raise NoMoneyException when exceed max_budget."""
|
||||
self.investment = investment
|
||||
CONFIG.max_budget = investment
|
||||
CONFIG.cost_manager.max_budget = investment
|
||||
logger.info(f'Investment: ${investment}.')
|
||||
|
||||
def _check_balance(self):
|
||||
if CONFIG.total_cost > CONFIG.max_budget:
|
||||
raise NoMoneyException(CONFIG.total_cost, f'Insufficient funds: {CONFIG.max_budget}')
|
||||
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
|
||||
raise NoMoneyException(CONFIG.cost_manager.total_cost,
|
||||
f'Insufficient funds: {CONFIG.cost_manager.max_budget}')
|
||||
|
||||
def start_project(self, idea, role="BOSS", cause_by=BossRequirement, **kwargs):
|
||||
"""Start a project from publishing boss requirement."""
|
||||
|
|
|
|||
79
metagpt/utils/cost_manager.py
Normal file
79
metagpt/utils/cost_manager.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/28
|
||||
@Author : mashenquan
|
||||
@File : openai.py
|
||||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
from typing import NamedTuple
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
total_prompt_tokens: int
|
||||
total_completion_tokens: int
|
||||
total_cost: float
|
||||
total_budget: float
|
||||
|
||||
|
||||
class CostManager(BaseModel):
|
||||
"""Calculate the overhead of using the interface."""
|
||||
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
total_budget: float = 0
|
||||
max_budget: float = 10.0
|
||||
total_cost: float = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][
|
||||
"completion"]) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
"""获得所有开销"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
Loading…
Add table
Add a link
Reference in a new issue