mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-06 14:22:46 +02:00
Merge pull request #946 from YangQianli92/main
Feat:Wrap Openllm, Fireworks and other services into the OpenAILLM class.
This commit is contained in:
commit
dacc8aa06d
12 changed files with 165 additions and 319 deletions
|
|
@ -26,6 +26,8 @@ class LLMType(Enum):
|
|||
OLLAMA = "ollama"
|
||||
QIANFAN = "qianfan" # Baidu BCE
|
||||
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
|
||||
MOONSHOT = "moonshot"
|
||||
MISTRAL = 'mistral'
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OPENAI
|
||||
|
|
|
|||
|
|
@ -12,10 +12,14 @@ from typing import Any, Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import (
|
||||
CostManager,
|
||||
FireworksCostManager,
|
||||
TokenCostManager,
|
||||
)
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from metagpt.utils.project_repo import ProjectRepo
|
||||
|
||||
|
|
@ -80,12 +84,21 @@ class Context(BaseModel):
|
|||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
|
||||
"""Return a CostManager instance"""
|
||||
if llm_config.api_type == LLMType.FIREWORKS:
|
||||
return FireworksCostManager()
|
||||
elif llm_config.api_type == LLMType.OPEN_LLM:
|
||||
return TokenCostManager()
|
||||
else:
|
||||
return self.cost_manager
|
||||
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.llm)
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
self._llm.cost_manager = self._select_costmanager(self.config.llm)
|
||||
return self._llm
|
||||
|
||||
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
|
||||
|
|
@ -93,5 +106,5 @@ class Context(BaseModel):
|
|||
# if self._llm is None:
|
||||
llm = create_llm_instance(llm_config)
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
llm.cost_manager = self._select_costmanager(llm_config)
|
||||
return llm
|
||||
|
|
|
|||
|
|
@ -6,10 +6,8 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
|
|
@ -20,9 +18,7 @@ from metagpt.provider.qianfan_api import QianFanLLM
|
|||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
"GeminiLLM",
|
||||
"OpenLLM",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAILLM",
|
||||
"AzureOpenAILLM",
|
||||
|
|
|
|||
|
|
@ -1,121 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : fireworks.ai's api
|
||||
|
||||
import re
|
||||
|
||||
from openai import APIConnectionError, AsyncStream
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = MODEL_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
||||
|
||||
@register_provider(LLMType.FIREWORKS)
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config=config)
|
||||
self.auto_max_tokens = False
|
||||
self.cost_manager = FireworksCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
collected_content = []
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
choice_delta = choice.delta
|
||||
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
|
||||
if choice_delta.content:
|
||||
collected_content.append(choice_delta.content)
|
||||
log_llm_stream(choice_delta.content)
|
||||
if finish_reason:
|
||||
# fireworks api return usage when finish_reason is not None
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
log_llm_stream("\n")
|
||||
|
||||
full_content = "".join(collected_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -21,11 +21,15 @@ class LLMProviderRegistry:
|
|||
return self.providers[enum]
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
def register_provider(keys):
|
||||
"""register provider to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
if isinstance(keys, list):
|
||||
for key in keys:
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
else:
|
||||
LLM_REGISTRY.register(keys, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : self-host open llm model with openai-compatible interface
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
||||
@register_provider(LLMType.OPEN_LLM)
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self._cost_manager = TokenCostManager()
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(api_key="sk-xxx", base_url=self.config.base_url)
|
||||
return kwargs
|
||||
|
||||
def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage:
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, "open-llm-model")
|
||||
usage.completion_tokens = count_string_tokens(rsp, "open-llm-model")
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
|
||||
return usage
|
||||
|
|
@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
|||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -50,7 +50,7 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMType.OPENAI)
|
||||
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT])
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
|
|
@ -84,20 +84,39 @@ class OpenAILLM(BaseLLM):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
usage = None
|
||||
collected_messages = []
|
||||
async for chunk in response:
|
||||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
|
||||
log_llm_stream(chunk_message)
|
||||
collected_messages.append(chunk_message)
|
||||
if finish_reason:
|
||||
if hasattr(chunk, "usage"):
|
||||
# Some services have usage as an attribute of the chunk, such as Fireworks
|
||||
usage = CompletionUsage(**chunk.usage)
|
||||
elif hasattr(chunk.choices[0], "usage"):
|
||||
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
|
||||
usage = CompletionUsage(**chunk.choices[0].usage)
|
||||
|
||||
log_llm_stream("\n")
|
||||
full_reply_content = "".join(collected_messages)
|
||||
if not usage:
|
||||
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
# "n": 1, # Some services do not provide this parameter, such as mistral
|
||||
# "stop": None, # default it's None and gpt4-v can't have this one
|
||||
"temperature": self.config.temperature,
|
||||
"model": self.model,
|
||||
|
|
@ -126,18 +145,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
log_llm_stream("\n")
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
await self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
@ -261,9 +269,10 @@ class OpenAILLM(BaseLLM):
|
|||
if not self.config.calc_usage:
|
||||
return usage
|
||||
|
||||
model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model"
|
||||
try:
|
||||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
usage.prompt_tokens = count_message_tokens(messages, model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, model)
|
||||
except Exception as e:
|
||||
logger.warning(f"usage calculation failed: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,13 @@
|
|||
@Desc : mashenquan, 2023/8/28. Separate the `CostManager` class to support user-level cost accounting.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import NamedTuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.token_counter import TOKEN_COSTS
|
||||
from metagpt.utils.token_counter import FIREWORKS_GRADE_TOKEN_COSTS, TOKEN_COSTS
|
||||
|
||||
|
||||
class Costs(NamedTuple):
|
||||
|
|
@ -103,3 +104,44 @@ class TokenCostManager(CostManager):
|
|||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
logger.info(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||
|
||||
|
||||
class FireworksCostManager(CostManager):
|
||||
def model_grade_token_costs(self, model: str) -> dict[str, float]:
|
||||
def _get_model_size(model: str) -> float:
|
||||
size = re.findall(".*-([0-9.]+)b", model)
|
||||
size = float(size[0]) if len(size) > 0 else -1
|
||||
return size
|
||||
|
||||
if "mixtral-8x7b" in model:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["mixtral-8x7b"]
|
||||
else:
|
||||
model_size = _get_model_size(model)
|
||||
if 0 < model_size <= 16:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["16"]
|
||||
elif 16 < model_size <= 80:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["80"]
|
||||
else:
|
||||
token_costs = FIREWORKS_GRADE_TOKEN_COSTS["-1"]
|
||||
return token_costs
|
||||
|
||||
def update_cost(self, prompt_tokens: int, completion_tokens: int, model: str):
|
||||
"""
|
||||
Refs to `https://app.fireworks.ai/pricing` **Developer pricing**
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
|
||||
token_costs = self.model_grade_token_costs(model)
|
||||
cost = (prompt_tokens * token_costs["prompt"] + completion_tokens * token_costs["completion"]) / 1000000
|
||||
self.total_cost += cost
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.4f}"
|
||||
f"Current cost: ${cost:.4f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -35,6 +35,14 @@ TOKEN_COSTS = {
|
|||
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
|
||||
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
|
||||
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
|
||||
"moonshot-v1-8k": {"prompt": 0.012, "completion": 0.012}, # prompt + completion tokens=0.012¥/k-tokens
|
||||
"moonshot-v1-32k": {"prompt": 0.024, "completion": 0.024},
|
||||
"moonshot-v1-128k": {"prompt": 0.06, "completion": 0.06},
|
||||
"open-mistral-7b": {"prompt": 0.00025, "completion": 0.00025},
|
||||
"open-mixtral-8x7b": {"prompt": 0.0007, "completion": 0.0007},
|
||||
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
|
||||
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
|
||||
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -120,6 +128,14 @@ DASHSCOPE_TOKEN_COSTS = {
|
|||
}
|
||||
|
||||
|
||||
FIREWORKS_GRADE_TOKEN_COSTS = {
|
||||
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
|
||||
"16": {"prompt": 0.2, "completion": 0.8}, # 16 means model size <= 16B; 0.2 means $0.2/1M tokens
|
||||
"80": {"prompt": 0.7, "completion": 2.8}, # 80 means 16B < model size <= 80B
|
||||
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
|
||||
}
|
||||
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
|
|
@ -143,6 +159,14 @@ TOKEN_MAX = {
|
|||
"glm-3-turbo": 128000,
|
||||
"glm-4": 128000,
|
||||
"gemini-pro": 32768,
|
||||
"moonshot-v1-8k": 8192,
|
||||
"moonshot-v1-32k": 32768,
|
||||
"moonshot-v1-128k": 128000,
|
||||
"open-mistral-7b": 8192,
|
||||
"open-mixtral-8x7b": 32768,
|
||||
"mistral-small-latest": 32768,
|
||||
"mistral-medium-latest": 32768,
|
||||
"mistral-large-latest": 32768,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,74 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of fireworks api
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireworksLLM,
|
||||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "fireworks"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
|
||||
|
||||
|
||||
def test_fireworks_costmanager():
|
||||
cost_manager = FireworksCostManager()
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
|
||||
assert cost_manager.total_cost == 0.5
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_llm = FireworksLLM(mock_llm_config)
|
||||
fireworks_llm.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_llm._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_llm.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "llama2-7b"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name)
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_llm = OpenLLM(mock_llm_config)
|
||||
openllm_llm.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_llm.cost_manager = CostManager()
|
||||
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_llm.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_llm.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)
|
||||
|
|
@ -1,10 +1,11 @@
|
|||
import pytest
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion import Choice, CompletionUsage
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from PIL import Image
|
||||
|
||||
|
|
@ -16,6 +17,22 @@ from tests.metagpt.provider.mock_llm_config import (
|
|||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
get_openai_chat_completion_chunk,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "AI assistant"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
|
||||
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
|
||||
|
||||
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -121,3 +138,29 @@ async def test_gen_image():
|
|||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
llm = OpenAILLM(mock_llm_config)
|
||||
|
||||
resp = await llm.acompletion(messages)
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.choices[0].message.content == resp_cont
|
||||
assert resp.usage == usage
|
||||
|
||||
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue