From 5c9b477691ba94621df0ff384cf052b2010a76f4 Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 6 Jul 2023 16:14:53 +0800 Subject: [PATCH] update token usage for streaming api --- metagpt/provider/base_gpt_api.py | 2 +- metagpt/provider/openai_api.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index 4046c08f0..972982dc7 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -13,7 +13,7 @@ from metagpt.logs import logger class BaseGPTAPI(BaseChatbot): - """GPT API抽象类,要求所有继承者提供一系列标准能力""" + """GPT API abstract class, requiring all inheritors to provide a series of standard capabilities""" system_prompt = 'You are a helpful assistant.' def _user_msg(self, msg: str) -> dict[str, str]: diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index c4b82d149..9cd891e87 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -5,8 +5,7 @@ @Author : alexanderwu @File : openai.py """ -import json -from typing import Union, NamedTuple +from typing import NamedTuple from functools import wraps import asyncio import time @@ -169,6 +168,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): print() full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) + usage = self._calc_usage(messages, full_reply_content) + self._update_costs(usage) return full_reply_content async def _achat_completion(self, messages: list[dict]) -> dict: @@ -180,7 +181,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): stop=None, temperature=0.5, ) - self._update_costs(rsp) + self._update_costs(rsp.get('usage')) return rsp def _chat_completion(self, messages: list[dict]) -> dict: @@ -213,6 +214,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): rsp = await self._achat_completion(messages) return self.get_choice_text(rsp) + def _calc_usage(self, messages: list[dict], rsp: str) -> dict: + usage = {} + prompt_tokens = count_message_tokens(messages, self.model) + completion_tokens = count_string_tokens(rsp, self.model) + usage['prompt_tokens'] = prompt_tokens + usage['completion_tokens'] = completion_tokens + return usage + async def acompletion_batch(self, batch: list[list[dict]]) -> list[dict]: """返回完整JSON""" split_batches = self.split_batches(batch) @@ -239,8 +248,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.info(f"Result of task {idx}: {result}") return results - def _update_costs(self, response: dict): - usage = response.get('usage') + def _update_costs(self, usage: dict): prompt_tokens = int(usage['prompt_tokens']) completion_tokens = int(usage['completion_tokens']) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)