diff --git a/config/config.yaml b/config/config.yaml index 1f5b85c21..e724897ee 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -35,6 +35,10 @@ RPM: 10 #### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY" # ZHIPUAI_API_KEY: "YOUR_API_KEY" +#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`. +#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY" +# GEMINI_API_KEY: "YOUR_API_KEY" + #### if use self-host open llm model with openai-compatible interface #OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1" #OPEN_LLM_API_MODEL: "llama2-13b" diff --git a/metagpt/config.py b/metagpt/config.py index 279f929fd..5176a7677 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -7,6 +7,7 @@ Provide configuration, singleton 2. Add the parameter `src_workspace` for the old version project path. """ import os +import warnings from copy import deepcopy from enum import Enum from pathlib import Path @@ -17,6 +18,7 @@ import yaml from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS from metagpt.logs import logger from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.common import require_python_version from metagpt.utils.singleton import Singleton @@ -39,6 +41,7 @@ class LLMProviderEnum(Enum): ZHIPUAI = "zhipuai" FIREWORKS = "fireworks" OPEN_LLM = "open_llm" + GEMINI = "gemini" class Config(metaclass=Singleton): @@ -74,10 +77,14 @@ class Config(metaclass=Singleton): (self.anthropic_api_key, LLMProviderEnum.ANTHROPIC), (self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI), (self.fireworks_api_key, LLMProviderEnum.FIREWORKS), - (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key + (self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), + (self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key ]: if self._is_valid_llm_key(k): - if self.openai_api_model: + # logger.debug(f"Use LLMProvider: {v.value}") + if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)): + warnings.warn("Use Gemini requires Python >= 3.10") + if self.openai_api_key and self.openai_api_model: logger.info(f"OpenAI API Model: {self.openai_api_model}") return v raise NotConfiguredException("You should config a LLM configuration first") @@ -87,7 +94,6 @@ class Config(metaclass=Singleton): return k and k != "YOUR_API_KEY" def _update(self): - # logger.info("Config loading done.") self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") @@ -96,6 +102,7 @@ class Config(metaclass=Singleton): self.open_llm_api_base = self._get("OPEN_LLM_API_BASE") self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL") self.fireworks_api_key = self._get("FIREWORKS_API_KEY") + self.gemini_api_key = self._get("GEMINI_API_KEY") _ = self.get_default_llm_provider_enum() self.openai_base_url = self._get("OPENAI_BASE_URL") diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 56dc19b4b..a9f46eb03 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -6,7 +6,10 @@ @File : __init__.py """ +from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.open_llm_api import OpenLLMGPTAPI from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI - -__all__ = ["OpenAIGPTAPI"] +__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py new file mode 100644 index 000000000..682f7b507 --- /dev/null +++ b/metagpt/provider/google_gemini_api.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart + +import google.generativeai as genai +from google.ai import generativelanguage as glm +from google.generativeai.generative_models import GenerativeModel +from google.generativeai.types import content_types +from google.generativeai.types.generation_types import ( + AsyncGenerateContentResponse, + GenerateContentResponse, + GenerationConfig, +) +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import CostManager, log_and_reraise + + +class GeminiGenerativeModel(GenerativeModel): + """ + Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class. + Will use default GenerativeModel if it fixed. + """ + + def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return self._client.count_tokens(model=self.model_name, contents=contents) + + async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + contents = content_types.to_contents(contents) + return await self._async_client.count_tokens(model=self.model_name, contents=contents) + + +@register_provider(LLMProviderEnum.GEMINI) +class GeminiGPTAPI(BaseGPTAPI): + """ + Refs to `https://ai.google.dev/tutorials/python_quickstart` + """ + + def __init__(self): + self.use_system_prompt = False # google gemini has no system prompt when use api + + self.__init_gemini(CONFIG) + self.model = "gemini-pro" # so far only one model + self.llm = GeminiGenerativeModel(model_name=self.model) + self._cost_manager = CostManager() + + def __init_gemini(self, config: CONFIG): + genai.configure(api_key=config.gemini_api_key) + + def _user_msg(self, msg: str) -> dict[str, str]: + # Not to change BaseGPTAPI default functions but update with Gemini's conversation format. + # You should follow the format. + return {"role": "user", "parts": [msg]} + + def _assistant_msg(self, msg: str) -> dict[str, str]: + return {"role": "model", "parts": [msg]} + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} + return kwargs + + def _update_costs(self, usage: dict): + """update each request's token cost""" + if CONFIG.calc_usage: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) + except Exception as e: + logger.error(f"google gemini updats costs failed! exp: {e}") + + def get_choice_text(self, resp: GenerateContentResponse) -> str: + return resp.text + + def get_usage(self, messages: list[dict], resp_text: str) -> dict: + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]}) + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} + return usage + + async def aget_usage(self, messages: list[dict], resp_text: str) -> dict: + req_text = messages[-1]["parts"][0] if messages else "" + prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]}) + completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]}) + usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens} + return usage + + def completion(self, messages: list[dict]) -> "GenerateContentResponse": + resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages)) + usage = self.get_usage(messages, resp.text) + self._update_costs(usage) + return resp + + async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) + usage = await self.aget_usage(messages, resp.text) + self._update_costs(usage) + return resp + + async def acompletion(self, messages: list[dict]) -> dict: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( + **self._const_kwargs(messages, stream=True) + ) + collected_content = [] + async for chunk in resp: + content = chunk.text + print(content, end="") + collected_content.append(content) + + full_content = "".join(collected_content) + usage = await self.aget_usage(messages, full_content) + self._update_costs(usage) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False) -> str: + """response in async with stream or non-stream mode""" + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index eef0e51e1..60d9a0777 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): completion_tokens = int(usage.get("completion_tokens", 0)) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: - logger.error("zhipuai updats costs failed!", e) + logger.error(f"zhipuai updats costs failed! exp: {e}") def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e5d4573e8..8db7a80a1 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -19,6 +19,7 @@ import json import os import platform import re +import sys import traceback import typing from pathlib import Path @@ -47,6 +48,12 @@ def check_cmd_exists(command) -> int: return result +def require_python_version(req_version: tuple[int]) -> bool: + if not (2 <= len(req_version) <= 3): + raise ValueError("req_version should be (3, 9) or (3, 10, 13)") + return True if sys.version_info > req_version else False + + class OutputParser: @classmethod def parse_blocks(cls, text: str): diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index af49845be..94b8d76d2 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -7,6 +7,7 @@ ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py +ref4: https://ai.google.dev/models/gemini """ import tiktoken @@ -27,6 +28,7 @@ TOKEN_COSTS = { "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0}, "chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens + "gemini-pro": {"prompt": 0.00025, "completion": 0.0005}, } @@ -47,6 +49,7 @@ TOKEN_MAX = { "gpt-4-1106-preview": 128000, "text-embedding-ada-002": 8192, "chatglm_turbo": 32768, + "gemini-pro": 32768, } diff --git a/requirements.txt b/requirements.txt index be7c477bf..e14c6bd3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,3 +49,4 @@ aiofiles==23.2.1 gitpython==3.1.40 zhipuai==1.0.7 gitignore-parser==0.1.9 +google-generativeai==0.3.1 diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py new file mode 100644 index 000000000..229d9b9a7 --- /dev/null +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of google gemini api + +from abc import ABC +from dataclasses import dataclass + +import pytest + +from metagpt.provider.google_gemini_api import GeminiGPTAPI + +messages = [{"role": "user", "content": "who are you"}] + + +@dataclass +class MockGeminiResponse(ABC): + text: str + + +default_resp = MockGeminiResponse(text="I'm gemini from google") + + +def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: + return default_resp + + +def test_gemini_completion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) + resp = GeminiGPTAPI().completion(messages) + assert resp.text == default_resp.text + + +async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: + return default_resp + + +@pytest.mark.asyncio +async def test_gemini_acompletion(mocker): + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) + resp = await GeminiGPTAPI().acompletion(messages) + assert resp.text == default_resp.text