mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
feat: CostManager改pydantic结构,以备RPC传参
This commit is contained in:
parent
9600787d63
commit
19767496b1
7 changed files with 26 additions and 31 deletions
|
|
@ -4,6 +4,7 @@
|
|||
@Time : 2023/5/5 23:04
|
||||
@Author : alexanderwu
|
||||
@File : base_gpt_api.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
|
@ -41,7 +42,11 @@ class BaseGPTAPI(BaseChatbot):
|
|||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = await self.acompletion_text(message, stream=True)
|
||||
try:
|
||||
rsp = await self.acompletion_text(message, stream=True)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}")
|
||||
raise e
|
||||
logger.debug(message)
|
||||
# logger.debug(rsp)
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@
|
|||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Dict
|
||||
|
||||
import openai
|
||||
from openai.error import APIConnectionError
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -35,7 +36,7 @@ class RateLimiter:
|
|||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
return [batch[i: i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
|
|
@ -56,14 +57,14 @@ class Costs(NamedTuple):
|
|||
total_budget: float
|
||||
|
||||
|
||||
class CostManager:
|
||||
class CostManager(BaseModel):
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
def __init__(self, options):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.options = options
|
||||
self.total_budget = 0
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
total_budget: int = 0
|
||||
max_budget: int
|
||||
total_cost: int = 0
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
|
|
@ -76,7 +77,8 @@ class CostManager:
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][
|
||||
"completion"]) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
|
||||
|
|
@ -114,18 +116,6 @@ class CostManager:
|
|||
"""获得所有开销"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
||||
@property
|
||||
def total_cost(self):
|
||||
return self.options.get("total_cost", 0)
|
||||
|
||||
@total_cost.setter
|
||||
def total_cost(self, v):
|
||||
self.options["total_cost"] = v
|
||||
|
||||
@property
|
||||
def max_budget(self):
|
||||
return self.options.get("max_budget", 0)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from enum import Enum
|
||||
from typing import Type, TypedDict, Set, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
|||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class MessageTag(StrEnum):
|
||||
class MessageTag(Enum):
|
||||
Prerequisite = "prerequisite"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class SoftwareCompany(BaseModel):
|
|||
investment: float = Field(default=10.0)
|
||||
idea: str = Field(default="")
|
||||
options: Dict = Field(default=Config().runtime_options)
|
||||
cost_manager: CostManager = Field(default=CostManager(Config().runtime_options))
|
||||
cost_manager: CostManager = Field(default=CostManager(**Config().runtime_options))
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
|||
async def test_write_code():
|
||||
api_design = "设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。"
|
||||
conf = Config()
|
||||
cost_manager = CostManager(conf.runtime_options)
|
||||
cost_manager = CostManager(**conf.runtime_options)
|
||||
llm = LLM(options=conf.runtime_options, cost_manager=cost_manager)
|
||||
write_code = WriteCode(options=conf.runtime_options, name="write_code", llm=llm)
|
||||
|
||||
|
|
@ -35,6 +35,6 @@ async def test_write_code():
|
|||
async def test_write_code_directly():
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
|
||||
options = Config().runtime_options
|
||||
llm = LLM(options=options, cost_manager=CostManager(options=options))
|
||||
llm = LLM(options=options, cost_manager=CostManager(**options))
|
||||
rsp = await llm.aask(prompt)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def env():
|
|||
|
||||
def test_add_role(env: Environment):
|
||||
conf = Config()
|
||||
cost_manager = CostManager(options=conf.runtime_options)
|
||||
cost_manager = CostManager(**conf.runtime_options)
|
||||
role = ProductManager(options=conf.runtime_options,
|
||||
cost_manager=cost_manager,
|
||||
name="Alice",
|
||||
|
|
@ -39,7 +39,7 @@ def test_add_role(env: Environment):
|
|||
|
||||
def test_get_roles(env: Environment):
|
||||
conf = Config()
|
||||
cost_manager = CostManager(options=conf.runtime_options)
|
||||
cost_manager = CostManager(**conf.runtime_options)
|
||||
role1 = Role(options=conf.runtime_options, cost_manager=cost_manager, name="Alice", profile="product manager",
|
||||
goal="create a new product", constraints="limited resources")
|
||||
role2 = Role(options=conf.runtime_options, cost_manager=cost_manager, name="Bob", profile="engineer",
|
||||
|
|
@ -53,7 +53,7 @@ def test_get_roles(env: Environment):
|
|||
@pytest.mark.asyncio
|
||||
async def test_publish_and_process_message(env: Environment):
|
||||
conf = Config()
|
||||
cost_manager = CostManager(options=conf.runtime_options)
|
||||
cost_manager = CostManager(**conf.runtime_options)
|
||||
product_manager = ProductManager(options=conf.runtime_options,
|
||||
cost_manager=cost_manager,
|
||||
name="Alice", profile="Product Manager",
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM, CostManager
|
|||
@pytest.fixture()
|
||||
def llm():
|
||||
options = Config().runtime_options
|
||||
return LLM(options=options, cost_manager=CostManager(options))
|
||||
return LLM(options=options, cost_manager=CostManager(**options))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue