fix format

This commit is contained in:
better629 2023-12-21 15:06:59 +08:00
parent 18a195a367
commit 6af9fecf65
7 changed files with 26 additions and 46 deletions

View file

@ -12,10 +12,4 @@ from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
__all__ = [
"FireWorksGPTAPI",
"GeminiGPTAPI",
"OpenLLMGPTAPI",
"OpenAIGPTAPI",
"ZhiPuAIGPTAPI"
]
__all__ = ["FireWorksGPTAPI", "GeminiGPTAPI", "OpenLLMGPTAPI", "OpenAIGPTAPI", "ZhiPuAIGPTAPI"]

View file

@ -6,8 +6,11 @@ import google.generativeai as genai
from google.ai import generativelanguage as glm
from google.generativeai.generative_models import GenerativeModel
from google.generativeai.types import content_types
from google.generativeai.types.generation_types import GenerateContentResponse, AsyncGenerateContentResponse
from google.generativeai.types.generation_types import GenerationConfig
from google.generativeai.types.generation_types import (
AsyncGenerateContentResponse,
GenerateContentResponse,
GenerationConfig,
)
from tenacity import (
after_log,
retry,
@ -29,15 +32,11 @@ class GeminiGenerativeModel(GenerativeModel):
Will use default GenerativeModel if it fixed.
"""
def count_tokens(
self, contents: content_types.ContentsType
) -> glm.CountTokensResponse:
def count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
contents = content_types.to_contents(contents)
return self._client.count_tokens(model=self.model_name, contents=contents)
async def count_tokens_async(
self, contents: content_types.ContentsType
) -> glm.CountTokensResponse:
async def count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
contents = content_types.to_contents(contents)
return await self._async_client.count_tokens(model=self.model_name, contents=contents)
@ -68,17 +67,11 @@ class GeminiGPTAPI(BaseGPTAPI):
return {"role": "model", "parts": [msg]}
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"contents": messages,
"generation_config": GenerationConfig(
temperature=0.3
),
"stream": stream
}
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
""" update each request's token cost """
"""update each request's token cost"""
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
@ -94,20 +87,14 @@ class GeminiGPTAPI(BaseGPTAPI):
req_text = messages[-1]["parts"][0] if messages else ""
prompt_resp = self.llm.count_tokens(contents={"role": "user", "parts": [{"text": req_text}]})
completion_resp = self.llm.count_tokens(contents={"role": "model", "parts": [{"text": resp_text}]})
usage = {
"prompt_tokens": prompt_resp.total_tokens,
"completion_tokens": completion_resp.total_tokens
}
usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens}
return usage
async def aget_usage(self, messages: list[dict], resp_text: str) -> dict:
req_text = messages[-1]["parts"][0] if messages else ""
prompt_resp = await self.llm.count_tokens_async(contents={"role": "user", "parts": [{"text": req_text}]})
completion_resp = await self.llm.count_tokens_async(contents={"role": "model", "parts": [{"text": resp_text}]})
usage = {
"prompt_tokens": prompt_resp.total_tokens,
"completion_tokens": completion_resp.total_tokens
}
usage = {"prompt_tokens": prompt_resp.total_tokens, "completion_tokens": completion_resp.total_tokens}
return usage
def completion(self, messages: list[dict]) -> "GenerateContentResponse":
@ -126,8 +113,9 @@ class GeminiGPTAPI(BaseGPTAPI):
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages,
stream=True))
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
**self._const_kwargs(messages, stream=True)
)
collected_content = []
async for chunk in resp:
content = chunk.text
@ -144,10 +132,10 @@ class GeminiGPTAPI(BaseGPTAPI):
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
""" response in async with stream or non-stream mode """
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)

View file

@ -70,7 +70,7 @@ class Researcher(Role):
return ret
def research_system_text(self, topic, current_task: Action) -> str:
""" BACKWARD compatible
"""BACKWARD compatible
This allows sub-class able to define its own system prompt based on topic.
return the previous implementation to have backward compatible
Args:

View file

@ -106,8 +106,8 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
options.add_argument("--headless")
options.add_argument("--enable-javascript")
if browser_type == "chrome":
options.add_argument("--disable-gpu") # This flag can help avoid renderer issue
options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems
options.add_argument("--disable-gpu") # This flag can help avoid renderer issue
options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems
options.add_argument("--no-sandbox")
for i in args:
options.add_argument(i)

View file

@ -226,7 +226,7 @@ class OutputParser:
if start_index != -1 and end_index != -1:
# Extract the structure part
structure_text = text[start_index: end_index + 1]
structure_text = text[start_index : end_index + 1]
try:
# Attempt to convert the text to a Python data type using ast.literal_eval

View file

@ -26,7 +26,7 @@ TOKEN_COSTS = {
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005}
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
}
@ -45,7 +45,7 @@ TOKEN_MAX = {
"gpt-4-1106-preview": 128000,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768,
"gemini-pro": 32768
"gemini-pro": 32768,
}

View file

@ -2,16 +2,14 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of google gemini api
import pytest
from abc import ABC
from dataclasses import dataclass
import pytest
from metagpt.provider.google_gemini_api import GeminiGPTAPI
messages = [
{"role": "user", "content": "who are you"}
]
messages = [{"role": "user", "content": "who are you"}]
@dataclass