mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-06 06:12:39 +02:00
add google gemini
This commit is contained in:
parent
183d939ee8
commit
bef8d64193
7 changed files with 192 additions and 4 deletions
|
|
@ -51,13 +51,17 @@ class Config(metaclass=Singleton):
|
|||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
|
||||
self.gemini_api_key = self._get("GEMINI_API_KEY")
|
||||
|
||||
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \
|
||||
(not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \
|
||||
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \
|
||||
(not self.open_llm_api_base) and \
|
||||
(not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key):
|
||||
(not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key) and \
|
||||
(not self.gemini_api_key or "YOUR_API_KEY" in self.gemini_api_key):
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first "
|
||||
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE")
|
||||
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE or GEMINI_API_KEY")
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
if openai_proxy:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from metagpt.provider.spark_api import SparkAPI
|
|||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
|
||||
|
||||
def LLM() -> "BaseGPTAPI":
|
||||
|
|
@ -29,6 +30,8 @@ def LLM() -> "BaseGPTAPI":
|
|||
llm = OpenLLMGPTAPI()
|
||||
elif CONFIG.fireworks_api_key:
|
||||
llm = FireWorksGPTAPI()
|
||||
elif CONFIG.gemini_api_key:
|
||||
llm = GeminiGPTAPI()
|
||||
else:
|
||||
raise RuntimeError("You should config a LLM configuration first")
|
||||
|
||||
|
|
|
|||
130
metagpt/provider/google_gemini_api.py
Normal file
130
metagpt/provider/google_gemini_api.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
|
||||
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_fixed,
|
||||
)
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import client
|
||||
from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse
|
||||
from google.generativeai.types.generation_types import GenerationConfig
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
||||
class GeminiGPTAPI(BaseGPTAPI):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
||||
use_system_prompt: bool = False # google gemini has no system prompt when use api
|
||||
|
||||
def __init__(self):
|
||||
self.__init_gemini(CONFIG)
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = genai.GenerativeModel(model_name=self.model)
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "model", "parts": [msg]}
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"contents": messages,
|
||||
"generation_config": GenerationConfig(
|
||||
temperature=0.3
|
||||
),
|
||||
"stream": stream
|
||||
}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
""" update each request's token cost """
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("google gemini updats costs failed!", e)
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
def get_usage(self, messages: list[dict], resp_text: str) -> dict:
|
||||
prompt_resp = self.llm.count_tokens(contents=messages)
|
||||
completion_resp = self.llm.count_tokens(contents={"parts": [resp_text]})
|
||||
usage = {
|
||||
"prompt_tokens": prompt_resp.total_tokens,
|
||||
"completion_tokens": completion_resp.total_tokens
|
||||
}
|
||||
return usage
|
||||
|
||||
async def aget_usage(self, messages: list[dict], resp_text: str) -> dict:
|
||||
# fix google-generativeai sdk
|
||||
if self.llm._client is None:
|
||||
self.llm._client = client.get_default_generative_client()
|
||||
# TODO exception to fix
|
||||
prompt_resp = await self.llm.count_tokens_async(contents=messages)
|
||||
completion_resp = await self.llm.count_tokens_async(contents={"parts": [resp_text]})
|
||||
usage = {
|
||||
"prompt_tokens": prompt_resp.total_tokens,
|
||||
"completion_tokens": completion_resp.total_tokens
|
||||
}
|
||||
return usage
|
||||
|
||||
def completion(self, messages: list[dict]) -> "GenerateContentResponse":
|
||||
resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages))
|
||||
# usage = self.get_usage(messages, resp.text)
|
||||
# self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
# usage = await self.aget_usage(messages, resp.text)
|
||||
# self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages,
|
||||
stream=True))
|
||||
collected_content = []
|
||||
async for chunk in resp:
|
||||
content = chunk.text
|
||||
print(content, end="")
|
||||
collected_content.append(content)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
# usage = await self.aget_usage(messages, full_content)
|
||||
# self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(1),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
""" response in async with stream or non-stream mode """
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py
|
||||
ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
|
||||
ref4: https://ai.google.dev/models/gemini
|
||||
"""
|
||||
import tiktoken
|
||||
|
||||
|
|
@ -24,7 +25,8 @@ TOKEN_COSTS = {
|
|||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069} # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005}
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -42,7 +44,8 @@ TOKEN_MAX = {
|
|||
"gpt-4-0613": 8192,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768
|
||||
"chatglm_turbo": 32768,
|
||||
"gemini-pro": 32768
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue