diff --git a/config/config.yaml b/config/config.yaml index 080de4000..596a31341 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -34,6 +34,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 2ce75b013..3b46c8504 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -51,13 +51,17 @@ class Config(metaclass=Singleton): self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + + self.gemini_api_key = self._get("GEMINI_API_KEY") + if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \ (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \ (not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \ (not self.open_llm_api_base) and \ - (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key): + (not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) and \ + (not self.gemini_api_key or "YOUR_API_KEY" in self.gemini_api_key): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first " - "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE") + "or FIREWORKS_API_KEY or OPEN_LLM_API_BASE or GEMINI_API_KEY") self.openai_api_base = self._get("OPENAI_API_BASE") openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy if openai_proxy: diff --git a/metagpt/llm.py b/metagpt/llm.py index 7b490ec4a..b13fc723a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -14,6 +14,7 @@ from metagpt.provider.spark_api import SparkAPI from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.fireworks_api import FireWorksGPTAPI from metagpt.provider.human_provider import HumanProvider +from metagpt.provider.google_gemini_api import GeminiGPTAPI def LLM() -> "BaseGPTAPI": @@ -29,6 +30,8 @@ def LLM() -> "BaseGPTAPI": llm = OpenLLMGPTAPI() elif CONFIG.fireworks_api_key: llm = FireWorksGPTAPI() + elif CONFIG.gemini_api_key: + llm = GeminiGPTAPI() else: raise RuntimeError("You should config a LLM configuration first") diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..1c866ebad --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) +import google.generativeai as genai +from google.generativeai import client +from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse +from google.generativeai.types.generation_types import GenerationConfig + +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.openai_api import log_and_reraise + + +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + use_system_prompt: bool = False # google gemini has no system prompt when use api + + def __init__(self): + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = genai.GenerativeModel(model_name=self.model) + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "contents": messages, + "generation_config": GenerationConfig( + temperature=0.3 + ), + "stream": stream + } + return kwargs + + def _update_costs(self, usage: dict): + """ update each request's token cost """ + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error("google gemini updats costs failed!", e) + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + prompt_resp = self.llm.count_tokens(contents=messages) + completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + # fix google-generativeai sdk + if self.llm._client is None: + self.llm._client = client.get_default_generative_client() + # TODO exception to fix + prompt_resp = await self.llm.count_tokens_async(contents=messages) + completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]}) + usage = { + "prompt_tokens": prompt_resp.total_tokens, + "completion_tokens": completion_resp.total_tokens + } + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + # usage = self.get_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + # usage = await self.aget_usage(messages, resp.text) + # self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages, + stream=True)) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + # usage = await self.aget_usage(messages, full_content) + # self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """ response in async with stream or non-stream mode """ + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index ba63e90a9..6d9cbd137 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -24,7 +25,8 @@ TOKEN_COSTS = { "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, - "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens + "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005} } @@ -42,7 +44,8 @@ TOKEN_MAX = { "gpt-4-0613": 8192, "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, - "chatglm_turbo": 32768 + "chatglm_turbo": 32768, + "gemini-pro": 32768 } diff --git a/requirements.txt b/requirements.txt index 14a9f485d..a2aaff48b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -45,3 +45,4 @@ semantic-kernel==0.3.13.dev0 wrapt==1.15.0 websocket-client==0.58.0 zhipuai==1.0.7 +google-generativeai==0.3.1 \ No newline at end of file diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..32ed11ba5 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +import pytest +from abc import ABC +from dataclasses import dataclass + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + + +messages = [ + {"role": "user", "content": "who are you"} +] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text