update fireworks/open_llm api due to new openai sdk

This commit is contained in:
better629 2023-12-24 15:34:32 +08:00
parent a1f39d1269
commit 7e0c62a7a9
6 changed files with 257 additions and 52 deletions

View file

@ -2,25 +2,140 @@
# -*- coding: utf-8 -*-
# @Desc : fireworks.ai's api
import openai
import re
from metagpt.config import CONFIG, LLMProviderEnum
from openai import APIConnectionError, AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
class FireworksCostManager(CostManager):
def model_grade_token_costs(self, model: str) -> dict[str, float]:
def _get_model_size(model: str) -> float:
size = re.findall(".*-([0-9.]+)b", model)
size = float(size[0]) if len(size) > 0 else -1
return size
if "mixtral-8x7b" in model:
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
else:
model_size = _get_model_size(model)
if 0 < model_size <= 16:
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
elif 16 < model_size <= 80:
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
else:
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
return token_costs
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
"""
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
token_costs = self.model_grade_token_costs(model)
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
self.total_cost += cost
logger.info(
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${CONFIG.max_budget:.3f} | "
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.FIREWORKS)
class FireWorksGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_fireworks(CONFIG)
self.llm = openai
self.model = CONFIG.fireworks_api_model
self.config: Config = CONFIG
self.__init_fireworks()
self.auto_max_tokens = False
self._cost_manager = FireworksCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_fireworks(self, config: "Config"):
# TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you
# instantiate the client, e.g. 'OpenAI(api_base=config.fireworks_api_base)'
# openai.api_key = config.fireworks_api_key
# openai.api_base = config.fireworks_api_base
self.rpm = int(config.get("RPM", 10))
def __init_fireworks(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
async_kwargs = kwargs.copy()
return kwargs, async_kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)
collected_content = []
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
# iterate through the stream of events
async for chunk in response:
if chunk.choices:
choice = chunk.choices[0]
choice_delta = choice.delta
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
if choice_delta.content:
collected_content.append(choice_delta.content)
print(choice_delta.content, end="")
if finish_reason:
# fireworks api return usage when finish_reason is not None
usage = CompletionUsage(**chunk.usage)
full_content = "".join(collected_content)
self._update_costs(usage)
return full_content
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)

View file

@ -2,44 +2,78 @@
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with openai-compatible interface
from openai.types import CompletionUsage
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
# class OpenLLMCostManager(CostManager):
# """open llm model is self-host, it's free and without cost"""
#
# def update_cost(self, prompt_tokens, completion_tokens, model):
# """
# Update the total cost, prompt tokens, and completion tokens.
#
# Args:
# prompt_tokens (int): The number of tokens used in the prompt.
# completion_tokens (int): The number of tokens used in the completion.
# model (str): The model used for the API call.
# """
# self.total_prompt_tokens += prompt_tokens
# self.total_completion_tokens += completion_tokens
#
# logger.info(
# f"Max budget: ${CONFIG.max_budget:.3f} | "
# f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
# )
# CONFIG.total_cost = self.total_cost
class OpenLLMCostManager(CostManager):
"""open llm model is self-host, it's free and without cost"""
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OPEN_LLM)
class OpenLLMGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_openllm(CONFIG)
self.model = CONFIG.open_llm_api_model
self.config: Config = CONFIG
self.__init_openllm()
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openllm(self, config: "Config"):
# TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you
# instantiate the client, e.g. 'OpenAI(api_base=config.open_llm_api_base)'
# openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
# openai.api_base = config.open_llm_api_base
self.rpm = int(config.get("RPM", 10))
def __init_openllm(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
async_kwargs = kwargs.copy()
return kwargs, async_kwargs
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
if not CONFIG.calc_usage:
return usage
try:
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
except Exception as e:
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -15,7 +15,7 @@ import time
from typing import List, Union
import openai
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI, RateLimitError
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
@ -175,13 +175,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
@ -341,7 +334,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
try:
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error("updating costs failed!", e)
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return CONFIG.cost_manager.get_costs()

View file

@ -230,9 +230,11 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
logger.warning(
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
)
repaired_output = repair_invalid_json(func_param_output, exp_str)

View file

@ -84,6 +84,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
elif "gpt-4" == model:
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return count_message_tokens(messages, model="gpt-4-0613")
elif "open-llm-model" == model:
"""
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
inaccurate. It's a reference result.
"""
tokens_per_message = 0 # ignore conversation message template prefix
tokens_per_name = 0
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
@ -112,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
Returns:
int: The number of tokens in the text string.
"""
encoding = tiktoken.encoding_for_model(model_name)
try:
encoding = tiktoken.encoding_for_model(model_name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(string))

View file

@ -0,0 +1,50 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import FireWorksGPTAPI
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks"))
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
messages = [{"role": "user", "content": "who are you"}]
def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion:
return default_resp
def test_fireworks_completion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask)
resp = FireWorksGPTAPI().completion(messages)
assert "fireworks" in resp.choices[0].message.content
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion:
return default_resp
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask)
resp = await FireWorksGPTAPI().acompletion(messages, stream=False)
assert "fireworks" in resp.choices[0].message.content