mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
commit
a4843cd974
9 changed files with 214 additions and 6 deletions
|
|
@ -35,6 +35,10 @@ RPM: 10
|
|||
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
|
||||
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
|
||||
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
|
||||
# GEMINI_API_KEY: "YOUR_API_KEY"
|
||||
|
||||
#### if use self-host open llm model with openai-compatible interface
|
||||
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
|
||||
#OPEN_LLM_API_MODEL: "llama2-13b"
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ Provide configuration, singleton
|
|||
2. Add the parameter `src_workspace` for the old version project path.
|
||||
"""
|
||||
import os
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
|
@ -17,6 +18,7 @@ import yaml
|
|||
from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT, OPTIONS
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.common import require_python_version
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
|
|
@ -39,6 +41,7 @@ class LLMProviderEnum(Enum):
|
|||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
GEMINI = "gemini"
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
|
|
@ -74,10 +77,14 @@ class Config(metaclass=Singleton):
|
|||
(self.anthropic_api_key, LLMProviderEnum.ANTHROPIC),
|
||||
(self.zhipuai_api_key, LLMProviderEnum.ZHIPUAI),
|
||||
(self.fireworks_api_key, LLMProviderEnum.FIREWORKS),
|
||||
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM), # reuse logic. but not a key
|
||||
(self.open_llm_api_base, LLMProviderEnum.OPEN_LLM),
|
||||
(self.gemini_api_key, LLMProviderEnum.GEMINI), # reuse logic. but not a key
|
||||
]:
|
||||
if self._is_valid_llm_key(k):
|
||||
if self.openai_api_model:
|
||||
# logger.debug(f"Use LLMProvider: {v.value}")
|
||||
if v == LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
|
||||
warnings.warn("Use Gemini requires Python >= 3.10")
|
||||
if self.openai_api_key and self.openai_api_model:
|
||||
logger.info(f"OpenAI API Model: {self.openai_api_model}")
|
||||
return v
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
|
|
@ -87,7 +94,6 @@ class Config(metaclass=Singleton):
|
|||
return k and k != "YOUR_API_KEY"
|
||||
|
||||
def _update(self):
|
||||
# logger.info("Config loading done.")
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
|
||||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
|
|
@ -96,6 +102,7 @@ class Config(metaclass=Singleton):
|
|||
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
|
||||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
self.gemini_api_key = self._get("GEMINI_API_KEY")
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
|
||||
__all__ = ["OpenAIGPTAPI"]
|
||||
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"]
|
||||
|
|
|
|||
142
metagpt/provider/google_gemini_api.py
Normal file
142
metagpt/provider/google_gemini_api.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
|
||||
|
||||
import google.generativeai as genai
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.generativeai.generative_models import GenerativeModel
|
||||
from google.generativeai.types import content_types
|
||||
from google.generativeai.types.generation_types import (
|
||||
AsyncGenerateContentResponse,
|
||||
GenerateContentResponse,
|
||||
GenerationConfig,
|
||||
)
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
"""
|
||||
Due to `https://github.com/google/generative-ai-python/pull/123`, inherit a new class.
|
||||
Will use default GenerativeModel if it fixed.
|
||||
"""
|
||||
|
||||
def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
contents = content_types.to_contents(contents)
|
||||
return self._client.count_tokens(model=self.model_name, contents=contents)
|
||||
|
||||
async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
contents = content_types.to_contents(contents)
|
||||
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
class GeminiGPTAPI(BaseGPTAPI):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_system_prompt = False # google gemini has no system prompt when use api
|
||||
|
||||
self.__init_gemini(CONFIG)
|
||||
self.model = "gemini-pro" # so far only one model
|
||||
self.llm = GeminiGenerativeModel(model_name=self.model)
|
||||
self._cost_manager = CostManager()
|
||||
|
||||
def __init_gemini(self, config: CONFIG):
|
||||
genai.configure(api_key=config.gemini_api_key)
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
# Not to change BaseGPTAPI default functions but update with Gemini's conversation format.
|
||||
# You should follow the format.
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "model", "parts": [msg]}
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
||||
def _update_costs(self, usage: dict):
|
||||
"""update each request's token cost"""
|
||||
if CONFIG.calc_usage:
|
||||
try:
|
||||
prompt_tokens = int(usage.get("prompt_tokens", 0))
|
||||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
def get_usage(self, messages: list[dict], resp_text: str) -> dict:
|
||||
req_text = messages[-1]["parts"][0] if messages else ""
|
||||
prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]})
|
||||
completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]})
|
||||
usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens}
|
||||
return usage
|
||||
|
||||
async def aget_usage(self, messages: list[dict], resp_text: str) -> dict:
|
||||
req_text = messages[-1]["parts"][0] if messages else ""
|
||||
prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]})
|
||||
completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]})
|
||||
usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens}
|
||||
return usage
|
||||
|
||||
def completion(self, messages: list[dict]) -> "GenerateContentResponse":
|
||||
resp: GenerateContentResponse = self.llm.generate_content(**self._const_kwargs(messages))
|
||||
usage = self.get_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
usage = await self.aget_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
|
||||
**self._const_kwargs(messages, stream=True)
|
||||
)
|
||||
collected_content = []
|
||||
async for chunk in resp:
|
||||
content = chunk.text
|
||||
print(content, end="")
|
||||
collected_content.append(content)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
usage = await self.aget_usage(messages, full_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
@ -63,7 +63,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
completion_tokens = int(usage.get("completion_tokens", 0))
|
||||
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("zhipuai updats costs failed!", e)
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import json
|
|||
import os
|
||||
import platform
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
|
@ -47,6 +48,12 @@ def check_cmd_exists(command) -> int:
|
|||
return result
|
||||
|
||||
|
||||
def require_python_version(req_version: tuple[int]) -> bool:
|
||||
if not (2 <= len(req_version) <= 3):
|
||||
raise ValueError("req_version should be (3, 9) or (3, 10, 13)")
|
||||
return True if sys.version_info > req_version else False
|
||||
|
||||
|
||||
class OutputParser:
|
||||
@classmethod
|
||||
def parse_blocks(cls, text: str):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
ref1: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
ref2: https://github.com/Significant-Gravitas/Auto-GPT/blob/master/autogpt/llm/token_counter.py
|
||||
ref3: https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/openai.py
|
||||
ref4: https://ai.google.dev/models/gemini
|
||||
"""
|
||||
import tiktoken
|
||||
|
||||
|
|
@ -27,6 +28,7 @@ TOKEN_COSTS = {
|
|||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
|
||||
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -47,6 +49,7 @@ TOKEN_MAX = {
|
|||
"gpt-4-1106-preview": 128000,
|
||||
"text-embedding-ada-002": 8192,
|
||||
"chatglm_turbo": 32768,
|
||||
"gemini-pro": 32768,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -49,3 +49,4 @@ aiofiles==23.2.1
|
|||
gitpython==3.1.40
|
||||
zhipuai==1.0.7
|
||||
gitignore-parser==0.1.9
|
||||
google-generativeai==0.3.1
|
||||
|
|
|
|||
41
tests/metagpt/provider/test_google_gemini_api.py
Normal file
41
tests/metagpt/provider/test_google_gemini_api.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of google gemini api
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockGeminiResponse(ABC):
|
||||
text: str
|
||||
|
||||
|
||||
default_resp = MockGeminiResponse(text="I'm gemini from google")
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
|
||||
resp = GeminiGPTAPI().completion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await GeminiGPTAPI().acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
Loading…
Add table
Add a link
Reference in a new issue