From bef8d64193c4ce783432e6f958fd5c0858ea7e00 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:45:40 +0800 Subject: [PATCH 01/12] add google gemini --- config/config.yaml | 4 + metagpt/config.py | 8 +- metagpt/llm.py | 3 + metagpt/provider/google_gemini_api.py | 130 ++++++++++++++++++ metagpt/utils/token_counter.py | 7 +- requirements.txt | 1 + .../provider/test_google_gemini_api.py | 43 ++++++ 7 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 metagpt/provider/google_gemini_api.py create mode 100644 tests/metagpt/provider/test_google_gemini_api.py diff --git a/config/config.yaml b/config/config.yaml index 080de4000..596a31341 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -34,6 +34,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 2ce75b013..3b46c8504 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -51,13 +51,17 @@ class Config(metaclass=Singleton): self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + + self.gemini_api_key = self._get("GEMINI_API_KEY") + if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \ (not self.open_llm_api_base) and \ - (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key): + (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) and \ + (not self.gemini_api_key or "YOUR_API_KEY" in self.gemini_api_key): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE") + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE or GEMINI_API_KEY") self.openai_api_base = self._get("OPENAI_API_BASE") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: diff --git a/metagpt/llm.py b/metagpt/llm.py index 7b490ec4a..b13fc723a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -14,6 +14,7 @@ from metagpt.provider.spark_api import SparkAPI from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider +from metagpt.provider.google_gemini_api import GeminiGPTAPI def LLM() -> "BaseGPTAPI": @@ -29,6 +30,8 @@ def LLM() -> "BaseGPTAPI": llm = OpenLLMGPTAPI() elif CONFIG.fireworks_api_key: llm = FireWorksGPTAPI() + elif CONFIG.gemini_api_key: + llm = GeminiGPTAPI() else: raise RuntimeError("You should config a LLM configuration first") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..1c866ebad --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +import google.generativeai as genai +from google.generativeai import client +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + use_system_prompt: bool = False # google gemini has no system prompt when use api + + def __init__(self): + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = genai.GenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "contents": messages, + "generation_config": GenerationConfig( + temperature=0.3 + ), + "stream": stream + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("google gemini updats costs failed!", e) + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + prompt_resp = self.llm.count_tokens(contents=messages) + completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + # fix google-generativeai sdk + if self.llm._client is None: + self.llm._client = client.get_default_generative_client() + # TODO exception to fix + prompt_resp = await self.llm.count_tokens_async(contents=messages) + completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + # usage = self.get_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + # usage = await self.aget_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, + stream=True)) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + # usage = await self.aget_usage(messages, full_content) + # self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ba63e90a9..6d9cbd137 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -24,7 +25,8 @@ TOKEN_COSTS = { "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens + "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} } @@ -42,7 +44,8 @@ TOKEN_MAX = { "gpt-4-0613": 8192, "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, - "chatglm_turbo": 32768 + "chatglm_turbo": 32768, + "gemini-pro": 32768 } diff --git a/requirements.txt b/requirements.txt index 14a9f485d..a2aaff48b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,4 @@ semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 zhipuai==1.0.7 +google-generativeai==0.3.1 \ No newline at end of file diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..32ed11ba5 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +import pytest +from abc import ABC +from dataclasses import dataclass + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + + +messages = [ + {"role": "user", "content": "who are you"} +] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text From 9fb6e7c459a24489028ebe55a4ed2032d689eac1 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:54:56 +0800 Subject: [PATCH 02/12] update gemini user_msg doc --- metagpt/provider/google_gemini_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 1c866ebad..a69ffdc28 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -36,6 +36,8 @@ class GeminiGPTAPI(BaseGPTAPI): genai.configure(api_key=config.gemini_api_key) def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. return {"role": "user", "parts": [msg]} def _assistant_msg(self, msg: str) -> dict[str, str]: From 4127ef85704a7771b484c8c73912e1919ef0be09 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:06:59 +0800 Subject: [PATCH 03/12] update gemini count_tokens --- metagpt/provider/google_gemini_api.py | 56 ++++++++++++++++++--------- metagpt/provider/zhipuai_api.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index a69ffdc28..0ba1e86c1 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -10,14 +10,35 @@ from tenacity import ( wait_fixed, ) import google.generativeai as genai -from google.generativeai import client +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from google.generativeai.generative_models import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse from google.generativeai.types.generation_types import GenerationConfig from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.openai_api import log_and_reraise +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) class GeminiGPTAPI(BaseGPTAPI): @@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI): def __init__(self): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model - self.llm = genai.GenerativeModel(model_name=self.model) + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("google gemini updats costs failed!", e) + logger.error(f"google gemini updats costs failed! exp: {e}") def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text def get_usage(self, messages: list[dict], resp_text: str) -> dict: - prompt_resp = self.llm.count_tokens(contents=messages) - completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI): return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: - # fix google-generativeai sdk - if self.llm._client is None: - self.llm._client = client.get_default_generative_client() - # TODO exception to fix - prompt_resp = await self.llm.count_tokens_async(contents=messages) - completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI): def completion(self, messages: list[dict]) -> "GenerateContentResponse": resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) - # usage = self.get_usage(messages, resp.text) - # self._update_costs(usage) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) return resp async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) - # usage = await self.aget_usage(messages, resp.text) - # self._update_costs(usage) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) return resp async def acompletion(self, messages: list[dict]) -> dict: @@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content.append(content) full_content = "".join(collected_content) - # usage = await self.aget_usage(messages, full_content) - # self._update_costs(usage) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 3161c0e88..3b24ca98f 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -65,7 +65,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """ get the first text of choice from llm response """ From 70cbfb1e480367bb9586a62b7b723c80c57aa4f0 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:30:25 +0800 Subject: [PATCH 04/12] retry use wait_random_exponential --- metagpt/provider/google_gemini_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 0ba1e86c1..b68e013a0 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -7,7 +7,7 @@ from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) import google.generativeai as genai from google.ai import generativelanguage as glm @@ -139,7 +139,7 @@ class GeminiGPTAPI(BaseGPTAPI): @retry( stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise From c4fbc478d22ee0a1794e619866164f37d322ee73 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:45:40 +0800 Subject: [PATCH 05/12] add google gemini --- config/config.yaml | 4 + metagpt/config.py | 6 +- metagpt/provider/google_gemini_api.py | 130 ++++++++++++++++++ metagpt/utils/token_counter.py | 3 + requirements.txt | 1 + .../provider/test_google_gemini_api.py | 43 ++++++ 6 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 metagpt/provider/google_gemini_api.py create mode 100644 tests/metagpt/provider/test_google_gemini_api.py diff --git a/config/config.yaml b/config/config.yaml index f547462ba..fc113370d 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -37,6 +37,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 131854a56..6ab537296 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -39,6 +39,7 @@ class LLMProviderEnum(Enum): ZHIPUAI = "zhipuai" FIREWORKS = "fireworks" OPEN_LLM = "open_llm" + GEMINI = "gemini" class Config(metaclass=Singleton): @@ -74,7 +75,8 @@ class Config(metaclass=Singleton): (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), - (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), + (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): if self.openai_api_model: @@ -96,6 +98,8 @@ class Config(metaclass=Singleton): self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + self.gemini_api_key = self._get("GEMINI_API_KEY") + _ = self.get_default_llm_provider_enum() self.openai_api_base = self._get("OPENAI_API_BASE") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..1c866ebad --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +import google.generativeai as genai +from google.generativeai import client +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + use_system_prompt: bool = False # google gemini has no system prompt when use api + + def __init__(self): + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = genai.GenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "contents": messages, + "generation_config": GenerationConfig( + temperature=0.3 + ), + "stream": stream + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("google gemini updats costs failed!", e) + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + prompt_resp = self.llm.count_tokens(contents=messages) + completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + # fix google-generativeai sdk + if self.llm._client is None: + self.llm._client = client.get_default_generative_client() + # TODO exception to fix + prompt_resp = await self.llm.count_tokens_async(contents=messages) + completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + # usage = self.get_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + # usage = await self.aget_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, + stream=True)) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + # usage = await self.aget_usage(messages, full_content) + # self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ebfb85de7..512ff784c 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -25,6 +26,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} } @@ -43,6 +45,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, + "gemini-pro": 32768 } diff --git a/requirements.txt b/requirements.txt index f5ef63c58..2b4e064ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,3 +49,4 @@ aiofiles==23.2.1 gitpython==3.1.40 zhipuai==1.0.7 gitignore-parser==0.1.9 +google-generativeai==0.3.1 diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..32ed11ba5 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +import pytest +from abc import ABC +from dataclasses import dataclass + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + + +messages = [ + {"role": "user", "content": "who are you"} +] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text From 91d1ab20cc21eccf5966cf507b08087af4cadda6 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 14 Dec 2023 16:54:56 +0800 Subject: [PATCH 06/12] update gemini user_msg doc --- metagpt/provider/google_gemini_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 1c866ebad..a69ffdc28 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -36,6 +36,8 @@ class GeminiGPTAPI(BaseGPTAPI): genai.configure(api_key=config.gemini_api_key) def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. return {"role": "user", "parts": [msg]} def _assistant_msg(self, msg: str) -> dict[str, str]: From 02090af7cb1a315b2b59ea843fa7aa8bb816cf4e Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:06:59 +0800 Subject: [PATCH 07/12] update gemini count_tokens --- metagpt/provider/google_gemini_api.py | 56 ++++++++++++++++++--------- metagpt/provider/zhipuai_api.py | 2 +- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index a69ffdc28..0ba1e86c1 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -10,14 +10,35 @@ from tenacity import ( wait_fixed, ) import google.generativeai as genai -from google.generativeai import client +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from google.generativeai.generative_models import GenerativeModel from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse from google.generativeai.types.generation_types import GenerationConfig from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.provider.openai_api import log_and_reraise +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async( + self, contents: content_types.ContentsType + ) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) class GeminiGPTAPI(BaseGPTAPI): @@ -30,7 +51,8 @@ class GeminiGPTAPI(BaseGPTAPI): def __init__(self): self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model - self.llm = genai.GenerativeModel(model_name=self.model) + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() def __init_gemini(self, config: CONFIG): genai.configure(api_key=config.gemini_api_key) @@ -61,14 +83,15 @@ class GeminiGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("google gemini updats costs failed!", e) + logger.error(f"google gemini updats costs failed! exp: {e}") def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text def get_usage(self, messages: list[dict], resp_text: str) -> dict: - prompt_resp = self.llm.count_tokens(contents=messages) - completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -76,12 +99,9 @@ class GeminiGPTAPI(BaseGPTAPI): return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: - # fix google-generativeai sdk - if self.llm._client is None: - self.llm._client = client.get_default_generative_client() - # TODO exception to fix - prompt_resp = await self.llm.count_tokens_async(contents=messages) - completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) usage = { "prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens @@ -90,14 +110,14 @@ class GeminiGPTAPI(BaseGPTAPI): def completion(self, messages: list[dict]) -> "GenerateContentResponse": resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) - # usage = self.get_usage(messages, resp.text) - # self._update_costs(usage) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) return resp async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) - # usage = await self.aget_usage(messages, resp.text) - # self._update_costs(usage) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) return resp async def acompletion(self, messages: list[dict]) -> dict: @@ -113,8 +133,8 @@ class GeminiGPTAPI(BaseGPTAPI): collected_content.append(content) full_content = "".join(collected_content) - # usage = await self.aget_usage(messages, full_content) - # self._update_costs(usage) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) return full_content @retry( diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index eef0e51e1..60d9a0777 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" From e5a7fdfe3b7168341a7b5b1903288fdbe99a7dd1 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 15 Dec 2023 17:30:25 +0800 Subject: [PATCH 08/12] retry use wait_random_exponential --- metagpt/provider/google_gemini_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 0ba1e86c1..b68e013a0 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -7,7 +7,7 @@ from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, - wait_fixed, + wait_random_exponential, ) import google.generativeai as genai from google.ai import generativelanguage as glm @@ -139,7 +139,7 @@ class GeminiGPTAPI(BaseGPTAPI): @retry( stop=stop_after_attempt(3), - wait=wait_fixed(1), + wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise From f3eb9f638efd3bdd08023a996985098959116dfd Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 12:55:45 +0800 Subject: [PATCH 09/12] add other llm for LLMProviderRegistry --- metagpt/config.py | 2 +- metagpt/provider/__init__.py | 13 +++++++++++-- metagpt/provider/google_gemini_api.py | 20 +++++++++++--------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 6ab537296..27d4488e0 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -79,7 +79,7 @@ class Config(metaclass=Singleton): (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): - if self.openai_api_model: + if self.openai_api_key and self.openai_api_model: logger.info(f"OpenAI API Model: {self.openai_api_model}") return v raise NotConfiguredException("You should config a LLM configuration first") diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 56dc19b4b..028c6f837 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -6,7 +6,16 @@ @File : __init__.py """ +from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI - -__all__ = ["OpenAIGPTAPI"] +__all__ = [ + "FireWorksGPTAPI", + "GeminiGPTAPI", + "OpenLLMGPTAPI", + "OpenAIGPTAPI", + "ZhiPuAIGPTAPI" +] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index b68e013a0..213b53263 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -2,6 +2,12 @@ # -*- coding: utf-8 -*- # @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart +import google.generativeai as genai +from google.ai import generativelanguage as glm +from google.generativeai.generative_models import GenerativeModel +from google.generativeai.types import content_types +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig from tenacity import ( after_log, retry, @@ -9,16 +15,11 @@ from tenacity import ( stop_after_attempt, wait_random_exponential, ) -import google.generativeai as genai -from google.ai import generativelanguage as glm -from google.generativeai.types import content_types -from google.generativeai.generative_models import GenerativeModel -from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse -from google.generativeai.types.generation_types import GenerationConfig -from metagpt.config import CONFIG +from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import CostManager, log_and_reraise @@ -29,18 +30,19 @@ class GeminiGenerativeModel(GenerativeModel): """ def count_tokens( - self, contents: content_types.ContentsType + self, contents: content_types.ContentsType ) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return self._client.count_tokens(model=self.model_name, contents=contents) async def count_tokens_async( - self, contents: content_types.ContentsType + self, contents: content_types.ContentsType ) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return await self._async_client.count_tokens(model=self.model_name, contents=contents) +@register_provider(LLMProviderEnum.GEMINI) class GeminiGPTAPI(BaseGPTAPI): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` From bdb427d5b785222701ef2e49c09bb0a2a1b40654 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 14:18:50 +0800 Subject: [PATCH 10/12] add gemini minimal python version warning --- metagpt/config.py | 5 +++++ metagpt/provider/google_gemini_api.py | 3 +-- metagpt/utils/common.py | 9 ++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index 27d4488e0..727b37b9c 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -7,6 +7,7 @@ Provide configuration, singleton 2. Add the parameter `src_workspace` for the old version project path. """ import os +import warnings from copy import deepcopy from enum import Enum from pathlib import Path @@ -17,6 +18,7 @@ import yaml from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.common import require_python_version from metagpt.utils.singleton import Singleton @@ -79,6 +81,9 @@ class Config(metaclass=Singleton): (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): + logger.info(f"Use LLMProvider: {v.value}") + if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): + warnings.warn("Use Gemini requires Python >= 3.10") if self.openai_api_key and self.openai_api_model: logger.info(f"OpenAI API Model: {self.openai_api_model}") return v diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 213b53263..10215e2d9 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -48,9 +48,8 @@ class GeminiGPTAPI(BaseGPTAPI): Refs to `https://ai.google.dev/tutorials/python_quickstart` """ - use_system_prompt: bool = False # google gemini has no system prompt when use api - def __init__(self): + self.use_system_prompt = False # google gemini has no system prompt when use api self.__init_gemini(CONFIG) self.model = "gemini-pro" # so far only one model self.llm = GeminiGenerativeModel(model_name=self.model) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e5d4573e8..eec4176df 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -19,6 +19,7 @@ import json import os import platform import re +import sys import traceback import typing from pathlib import Path @@ -47,6 +48,12 @@ def check_cmd_exists(command) -> int: return result +def require_python_version(req_version: tuple[int]) -> bool: + if not (2 <= len(req_version) <= 3): + raise ValueError("req_version should be (3, 9) or (3, 10, 13)") + return True if sys.version_info > req_version else False + + class OutputParser: @classmethod def parse_blocks(cls, text: str): @@ -219,7 +226,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index : end_index + 1] + structure_text = text[start_index: end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval From 18a195a3678dd5c23c9666a57742eeb5bdec943a Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 14:46:33 +0800 Subject: [PATCH 11/12] update config --- metagpt/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/metagpt/config.py b/metagpt/config.py index be0d6ec41..963fe3b05 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -81,7 +81,7 @@ class Config(metaclass=Singleton): (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): - logger.info(f"Use LLMProvider: {v.value}") + # logger.debug(f"Use LLMProvider: {v.value}") if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): warnings.warn("Use Gemini requires Python >= 3.10") if self.openai_api_key and self.openai_api_model: @@ -94,7 +94,6 @@ class Config(metaclass=Singleton): return k and k != "YOUR_API_KEY" def _update(self): - # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") From 6af9fecf65cb80f35d8fb1d56d6a6a01fe3504a5 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 21 Dec 2023 15:06:59 +0800 Subject: [PATCH 12/12] fix format --- metagpt/provider/__init__.py | 8 +--- metagpt/provider/google_gemini_api.py | 44 +++++++------------ metagpt/roles/researcher.py | 2 +- metagpt/tools/web_browser_engine_selenium.py | 4 +- metagpt/utils/common.py | 2 +- metagpt/utils/token_counter.py | 4 +- .../provider/test_google_gemini_api.py | 8 ++-- 7 files changed, 26 insertions(+), 46 deletions(-) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 028c6f837..a9f46eb03 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -12,10 +12,4 @@ from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -__all__ = [ - "FireWorksGPTAPI", - "GeminiGPTAPI", - "OpenLLMGPTAPI", - "OpenAIGPTAPI", - "ZhiPuAIGPTAPI" -] +__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 631da1052..682f7b507 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -6,8 +6,11 @@ import google.generativeai as genai from google.ai import generativelanguage as glm from google.generativeai.generative_models import GenerativeModel from google.generativeai.types import content_types -from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse -from google.generativeai.types.generation_types import GenerationConfig +from google.generativeai.types.generation_types import ( + AsyncGenerateContentResponse, + GenerateContentResponse, + GenerationConfig, +) from tenacity import ( after_log, retry, @@ -29,15 +32,11 @@ class GeminiGenerativeModel(GenerativeModel): Will use default GenerativeModel if it fixed. """ - def count_tokens( - self, contents: content_types.ContentsType - ) -> glm.CountTokensResponse: + def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return self._client.count_tokens(model=self.model_name, contents=contents) - async def count_tokens_async( - self, contents: content_types.ContentsType - ) -> glm.CountTokensResponse: + async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return await self._async_client.count_tokens(model=self.model_name, contents=contents) @@ -68,17 +67,11 @@ class GeminiGPTAPI(BaseGPTAPI): return {"role": "model", "parts": [msg]} def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = { - "contents": messages, - "generation_config": GenerationConfig( - temperature=0.3 - ), - "stream": stream - } + kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs def _update_costs(self, usage: dict): - """ update each request's token cost """ + """update each request's token cost""" if CONFIG.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) @@ -94,20 +87,14 @@ class GeminiGPTAPI(BaseGPTAPI): req_text = messages[-1]["parts"][0] if messages else "" prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) - usage = { - "prompt_tokens": prompt_resp.total_tokens, - "completion_tokens": completion_resp.total_tokens - } + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: req_text = messages[-1]["parts"][0] if messages else "" prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) - usage = { - "prompt_tokens": prompt_resp.total_tokens, - "completion_tokens": completion_resp.total_tokens - } + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} return usage def completion(self, messages: list[dict]) -> "GenerateContentResponse": @@ -126,8 +113,9 @@ class GeminiGPTAPI(BaseGPTAPI): return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: - resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, - stream=True)) + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( + **self._const_kwargs(messages, stream=True) + ) collected_content = [] async for chunk in resp: content = chunk.text @@ -144,10 +132,10 @@ class GeminiGPTAPI(BaseGPTAPI): wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise + retry_error_callback=log_and_reraise, ) async def acompletion_text(self, messages: list[dict], stream=False) -> str: - """ response in async with stream or non-stream mode """ + """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) resp = await self._achat_completion(messages) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index 52c55f0ca..e894d1a57 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -70,7 +70,7 @@ class Researcher(Role): return ret def research_system_text(self, topic, current_task: Action) -> str: - """ BACKWARD compatible + """BACKWARD compatible This allows sub-class able to define its own system prompt based on topic. return the previous implementation to have backward compatible Args: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 074943892..decab2b7d 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -106,8 +106,8 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): options.add_argument("--headless") options.add_argument("--enable-javascript") if browser_type == "chrome": - options.add_argument("--disable-gpu") # This flag can help avoid renderer issue - options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--no-sandbox") for i in args: options.add_argument(i) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index eec4176df..8db7a80a1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -226,7 +226,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index: end_index + 1] + structure_text = text[start_index : end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 512ff784c..c29fa7d43 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -26,7 +26,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens - "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } @@ -45,7 +45,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, - "gemini-pro": 32768 + "gemini-pro": 32768, } diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 32ed11ba5..229d9b9a7 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -2,16 +2,14 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of google gemini api -import pytest from abc import ABC from dataclasses import dataclass +import pytest + from metagpt.provider.google_gemini_api import GeminiGPTAPI - -messages = [ - {"role": "user", "content": "who are you"} -] +messages = [{"role": "user", "content": "who are you"}] @dataclass