update gemini count_tokens

This commit is contained in:
better629 2023-12-15 17:06:59 +08:00
parent 91d1ab20cc
commit 02090af7cb
2 changed files with 39 additions and 19 deletions

View file

@ -10,14 +10,35 @@ from tenacity import (
wait_fixed,
)
import google.generativeai as genai
from google.generativeai import client
from google.ai import generativelanguage as glm
from google.generativeai.types import content_types
from google.generativeai.generative_models import GenerativeModel
from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse
from google.generativeai.types.generation_types import GenerationConfig
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.openai_api import log_and_reraise
from metagpt.provider.openai_api import CostManager, log_and_reraise
class GeminiGenerativeModel(GenerativeModel):
"""
Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class.
Will use default GenerativeModel if it fixed.
"""
def count_tokens(
self, contents: content_types.ContentsType
) -> glm.CountTokensResponse:
contents = content_types.to_contents(contents)
return self._client.count_tokens(model=self.model_name, contents=contents)
async def count_tokens_async(
self, contents: content_types.ContentsType
) -> glm.CountTokensResponse:
contents = content_types.to_contents(contents)
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
class GeminiGPTAPI(BaseGPTAPI):
@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI):
def __init__(self):
self.__init_gemini(CONFIG)
self.model = "gemini-pro" # so far only one model
self.llm = genai.GenerativeModel(model_name=self.model)
self.llm = GeminiGenerativeModel(model_name=self.model)
self._cost_manager = CostManager()
def __init_gemini(self, config: CONFIG):
genai.configure(api_key=config.gemini_api_key)
@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI):
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("google gemini updats costs failed!", e)
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text
def get_usage(self, messages: list[dict], resp_text: str) -> dict:
prompt_resp = self.llm.count_tokens(contents=messages)
completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]})
req_text = messages[-1]["parts"][0] if messages else ""
prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]})
completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]})
usage = {
"prompt_tokens": prompt_resp.total_tokens,
"completion_tokens": completion_resp.total_tokens
@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI):
return usage
async def aget_usage(self, messages: list[dict], resp_text: str) -> dict:
# fix google-generativeai sdk
if self.llm._client is None:
self.llm._client = client.get_default_generative_client()
# TODO exception to fix
prompt_resp = await self.llm.count_tokens_async(contents=messages)
completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]})
req_text = messages[-1]["parts"][0] if messages else ""
prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]})
completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]})
usage = {
"prompt_tokens": prompt_resp.total_tokens,
"completion_tokens": completion_resp.total_tokens
@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI):
def completion(self, messages: list[dict]) -> "GenerateContentResponse":
resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages))
# usage = self.get_usage(messages, resp.text)
# self._update_costs(usage)
usage = self.get_usage(messages, resp.text)
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
# usage = await self.aget_usage(messages, resp.text)
# self._update_costs(usage)
usage = await self.aget_usage(messages, resp.text)
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI):
collected_content.append(content)
full_content = "".join(collected_content)
# usage = await self.aget_usage(messages, full_content)
# self._update_costs(usage)
usage = await self.aget_usage(messages, full_content)
self._update_costs(usage)
return full_content
@retry(

View file

@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("zhipuai updats costs failed!", e)
logger.error(f"zhipuai updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""