Merge pull request #324 from Justin-ZL/main

add: openai moderation
This commit is contained in:
stellaHSR 2023-09-15 16:37:40 +08:00 committed by GitHub
commit 9260021c3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,11 +6,17 @@
"""
import asyncio
import time
from typing import NamedTuple
from typing import NamedTuple, Union
import openai
from openai.error import APIConnectionError
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_fixed,
)
from metagpt.config import CONFIG
from metagpt.logs import logger
@ -48,12 +54,14 @@ class RateLimiter:
self.last_call_time = time.time()
class Costs(NamedTuple):
total_prompt_tokens: int
total_completion_tokens: int
total_cost: float
total_budget: float
class CostManager(metaclass=Singleton):
"""计算使用接口的开销"""
@ -74,7 +82,9 @@ class CostManager(metaclass=Singleton):
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
@ -100,6 +110,7 @@ class CostManager(metaclass=Singleton):
"""
return self.total_completion_tokens
def get_total_cost(self):
"""
Get the total cost of API calls.
@ -109,25 +120,20 @@ def get_total_cost(self):
"""
return self.total_cost
def get_costs(self) -> Costs:
"""Get all costs"""
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning("""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
""")
raise retry_state.outcome.exception()
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning("""
logger.warning(
"""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
""")
"""
)
raise retry_state.outcome.exception()
@ -182,15 +188,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"n": 1,
"stop": None,
"temperature": 0.3,
"timeout": 3
"timeout": 3,
}
if CONFIG.openai_api_type == "azure":
if CONFIG.deployment_name and CONFIG.deployment_id:
raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model")
elif not CONFIG.deployment_name and not CONFIG.deployment_id:
raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter")
kwargs_mode = {"engine": CONFIG.deployment_name} if CONFIG.deployment_name \
kwargs_mode = (
{"engine": CONFIG.deployment_name}
if CONFIG.deployment_name
else {"deployment_id": CONFIG.deployment_id}
)
else:
kwargs_mode = {"model": self.model}
kwargs.update(kwargs_mode)
@ -219,7 +228,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(1),
after=after_log(logger, logger.level('WARNING').name),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
@ -236,8 +245,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
try:
prompt_tokens = count_message_tokens(messages, self.model)
completion_tokens = count_string_tokens(rsp, self.model)
usage['prompt_tokens'] = prompt_tokens
usage['completion_tokens'] = completion_tokens
usage["prompt_tokens"] = prompt_tokens
usage["completion_tokens"] = completion_tokens
return usage
except Exception as e:
logger.error("usage calculation failed!", e)
@ -273,8 +282,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def _update_costs(self, usage: dict):
if CONFIG.calc_usage:
try:
prompt_tokens = int(usage['prompt_tokens'])
completion_tokens = int(usage['completion_tokens'])
prompt_tokens = int(usage["prompt_tokens"])
completion_tokens = int(usage["completion_tokens"])
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed!", e)
@ -286,3 +295,31 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
def moderation(self, content: Union[str, list[str]]):
try:
if not content:
logger.error("content cannot be empty!")
else:
rsp = self._moderation(content=content)
return rsp
except Exception as e:
logger.error(f"moderating failed:{e}")
def _moderation(self, content: Union[str, list[str]]):
rsp = self.llm.Moderation.create(input=content)
return rsp
async def amoderation(self, content: Union[str, list[str]]):
try:
if not content:
logger.error("content cannot be empty!")
else:
rsp = await self._amoderation(content=content)
return rsp
except Exception as e:
logger.error(f"moderating failed:{e}")
async def _amoderation(self, content: Union[str, list[str]]):
rsp = await self.llm.Moderation.acreate(input=content)
return rsp