diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 028c6f837..a9f46eb03 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -12,10 +12,4 @@ from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -__all__ = [ - "FireWorksGPTAPI", - "GeminiGPTAPI", - "OpenLLMGPTAPI", - "OpenAIGPTAPI", - "ZhiPuAIGPTAPI" -] +__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 631da1052..682f7b507 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -6,8 +6,11 @@ import google.generativeai as genai from google.ai import generativelanguage as glm from google.generativeai.generative_models import GenerativeModel from google.generativeai.types import content_types -from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse -from google.generativeai.types.generation_types import GenerationConfig +from google.generativeai.types.generation_types import ( + AsyncGenerateContentResponse, + GenerateContentResponse, + GenerationConfig, +) from tenacity import ( after_log, retry, @@ -29,15 +32,11 @@ class GeminiGenerativeModel(GenerativeModel): Will use default GenerativeModel if it fixed. """ - def count_tokens( - self, contents: content_types.ContentsType - ) -> glm.CountTokensResponse: + def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return self._client.count_tokens(model=self.model_name, contents=contents) - async def count_tokens_async( - self, contents: content_types.ContentsType - ) -> glm.CountTokensResponse: + async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: contents = content_types.to_contents(contents) return await self._async_client.count_tokens(model=self.model_name, contents=contents) @@ -68,17 +67,11 @@ class GeminiGPTAPI(BaseGPTAPI): return {"role": "model", "parts": [msg]} def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = { - "contents": messages, - "generation_config": GenerationConfig( - temperature=0.3 - ), - "stream": stream - } + kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs def _update_costs(self, usage: dict): - """ update each request's token cost """ + """update each request's token cost""" if CONFIG.calc_usage: try: prompt_tokens = int(usage.get("prompt_tokens", 0)) @@ -94,20 +87,14 @@ class GeminiGPTAPI(BaseGPTAPI): req_text = messages[-1]["parts"][0] if messages else "" prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) - usage = { - "prompt_tokens": prompt_resp.total_tokens, - "completion_tokens": completion_resp.total_tokens - } + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} return usage async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: req_text = messages[-1]["parts"][0] if messages else "" prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) - usage = { - "prompt_tokens": prompt_resp.total_tokens, - "completion_tokens": completion_resp.total_tokens - } + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} return usage def completion(self, messages: list[dict]) -> "GenerateContentResponse": @@ -126,8 +113,9 @@ class GeminiGPTAPI(BaseGPTAPI): return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: - resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, - stream=True)) + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( + **self._const_kwargs(messages, stream=True) + ) collected_content = [] async for chunk in resp: content = chunk.text @@ -144,10 +132,10 @@ class GeminiGPTAPI(BaseGPTAPI): wait=wait_random_exponential(min=1, max=60), after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise + retry_error_callback=log_and_reraise, ) async def acompletion_text(self, messages: list[dict], stream=False) -> str: - """ response in async with stream or non-stream mode """ + """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) resp = await self._achat_completion(messages) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index 52c55f0ca..e894d1a57 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -70,7 +70,7 @@ class Researcher(Role): return ret def research_system_text(self, topic, current_task: Action) -> str: - """ BACKWARD compatible + """BACKWARD compatible This allows sub-class able to define its own system prompt based on topic. return the previous implementation to have backward compatible Args: diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 074943892..decab2b7d 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -106,8 +106,8 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): options.add_argument("--headless") options.add_argument("--enable-javascript") if browser_type == "chrome": - options.add_argument("--disable-gpu") # This flag can help avoid renderer issue - options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems + options.add_argument("--disable-gpu") # This flag can help avoid renderer issue + options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems options.add_argument("--no-sandbox") for i in args: options.add_argument(i) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index eec4176df..8db7a80a1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -226,7 +226,7 @@ class OutputParser: if start_index != -1 and end_index != -1: # Extract the structure part - structure_text = text[start_index: end_index + 1] + structure_text = text[start_index : end_index + 1] try: # Attempt to convert the text to a Python data type using ast.literal_eval diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 512ff784c..c29fa7d43 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -26,7 +26,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens - "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } @@ -45,7 +45,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, - "gemini-pro": 32768 + "gemini-pro": 32768, } diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 32ed11ba5..229d9b9a7 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -2,16 +2,14 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of google gemini api -import pytest from abc import ABC from dataclasses import dataclass +import pytest + from metagpt.provider.google_gemini_api import GeminiGPTAPI - -messages = [ - {"role": "user", "content": "who are you"} -] +messages = [{"role": "user", "content": "who are you"}] @dataclass