From bef8d64193c4ce783432e6f958fd5c0858ea7e00 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:45:40 +0800 Subject: [PATCH 1/4] add google gemini --- config/config.yaml | 4 + metagpt/config.py | 8 +- metagpt/llm.py | 3 + metagpt/provider/google_gemini_api.py | 130 ++++++++++++++++++ metagpt/utils/token_counter.py | 7 +- requirements.txt | 1 + .../provider/test_google_gemini_api.py | 43 ++++++ 7 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 metagpt/provider/google_gemini_api.py create mode 100644 tests/metagpt/provider/test_google_gemini_api.py diff --git a/config/config.yaml b/config/config.yaml index 080de4000..596a31341 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -34,6 +34,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 2ce75b013..3b46c8504 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -51,13 +51,17 @@ class Config(metaclass=Singleton): self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + + self.gemini_api_key = self._get("GEMINI_API_KEY") + if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \ (not self.open_llm_api_base) and \ - (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key): + (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) and \ + (not self.gemini_api_key or "YOUR_API_KEY" in self.gemini_api_key): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE") + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE or GEMINI_API_KEY") self.openai_api_base = self._get("OPENAI_API_BASE") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: diff --git a/metagpt/llm.py b/metagpt/llm.py index 7b490ec4a..b13fc723a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -14,6 +14,7 @@ from metagpt.provider.spark_api import SparkAPI from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider +from metagpt.provider.google_gemini_api import GeminiGPTAPI def LLM() -> "BaseGPTAPI": @@ -29,6 +30,8 @@ def LLM() -> "BaseGPTAPI": llm = OpenLLMGPTAPI() elif CONFIG.fireworks_api_key: llm = FireWorksGPTAPI() + elif CONFIG.gemini_api_key: + llm = GeminiGPTAPI() else: raise RuntimeError("You should config a LLM configuration first") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..1c866ebad --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +import google.generativeai as genai +from google.generativeai import client +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + use_system_prompt: bool = False # google gemini has no system prompt when use api + + def __init__(self): + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = genai.GenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "contents": messages, + "generation_config": GenerationConfig( + temperature=0.3 + ), + "stream": stream + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("google gemini updats costs failed!", e) + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + prompt_resp = self.llm.count_tokens(contents=messages) + completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + # fix google-generativeai sdk + if self.llm._client is None: + self.llm._client = client.get_default_generative_client() + # TODO exception to fix + prompt_resp = await self.llm.count_tokens_async(contents=messages) + completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + # usage = self.get_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + # usage = await self.aget_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, + stream=True)) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + # usage = await self.aget_usage(messages, full_content) + # self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ba63e90a9..6d9cbd137 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -24,7 +25,8 @@ TOKEN_COSTS = { "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens + "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} } @@ -42,7 +44,8 @@ TOKEN_MAX = { "gpt-4-0613": 8192, "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, - "chatglm_turbo": 32768 + "chatglm_turbo": 32768, + "gemini-pro": 32768 } diff --git a/requirements.txt b/requirements.txt index 14a9f485d..a2aaff48b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,4 @@ semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 zhipuai==1.0.7 +google-generativeai==0.3.1 \ No newline at end of file diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..32ed11ba5 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +import pytest +from abc import ABC +from dataclasses import dataclass + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + + +messages = [ + {"role": "user", "content": "who are you"} +] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text From 9fb6e7c459a24489028ebe55a4ed2032d689eac1 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:54:56 +0800 Subject: [PATCH 2/4] update gemini user_msg doc --- metagpt/provider/google_gemini_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 1c866ebad..a69ffdc28 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -36,6 +36,8 @@ class GeminiGPTAPI(BaseGPTAPI): genai.configure(api_key=config.gemini_api_key) def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. return {"role": "user", "parts": [msg]} def _assistant_msg(self, msg: str) -> dict[str, str]: From 4127ef85704a7771b484c8c73912e1919ef0be09 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:06:59 +0800 Subject: [PATCH 3/4] update gemini count_tokens --- metagpt/provider/google_gemini_api.py | 56 ++++++++++++++++++--------- metagpt/provider/zhipuai_api.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index a69ffdc28..0ba1e86c1 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -10,14 +10,35 @@ from tenacity import ( wait_fixed, ) import google.generativeai as genai -from google.generativeai import client +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from google.generativeai.generative_models import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse from google.generativeai.types.generation_types import GenerationConfig from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.openai_api import log_and_reraise +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) class GeminiGPTAPI(BaseGPTAPI): @@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI): def __init__(self): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model - self.llm = genai.GenerativeModel(model_name=self.model) + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("google gemini updats costs failed!", e) + logger.error(f"google gemini updats costs failed! exp: {e}") def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text def get_usage(self, messages: list[dict], resp_text: str) -> dict: - prompt_resp = self.llm.count_tokens(contents=messages) - completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI): return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: - # fix google-generativeai sdk - if self.llm._client is None: - self.llm._client = client.get_default_generative_client() - # TODO exception to fix - prompt_resp = await self.llm.count_tokens_async(contents=messages) - completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI): def completion(self, messages: list[dict]) -> "GenerateContentResponse": resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) - # usage = self.get_usage(messages, resp.text) - # self._update_costs(usage) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) return resp async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) - # usage = await self.aget_usage(messages, resp.text) - # self._update_costs(usage) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) return resp async def acompletion(self, messages: list[dict]) -> dict: @@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content.append(content) full_content = "".join(collected_content) - # usage = await self.aget_usage(messages, full_content) - # self._update_costs(usage) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 3161c0e88..3b24ca98f 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -65,7 +65,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """ get the first text of choice from llm response """ From 70cbfb1e480367bb9586a62b7b723c80c57aa4f0 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:30:25 +0800 Subject: [PATCH 4/4] retry use wait_random_exponential --- metagpt/provider/google_gemini_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 0ba1e86c1..b68e013a0 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -7,7 +7,7 @@ from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) import google.generativeai as genai from google.ai import generativelanguage as glm @@ -139,7 +139,7 @@ class GeminiGPTAPI(BaseGPTAPI): @retry( stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise