feat: CostManager改pydantic结构,以备RPC传参

This commit is contained in:
莘权 马 2023-08-22 19:47:35 +08:00
parent 9600787d63
commit 19767496b1
7 changed files with 26 additions and 31 deletions

View file

@ -4,6 +4,7 @@
@Time : 2023/5/5 23:04
@Author : alexanderwu
@File : base_gpt_api.py
@Desc : mashenquan, 2023/8/22. + try catch
"""
from abc import abstractmethod
from typing import Optional
@ -41,7 +42,11 @@ class BaseGPTAPI(BaseChatbot):
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
message = [self._default_system_msg(), self._user_msg(msg)]
rsp = await self.acompletion_text(message, stream=True)
try:
rsp = await self.acompletion_text(message, stream=True)
except Exception as e:
logger.exception(f"{e}")
raise e
logger.debug(message)
# logger.debug(rsp)
return rsp

View file

@ -8,10 +8,11 @@
"""
import asyncio
import time
from typing import NamedTuple
from typing import NamedTuple, Dict
import openai
from openai.error import APIConnectionError
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
from metagpt.logs import logger
@ -35,7 +36,7 @@ class RateLimiter:
self.rpm = rpm
def split_batches(self, batch):
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
return [batch[i: i + self.rpm] for i in range(0, len(batch), self.rpm)]
async def wait_if_needed(self, num_requests):
current_time = time.time()
@ -56,14 +57,14 @@ class Costs(NamedTuple):
total_budget: float
class CostManager:
class CostManager(BaseModel):
"""计算使用接口的开销"""
def __init__(self, options):
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
self.options = options
self.total_budget = 0
total_prompt_tokens: int = 0
total_completion_tokens: int = 0
total_budget: int = 0
max_budget: int
total_cost: int = 0
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -76,7 +77,8 @@ class CostManager:
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][
"completion"]) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | "
@ -114,18 +116,6 @@ class CostManager:
"""获得所有开销"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
@property
def total_cost(self):
return self.options.get("total_cost", 0)
@total_cost.setter
def total_cost(self, v):
self.options["total_cost"] = v
@property
def max_budget(self):
return self.options.get("max_budget", 0)
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")

View file

@ -9,7 +9,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
from enum import Enum
from typing import Type, TypedDict, Set, Optional
from pydantic import BaseModel
@ -17,7 +17,7 @@ from pydantic import BaseModel
from metagpt.logs import logger
class MessageTag(StrEnum):
class MessageTag(Enum):
Prerequisite = "prerequisite"

View file

@ -30,7 +30,7 @@ class SoftwareCompany(BaseModel):
investment: float = Field(default=10.0)
idea: str = Field(default="")
options: Dict = Field(default=Config().runtime_options)
cost_manager: CostManager = Field(default=CostManager(Config().runtime_options))
cost_manager: CostManager = Field(default=CostManager(**Config().runtime_options))
class Config:
arbitrary_types_allowed = True

View file

@ -19,7 +19,7 @@ from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
async def test_write_code():
api_design = "设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。"
conf = Config()
cost_manager = CostManager(conf.runtime_options)
cost_manager = CostManager(**conf.runtime_options)
llm = LLM(options=conf.runtime_options, cost_manager=cost_manager)
write_code = WriteCode(options=conf.runtime_options, name="write_code", llm=llm)
@ -35,6 +35,6 @@ async def test_write_code():
async def test_write_code_directly():
prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0]
options = Config().runtime_options
llm = LLM(options=options, cost_manager=CostManager(options=options))
llm = LLM(options=options, cost_manager=CostManager(**options))
rsp = await llm.aask(prompt)
logger.info(rsp)

View file

@ -26,7 +26,7 @@ def env():
def test_add_role(env: Environment):
conf = Config()
cost_manager = CostManager(options=conf.runtime_options)
cost_manager = CostManager(**conf.runtime_options)
role = ProductManager(options=conf.runtime_options,
cost_manager=cost_manager,
name="Alice",
@ -39,7 +39,7 @@ def test_add_role(env: Environment):
def test_get_roles(env: Environment):
conf = Config()
cost_manager = CostManager(options=conf.runtime_options)
cost_manager = CostManager(**conf.runtime_options)
role1 = Role(options=conf.runtime_options, cost_manager=cost_manager, name="Alice", profile="product manager",
goal="create a new product", constraints="limited resources")
role2 = Role(options=conf.runtime_options, cost_manager=cost_manager, name="Bob", profile="engineer",
@ -53,7 +53,7 @@ def test_get_roles(env: Environment):
@pytest.mark.asyncio
async def test_publish_and_process_message(env: Environment):
conf = Config()
cost_manager = CostManager(options=conf.runtime_options)
cost_manager = CostManager(**conf.runtime_options)
product_manager = ProductManager(options=conf.runtime_options,
cost_manager=cost_manager,
name="Alice", profile="Product Manager",

View file

@ -16,7 +16,7 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM, CostManager
@pytest.fixture()
def llm():
options = Config().runtime_options
return LLM(options=options, cost_manager=CostManager(options))
return LLM(options=options, cost_manager=CostManager(**options))
@pytest.mark.asyncio