update token usage for streaming api

This commit is contained in:
geekan 2023-07-06 16:14:53 +08:00
parent 9f15569f04
commit 5c9b477691
2 changed files with 14 additions and 6 deletions

View file

@ -13,7 +13,7 @@ from metagpt.logs import logger
class BaseGPTAPI(BaseChatbot):
"""GPT API抽象类,要求所有继承者提供一系列标准能力"""
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
system_prompt = 'You are a helpful assistant.'
def _user_msg(self, msg: str) -> dict[str, str]:

View file

@ -5,8 +5,7 @@
@Author : alexanderwu
@File : openai.py
"""
import json
from typing import Union, NamedTuple
from typing import NamedTuple
from functools import wraps
import asyncio
import time
@ -169,6 +168,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
print()
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
async def _achat_completion(self, messages: list[dict]) -> dict:
@ -180,7 +181,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
stop=None,
temperature=0.5,
)
self._update_costs(rsp)
self._update_costs(rsp.get('usage'))
return rsp
def _chat_completion(self, messages: list[dict]) -> dict:
@ -213,6 +214,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)
def _calc_usage(self, messages: list[dict], rsp: str) -> dict:
usage = {}
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage['prompt_tokens'] = prompt_tokens
usage['completion_tokens'] = completion_tokens
return usage
async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]:
"""返回完整JSON"""
split_batches = self.split_batches(batch)
@ -239,8 +248,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
logger.info(f"Result of task {idx}: {result}")
return results
def _update_costs(self, response: dict):
usage = response.get('usage')
def _update_costs(self, usage: dict):
prompt_tokens = int(usage['prompt_tokens'])
completion_tokens = int(usage['completion_tokens'])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)