add google gemini

This commit is contained in:
better629 2023-12-14 16:45:40 +08:00
parent 820315b916
commit c4fbc478d2
6 changed files with 186 additions and 1 deletions

View file

@ -37,6 +37,10 @@ RPM: 10
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
# GEMINI_API_KEY: "YOUR_API_KEY"
#### if use self-host open llm model with openai-compatible interface
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
#OPEN_LLM_API_MODEL: "llama2-13b"

View file

@ -39,6 +39,7 @@ class LLMProviderEnum(Enum):
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
GEMINI = "gemini"
class Config(metaclass=Singleton):
@ -74,7 +75,8 @@ class Config(metaclass=Singleton):
(self.anthropic_api_key, LLMProviderEnum.ANTHROPIC),
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
(self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key
]:
if self._is_valid_llm_key(k):
if self.openai_api_model:
@ -96,6 +98,8 @@ class Config(metaclass=Singleton):
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
self.gemini_api_key = self._get("GEMINI_API_KEY")
_ = self.get_default_llm_provider_enum()
self.openai_api_base = self._get("OPENAI_API_BASE")

View file

@ -0,0 +1,130 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_fixed,
)
import google.generativeai as genai
from google.generativeai import client
from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse
from google.generativeai.types.generation_types import GenerationConfig
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.openai_api import log_and_reraise
class GeminiGPTAPI(BaseGPTAPI):
"""
Refs to `https://ai.google.dev/tutorials/python_quickstart`
"""
use_system_prompt: bool = False # google gemini has no system prompt when use api
def __init__(self):
self.__init_gemini(CONFIG)
self.model = "gemini-pro" # so far only one model
self.llm = genai.GenerativeModel(model_name=self.model)
def __init_gemini(self, config: CONFIG):
genai.configure(api_key=config.gemini_api_key)
def _user_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "parts": [msg]}
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "model", "parts": [msg]}
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"contents": messages,
"generation_config": GenerationConfig(
temperature=0.3
),
"stream": stream
}
return kwargs
def _update_costs(self, usage: dict):
""" update each request's token cost """
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("google gemini updats costs failed!", e)
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text
def get_usage(self, messages: list[dict], resp_text: str) -> dict:
prompt_resp = self.llm.count_tokens(contents=messages)
completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]})
usage = {
"prompt_tokens": prompt_resp.total_tokens,
"completion_tokens": completion_resp.total_tokens
}
return usage
async def aget_usage(self, messages: list[dict], resp_text: str) -> dict:
# fix google-generativeai sdk
if self.llm._client is None:
self.llm._client = client.get_default_generative_client()
# TODO exception to fix
prompt_resp = await self.llm.count_tokens_async(contents=messages)
completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]})
usage = {
"prompt_tokens": prompt_resp.total_tokens,
"completion_tokens": completion_resp.total_tokens
}
return usage
def completion(self, messages: list[dict]) -> "GenerateContentResponse":
resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages))
# usage = self.get_usage(messages, resp.text)
# self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
# usage = await self.aget_usage(messages, resp.text)
# self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages,
stream=True))
collected_content = []
async for chunk in resp:
content = chunk.text
print(content, end="")
collected_content.append(content)
full_content = "".join(collected_content)
# usage = await self.aget_usage(messages, full_content)
# self._update_costs(usage)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(1),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
""" response in async with stream or non-stream mode """
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -7,6 +7,7 @@
ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py
ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
ref4: https://ai.google.dev/models/gemini
"""
import tiktoken
@ -25,6 +26,7 @@ TOKEN_COSTS = {
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005}
}
@ -43,6 +45,7 @@ TOKEN_MAX = {
"gpt-4-1106-preview": 128000,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768,
"gemini-pro": 32768
}

View file

@ -49,3 +49,4 @@ aiofiles==23.2.1
gitpython==3.1.40
zhipuai==1.0.7
gitignore-parser==0.1.9
google-generativeai==0.3.1

View file

@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of google gemini api
import pytest
from abc import ABC
from dataclasses import dataclass
from metagpt.provider.google_gemini_api import GeminiGPTAPI
messages = [
{"role": "user", "content": "who are you"}
]
@dataclass
class MockGeminiResponse(ABC):
text: str
default_resp = MockGeminiResponse(text="I'm gemini from google")
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
return default_resp
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
resp = GeminiGPTAPI().completion(messages)
assert resp.text == default_resp.text
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
return default_resp
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
resp = await GeminiGPTAPI().acompletion(messages)
assert resp.text == default_resp.text