Merge pull request #46 from send18/feature-wait-exponential-if-rate-limit

wait_exponential if RateLimitError
This commit is contained in:
Justin-ZL 2023-09-09 14:32:40 +08:00 committed by GitHub
commit 874bf98183
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 67 deletions

View file

@ -9,7 +9,6 @@
from abc import abstractmethod
from typing import Optional
from metagpt.logs import logger
from metagpt.provider.base_chatbot import BaseChatbot
@ -52,13 +51,7 @@ class BaseGPTAPI(BaseChatbot):
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
try:
rsp = await self.acompletion_text(message, stream=True, generator=generator)
except Exception as e:
logger.exception(f"{e}")
logger.info(f"ask:{msg}, error:{e}")
raise e
logger.info(f"ask:{msg}, anwser:{rsp}")
rsp = await self.acompletion_text(message, stream=True, generator=generator)
return rsp
def _extract_assistant_rsp(self, context):

View file

@ -7,17 +7,16 @@
Change cost control from global to company level.
"""
import asyncio
import random
import time
import traceback
import openai
from openai.error import APIConnectionError
from openai.error import APIConnectionError, RateLimitError
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_fixed,
)
@ -75,16 +74,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
def __init__(self):
self.llm = openai
self.model = CONFIG.openai_api_model
self.auto_max_tokens = False
self.rpm = int(CONFIG.get("RPM", 10))
RateLimiter.__init__(self, rpm=self.rpm)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await self.async_retry_call(
openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True
)
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
# iterate through the stream of events
async for chunk in response:
chunk_message = chunk["choices"][0]["delta"] # extract the message
@ -118,12 +114,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.async_retry_call(self.llm.ChatCompletion.acreate, **self._cons_kwargs(messages))
rsp = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages))
self._update_costs(rsp.get("usage"))
return rsp
def _chat_completion(self, messages: list[dict]) -> dict:
rsp = self.retry_call(self.llm.ChatCompletion.create, **self._cons_kwargs(messages))
rsp = openai.ChatCompletion.create(**self._cons_kwargs(messages))
self._update_costs(rsp)
return rsp
@ -144,6 +140,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
@retry(
stop=stop_after_attempt(6),
wait=wait_exponential(1),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False) -> str:
"""when streaming, print each token in place."""
if stream:
@ -221,58 +224,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
@staticmethod
async def async_retry_call(func, *args, **kwargs):
for i in range(OpenAIGPTAPI.MAX_TRY):
try:
rsp = await func(*args, **kwargs)
return rsp
except openai.error.RateLimitError as e:
random_time = random.uniform(0, 3) # 生成0到5秒之间的随机时间
rounded_time = round(random_time, 1) # 保留一位小数以实现0.1秒的精度
logger.warning(f"Exception:{e}, sleeping for {rounded_time} seconds")
await asyncio.sleep(rounded_time)
continue
except Exception as e:
error_str = traceback.format_exc()
logger.error(f"Exception:{e}, stack:{error_str}")
raise e
raise openai.error.OpenAIError("Exceeds the maximum retries")
@staticmethod
def retry_call(func, *args, **kwargs):
for i in range(OpenAIGPTAPI.MAX_TRY):
try:
rsp = func(*args, **kwargs)
return rsp
except openai.error.RateLimitError as e:
logger.warning(f"Exception:{e}")
continue
except (
openai.error.AuthenticationError,
openai.error.PermissionError,
openai.error.InvalidAPIType,
openai.error.SignatureVerificationError,
) as e:
logger.warning(f"Exception:{e}")
raise e
except Exception as e:
error_str = traceback.format_exc()
logger.error(f"Exception:{e}, stack:{error_str}")
raise e
raise openai.error.OpenAIError("Exceeds the maximum retries")
async def get_summary(self, text: str, max_words=200, keep_language: bool = False, **kwargs) -> str:
from metagpt.memory.brain_memory import BrainMemory
memory = BrainMemory(llm_type=LLMType.OPENAI.value, historical_summary=text, cacheable=False)
return await memory.summarize(llm=self, max_length=max_words, keep_language=keep_language)
MAX_TRY = 5
if __name__ == "__main__":
txt = """
as dfas sad lkf sdkl sakdfsdk sjd jsk sdl sk dd sd asd fa sdf sad dd
- .gitlab-ci.yml & base_test.py
"""