From 02090af7cb1a315b2b59ea843fa7aa8bb816cf4e Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:06:59 +0800 Subject: [PATCH] update gemini count_tokens --- metagpt/provider/google_gemini_api.py | 56 ++++++++++++++++++--------- metagpt/provider/zhipuai_api.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index a69ffdc28..0ba1e86c1 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -10,14 +10,35 @@ from tenacity import ( wait_fixed, ) import google.generativeai as genai -from google.generativeai import client +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from google.generativeai.generative_models import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse from google.generativeai.types.generation_types import GenerationConfig from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.openai_api import log_and_reraise +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) class GeminiGPTAPI(BaseGPTAPI): @@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI): def __init__(self): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model - self.llm = genai.GenerativeModel(model_name=self.model) + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("google gemini updats costs failed!", e) + logger.error(f"google gemini updats costs failed! exp: {e}") def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text def get_usage(self, messages: list[dict], resp_text: str) -> dict: - prompt_resp = self.llm.count_tokens(contents=messages) - completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI): return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: - # fix google-generativeai sdk - if self.llm._client is None: - self.llm._client = client.get_default_generative_client() - # TODO exception to fix - prompt_resp = await self.llm.count_tokens_async(contents=messages) - completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI): def completion(self, messages: list[dict]) -> "GenerateContentResponse": resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) - # usage = self.get_usage(messages, resp.text) - # self._update_costs(usage) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) return resp async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) - # usage = await self.aget_usage(messages, resp.text) - # self._update_costs(usage) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) return resp async def acompletion(self, messages: list[dict]) -> dict: @@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content.append(content) full_content = "".join(collected_content) - # usage = await self.aget_usage(messages, full_content) - # self._update_costs(usage) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index eef0e51e1..60d9a0777 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response"""