mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
update fireworks/open_llm api due to new openai sdk
This commit is contained in:
parent
a1f39d1269
commit
7e0c62a7a9
6 changed files with 257 additions and 52 deletions
|
|
@ -2,25 +2,140 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import openai
|
||||
import re
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from openai import APIConnectionError, AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
self.__init_fireworks(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.fireworks_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = FireworksCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_fireworks(self, config: "Config"):
|
||||
# TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you
|
||||
# instantiate the client, e.g. 'OpenAI(api_base=config.fireworks_api_base)'
|
||||
# openai.api_key = config.fireworks_api_key
|
||||
# openai.api_base = config.fireworks_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(api_key=self.config.fireworks_api_key, base_url=self.config.fireworks_api_base)
|
||||
async_kwargs = kwargs.copy()
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use FireworksCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
choice_delta = choice.delta
|
||||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
print(choice_delta.content, end="")
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
|
|||
|
|
@ -2,44 +2,78 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
# class OpenLLMCostManager(CostManager):
|
||||
# """open llm model is self-host, it's free and without cost"""
|
||||
#
|
||||
# def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
# """
|
||||
# Update the total cost, prompt tokens, and completion tokens.
|
||||
#
|
||||
# Args:
|
||||
# prompt_tokens (int): The number of tokens used in the prompt.
|
||||
# completion_tokens (int): The number of tokens used in the completion.
|
||||
# model (str): The model used for the API call.
|
||||
# """
|
||||
# self.total_prompt_tokens += prompt_tokens
|
||||
# self.total_completion_tokens += completion_tokens
|
||||
#
|
||||
# logger.info(
|
||||
# f"Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
# f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
# )
|
||||
# CONFIG.total_cost = self.total_cost
|
||||
|
||||
class OpenLLMCostManager(CostManager):
|
||||
"""open llm model is self-host, it's free and without cost"""
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
logger.info(
|
||||
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
self.__init_openllm(CONFIG)
|
||||
self.model = CONFIG.open_llm_api_model
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openllm(self, config: "Config"):
|
||||
# TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you
|
||||
# instantiate the client, e.g. 'OpenAI(api_base=config.open_llm_api_base)'
|
||||
# openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
|
||||
# openai.api_base = config.open_llm_api_base
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.open_llm_api_base)
|
||||
async_kwargs = kwargs.copy()
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not CONFIG.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
|
||||
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if self.config.calc_usage and usage:
|
||||
try:
|
||||
# use OpenLLMCostManager not CONFIG.cost_manager
|
||||
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import time
|
|||
from typing import List, Union
|
||||
|
||||
import openai
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI, RateLimitError
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
|
@ -175,13 +175,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(RateLimitError),
|
||||
reraise=True,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
|
|
@ -341,7 +334,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
try:
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error("updating costs failed!", e)
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
|
|
|||
|
|
@ -230,9 +230,11 @@ def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["R
|
|||
elif retry_state.kwargs:
|
||||
func_param_output = retry_state.kwargs.get("output", "")
|
||||
exp_str = str(retry_state.outcome.exception())
|
||||
|
||||
fix_str = "try to fix it, " if CONFIG.repair_llm_output else ""
|
||||
logger.warning(
|
||||
f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
|
||||
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}"
|
||||
f"{retry_state.attempt_number}, {fix_str}exp: {exp_str}"
|
||||
)
|
||||
|
||||
repaired_output = repair_invalid_json(func_param_output, exp_str)
|
||||
|
|
|
|||
|
|
@ -84,6 +84,13 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
elif "gpt-4" == model:
|
||||
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return count_message_tokens(messages, model="gpt-4-0613")
|
||||
elif "open-llm-model" == model:
|
||||
"""
|
||||
For self-hosted open_llm api, they include lots of different models. The message tokens calculation is
|
||||
inaccurate. It's a reference result.
|
||||
"""
|
||||
tokens_per_message = 0 # ignore conversation message template prefix
|
||||
tokens_per_name = 0
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. "
|
||||
|
|
@ -112,7 +119,11 @@ def count_string_tokens(string: str, model_name: str) -> int:
|
|||
Returns:
|
||||
int: The number of tokens in the text string.
|
||||
"""
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
|
|
|
|||
50
tests/metagpt/provider/test_fireworks_api.py
Normal file
50
tests/metagpt/provider/test_fireworks_api.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks"))
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
def test_fireworks_completion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask)
|
||||
|
||||
resp = FireWorksGPTAPI().completion(messages)
|
||||
assert "fireworks" in resp.choices[0].message.content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask)
|
||||
|
||||
resp = await FireWorksGPTAPI().acompletion(messages, stream=False)
|
||||
|
||||
assert "fireworks" in resp.choices[0].message.content
|
||||
Loading…
Add table
Add a link
Reference in a new issue