mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
update gemini count_tokens
This commit is contained in:
parent
91d1ab20cc
commit
02090af7cb
2 changed files with 39 additions and 19 deletions
|
|
@ -10,14 +10,35 @@ from tenacity import (
|
|||
wait_fixed,
|
||||
)
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import client
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.generativeai.types import content_types
|
||||
from google.generativeai.generative_models import GenerativeModel
|
||||
from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse
|
||||
from google.generativeai.types.generation_types import GenerationConfig
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
"""
|
||||
Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class.
|
||||
Will use default GenerativeModel if it fixed.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, contents: content_types.ContentsType
|
||||
) -> glm.CountTokensResponse:
|
||||
contents = content_types.to_contents(contents)
|
||||
return self._client.count_tokens(model=self.model_name, contents=contents)
|
||||
|
||||
async def count_tokens_async(
|
||||
self, contents: content_types.ContentsType
|
||||
) -> glm.CountTokensResponse:
|
||||
contents = content_types.to_contents(contents)
|
||||
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
|
||||
|
||||
|
||||
class GeminiGPTAPI(BaseGPTAPI):
|
||||
|
|
@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
def __init__(self):
|
||||
self.__init_gemini(CONFIG)
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = genai.GenerativeModel(model_name=self.model)
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
self._cost_manager = CostManager()
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
|
|
@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("google gemini updats costs failed!", e)
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
def get_usage(self, messages: list[dict], resp_text: str) -> dict:
|
||||
prompt_resp = self.llm.count_tokens(contents=messages)
|
||||
completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]})
|
||||
req_text = messages[-1]["parts"][0] if messages else ""
|
||||
prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]})
|
||||
completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]})
|
||||
usage = {
|
||||
"prompt_tokens": prompt_resp.total_tokens,
|
||||
"completion_tokens": completion_resp.total_tokens
|
||||
|
|
@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
return usage
|
||||
|
||||
async def aget_usage(self, messages: list[dict], resp_text: str) -> dict:
|
||||
# fix google-generativeai sdk
|
||||
if self.llm._client is None:
|
||||
self.llm._client = client.get_default_generative_client()
|
||||
# TODO exception to fix
|
||||
prompt_resp = await self.llm.count_tokens_async(contents=messages)
|
||||
completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]})
|
||||
req_text = messages[-1]["parts"][0] if messages else ""
|
||||
prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]})
|
||||
completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]})
|
||||
usage = {
|
||||
"prompt_tokens": prompt_resp.total_tokens,
|
||||
"completion_tokens": completion_resp.total_tokens
|
||||
|
|
@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
|
||||
def completion(self, messages: list[dict]) -> "GenerateContentResponse":
|
||||
resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages))
|
||||
# usage = self.get_usage(messages, resp.text)
|
||||
# self._update_costs(usage)
|
||||
usage = self.get_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
# usage = await self.aget_usage(messages, resp.text)
|
||||
# self._update_costs(usage)
|
||||
usage = await self.aget_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
|
|
@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
collected_content.append(content)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
# usage = await self.aget_usage(messages, full_content)
|
||||
# self._update_costs(usage)
|
||||
usage = await self.aget_usage(messages, full_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("zhipuai updats costs failed!", e)
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue