From 7e0c62a7a917fee7c6c5aee5900b5e48ae79dd37 Mon Sep 17 00:00:00 2001 From: better629 Date: Sun, 24 Dec 2023 15:34:32 +0800 Subject: [PATCH] update fireworks/open_llm api due to new openai sdk --- metagpt/provider/fireworks_api.py | 139 +++++++++++++++++-- metagpt/provider/open_llm_api.py | 92 ++++++++---- metagpt/provider/openai_api.py | 11 +- metagpt/utils/repair_llm_raw_output.py | 4 +- metagpt/utils/token_counter.py | 13 +- tests/metagpt/provider/test_fireworks_api.py | 50 +++++++ 6 files changed, 257 insertions(+), 52 deletions(-) create mode 100644 tests/metagpt/provider/test_fireworks_api.py diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index bfe85f490..96b7db453 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -2,25 +2,140 @@ # -*- coding: utf-8 -*- # @Desc : fireworks.ai's api -import openai +import re -from metagpt.config import CONFIG, LLMProviderEnum +from openai import APIConnectionError, AsyncStream +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletionChunk +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise +from metagpt.utils.cost_manager import CostManager, Costs + +MODEL_GRADE_TOKEN_COSTS = { + "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition + "16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens + "80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B + "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, +} + + +class FireworksCostManager(CostManager): + def model_grade_token_costs(self, model: str) -> dict[str, float]: + def _get_model_size(model: str) -> float: + size = re.findall(".*-([0-9.]+)b", model) + size = float(size[0]) if len(size) > 0 else -1 + return size + + if "mixtral-8x7b" in model: + token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] + else: + model_size = _get_model_size(model) + if 0 < model_size <= 16: + token_costs = MODEL_GRADE_TOKEN_COSTS["16"] + elif 16 < model_size <= 80: + token_costs = MODEL_GRADE_TOKEN_COSTS["80"] + else: + token_costs = MODEL_GRADE_TOKEN_COSTS["-1"] + return token_costs + + def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str): + """ + Refs to `https://app.fireworks.ai/pricing` **Developer pricing** + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + + token_costs = self.model_grade_token_costs(model) + cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000 + self.total_cost += cost + logger.info( + f"Total running cost: ${self.total_cost:.4f} | Max budget: ${CONFIG.max_budget:.3f} | " + f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) + CONFIG.total_cost = self.total_cost @register_provider(LLMProviderEnum.FIREWORKS) class FireWorksGPTAPI(OpenAIGPTAPI): def __init__(self): - self.__init_fireworks(CONFIG) - self.llm = openai - self.model = CONFIG.fireworks_api_model + self.config: Config = CONFIG + self.__init_fireworks() self.auto_max_tokens = False + self._cost_manager = FireworksCostManager() RateLimiter.__init__(self, rpm=self.rpm) - def __init_fireworks(self, config: "Config"): - # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you - # instantiate the client, e.g. 'OpenAI(api_base=config.fireworks_api_base)' - # openai.api_key = config.fireworks_api_key - # openai.api_base = config.fireworks_api_base - self.rpm = int(config.get("RPM", 10)) + def __init_fireworks(self): + self.is_azure = False + self.rpm = int(self.config.get("RPM", 10)) + self._make_client() + self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it + + def _make_client_kwargs(self) -> (dict, dict): + kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base) + async_kwargs = kwargs.copy() + return kwargs, async_kwargs + + def _update_costs(self, usage: CompletionUsage): + if self.config.calc_usage and usage: + try: + # use FireworksCostManager not CONFIG.cost_manager + self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + except Exception as e: + logger.error(f"updating costs failed!, exp: {e}") + + def get_costs(self) -> Costs: + return self._cost_manager.get_costs() + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + **self._cons_kwargs(messages), stream=True + ) + + collected_content = [] + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + # iterate through the stream of events + async for chunk in response: + if chunk.choices: + choice = chunk.choices[0] + choice_delta = choice.delta + finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None + if choice_delta.content: + collected_content.append(choice_delta.content) + print(choice_delta.content, end="") + if finish_reason: + # fireworks api return usage when finish_reason is not None + usage = CompletionUsage(**chunk.usage) + + full_content = "".join(collected_content) + self._update_costs(usage) + return full_content + + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(APIConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + """when streaming, print each token in place.""" + if stream: + return await self._achat_completion_stream(messages) + rsp = await self._achat_completion(messages) + return self.get_choice_text(rsp) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 2e8c03ba1..dd1491780 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -2,44 +2,78 @@ # -*- coding: utf-8 -*- # @Desc : self-host open llm model with openai-compatible interface +from openai.types import CompletionUsage -from metagpt.config import CONFIG, LLMProviderEnum +from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter +from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.token_counter import count_message_tokens, count_string_tokens -# class OpenLLMCostManager(CostManager): -# """open llm model is self-host, it's free and without cost""" -# -# def update_cost(self, prompt_tokens, completion_tokens, model): -# """ -# Update the total cost, prompt tokens, and completion tokens. -# -# Args: -# prompt_tokens (int): The number of tokens used in the prompt. -# completion_tokens (int): The number of tokens used in the completion. -# model (str): The model used for the API call. -# """ -# self.total_prompt_tokens += prompt_tokens -# self.total_completion_tokens += completion_tokens -# -# logger.info( -# f"Max budget: ${CONFIG.max_budget:.3f} | " -# f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" -# ) -# CONFIG.total_cost = self.total_cost + +class OpenLLMCostManager(CostManager): + """open llm model is self-host, it's free and without cost""" + + def update_cost(self, prompt_tokens, completion_tokens, model): + """ + Update the total cost, prompt tokens, and completion tokens. + + Args: + prompt_tokens (int): The number of tokens used in the prompt. + completion_tokens (int): The number of tokens used in the completion. + model (str): The model used for the API call. + """ + self.total_prompt_tokens += prompt_tokens + self.total_completion_tokens += completion_tokens + + logger.info( + f"Max budget: ${CONFIG.max_budget:.3f} | reference " + f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) + CONFIG.total_cost = self.total_cost @register_provider(LLMProviderEnum.OPEN_LLM) class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): - self.__init_openllm(CONFIG) - self.model = CONFIG.open_llm_api_model + self.config: Config = CONFIG + self.__init_openllm() self.auto_max_tokens = False + self._cost_manager = OpenLLMCostManager() RateLimiter.__init__(self, rpm=self.rpm) - def __init_openllm(self, config: "Config"): - # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you - # instantiate the client, e.g. 'OpenAI(api_base=config.open_llm_api_base)' - # openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value - # openai.api_base = config.open_llm_api_base - self.rpm = int(config.get("RPM", 10)) + def __init_openllm(self): + self.is_azure = False + self.rpm = int(self.config.get("RPM", 10)) + self._make_client() + self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it + + def _make_client_kwargs(self) -> (dict, dict): + kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base) + async_kwargs = kwargs.copy() + return kwargs, async_kwargs + + def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + if not CONFIG.calc_usage: + return usage + + try: + usage.prompt_tokens = count_message_tokens(messages, "open-llm-model") + usage.completion_tokens = count_string_tokens(rsp, "open-llm-model") + except Exception as e: + logger.error(f"usage calculation failed!: {e}") + + return usage + + def _update_costs(self, usage: CompletionUsage): + if self.config.calc_usage and usage: + try: + # use OpenLLMCostManager not CONFIG.cost_manager + self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) + except Exception as e: + logger.error(f"updating costs failed!, exp: {e}") + + def get_costs(self) -> Costs: + return self._cost_manager.get_costs() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 44f857ed9..a39e4ccdd 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -15,7 +15,7 @@ import time from typing import List, Union import openai -from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI, RateLimitError +from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -175,13 +175,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - @retry( - wait=wait_random_exponential(min=1, max=60), - stop=stop_after_attempt(6), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(RateLimitError), - reraise=True, - ) async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: """when streaming, print each token in place.""" if stream: @@ -341,7 +334,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): try: CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) except Exception as e: - logger.error("updating costs failed!", e) + logger.error(f"updating costs failed!, exp: {e}") def get_costs(self) -> Costs: return CONFIG.cost_manager.get_costs() diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 87fd0efd0..a96c3dce0 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -230,9 +230,11 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R elif retry_state.kwargs: func_param_output = retry_state.kwargs.get("output", "") exp_str = str(retry_state.outcome.exception()) + + fix_str = "try to fix it, " if CONFIG.repair_llm_output else "" logger.warning( f"parse json from content inside [CONTENT][/CONTENT] failed at retry " - f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}" + f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}" ) repaired_output = repair_invalid_json(func_param_output, exp_str) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 94b8d76d2..a1b74a074 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -84,6 +84,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): elif "gpt-4" == model: print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return count_message_tokens(messages, model="gpt-4-0613") + elif "open-llm-model" == model: + """ + For self-hosted open_llm api, they include lots of different models. The message tokens calculation is + inaccurate. It's a reference result. + """ + tokens_per_message = 0 # ignore conversation message template prefix + tokens_per_name = 0 else: raise NotImplementedError( f"num_tokens_from_messages() is not implemented for model {model}. " @@ -112,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int: Returns: int: The number of tokens in the text string. """ - encoding = tiktoken.encoding_for_model(model_name) + try: + encoding = tiktoken.encoding_for_model(model_name) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") return len(encoding.encode(string)) diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py new file mode 100644 index 000000000..43e45adf3 --- /dev/null +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of fireworks api + +import pytest +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.completion_usage import CompletionUsage + +from metagpt.provider.fireworks_api import FireWorksGPTAPI + +default_resp = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="accounts/fireworks/models/llama-v2-13b-chat", + object="chat.completion", + created=1703300855, + choices=[ + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks")) + ], + usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), +) + +messages = [{"role": "user", "content": "who are you"}] + + +def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion: + return default_resp + + +def test_fireworks_completion(mocker): + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask) + + resp = FireWorksGPTAPI().completion(messages) + assert "fireworks" in resp.choices[0].message.content + + +async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion: + return default_resp + + +@pytest.mark.asyncio +async def test_fireworks_acompletion(mocker): + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask) + + resp = await FireWorksGPTAPI().acompletion(messages, stream=False) + + assert "fireworks" in resp.choices[0].message.content