diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index ad9df0396..7e865f288 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -6,11 +6,17 @@ """ import asyncio import time -from typing import NamedTuple +from typing import NamedTuple, Union import openai from openai.error import APIConnectionError -from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, +) from metagpt.config import CONFIG from metagpt.logs import logger @@ -48,12 +54,14 @@ class RateLimiter: self.last_call_time = time.time() + class Costs(NamedTuple): total_prompt_tokens: int total_completion_tokens: int total_cost: float total_budget: float + class CostManager(metaclass=Singleton): """计算使用接口的开销""" @@ -74,7 +82,9 @@ class CostManager(metaclass=Singleton): """ self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens - cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000 + cost = ( + prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] + ) / 1000 self.total_cost += cost logger.info( f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " @@ -100,6 +110,7 @@ class CostManager(metaclass=Singleton): """ return self.total_completion_tokens + def get_total_cost(self): """ Get the total cost of API calls. @@ -109,25 +120,20 @@ def get_total_cost(self): """ return self.total_cost + def get_costs(self) -> Costs: """Get all costs""" return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) -def log_and_reraise(retry_state): - logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") - logger.warning(""" -Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ -See FAQ 5.8 -""") - raise retry_state.outcome.exception() - def log_and_reraise(retry_state): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") - logger.warning(""" + logger.warning( + """ Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ See FAQ 5.8 -""") +""" + ) raise retry_state.outcome.exception() @@ -182,15 +188,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): "n": 1, "stop": None, "temperature": 0.3, - "timeout": 3 + "timeout": 3, } if CONFIG.openai_api_type == "azure": if CONFIG.deployment_name and CONFIG.deployment_id: raise ValueError("You can only use one of the `deployment_id` or `deployment_name` model") elif not CONFIG.deployment_name and not CONFIG.deployment_id: raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter") - kwargs_mode = {"engine": CONFIG.deployment_name} if CONFIG.deployment_name \ + kwargs_mode = ( + {"engine": CONFIG.deployment_name} + if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id} + ) else: kwargs_mode = {"model": self.model} kwargs.update(kwargs_mode) @@ -219,7 +228,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): @retry( stop=stop_after_attempt(3), wait=wait_fixed(1), - after=after_log(logger, logger.level('WARNING').name), + after=after_log(logger, logger.level("WARNING").name), retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) @@ -236,8 +245,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): try: prompt_tokens = count_message_tokens(messages, self.model) completion_tokens = count_string_tokens(rsp, self.model) - usage['prompt_tokens'] = prompt_tokens - usage['completion_tokens'] = completion_tokens + usage["prompt_tokens"] = prompt_tokens + usage["completion_tokens"] = completion_tokens return usage except Exception as e: logger.error("usage calculation failed!", e) @@ -273,8 +282,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def _update_costs(self, usage: dict): if CONFIG.calc_usage: try: - prompt_tokens = int(usage['prompt_tokens']) - completion_tokens = int(usage['completion_tokens']) + prompt_tokens = int(usage["prompt_tokens"]) + completion_tokens = int(usage["completion_tokens"]) self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) except Exception as e: logger.error("updating costs failed!", e) @@ -286,3 +295,31 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): if not self.auto_max_tokens: return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) + + def moderation(self, content: Union[str, list[str]]): + try: + if not content: + logger.error("content cannot be empty!") + else: + rsp = self._moderation(content=content) + return rsp + except Exception as e: + logger.error(f"moderating failed:{e}") + + def _moderation(self, content: Union[str, list[str]]): + rsp = self.llm.Moderation.create(input=content) + return rsp + + async def amoderation(self, content: Union[str, list[str]]): + try: + if not content: + logger.error("content cannot be empty!") + else: + rsp = await self._amoderation(content=content) + return rsp + except Exception as e: + logger.error(f"moderating failed:{e}") + + async def _amoderation(self, content: Union[str, list[str]]): + rsp = await self.llm.Moderation.acreate(input=content) + return rsp