feat: +provider ark. #1429

This commit is contained in:
莘权 马 2024-08-05 20:04:50 +08:00
parent c036574507
commit 67842cf60d
6 changed files with 75 additions and 2 deletions

View file

@ -31,6 +31,7 @@ class LLMType(Enum):
MOONSHOT = "moonshot"
MISTRAL = "mistral"
YI = "yi" # lingyiwanwu
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
def __missing__(self, key):
return self.OPENAI

View file

@ -17,6 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.ark_api import ArkLLM
__all__ = [
"GeminiLLM",
@ -30,4 +31,5 @@ __all__ = [
"QianFanLLM",
"DashScopeLLM",
"AnthropicLLM",
"ArkLLM",
]

View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provider for volcengine.
See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model
config2.yaml example:
```yaml
llm:
base_url: "https://ark.cn-beijing.volces.com/api/v3"
api_type: "ark"
endpoint: "ep-2024080514****-d****"
api_key: "d47****b-****-****-****-d6e****0fd77"
pricing_plan: "doubao-lite"
```
"""
from typing import Optional, Union
from pydantic import BaseModel
from volcenginesdkarkruntime import AsyncArk
from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper
from metagpt.configs.llm_config import LLMType
from metagpt.provider import OpenAILLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS
@register_provider(LLMType.ARK)
class ArkLLM(OpenAILLM):
aclient: Optional[AsyncArk] = None
def _init_client(self):
"""SDK: https://github.com/openai/openai-python#async-usage"""
self.model = (
self.config.endpoint or self.config.model
) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncArk(**kwargs)
def _make_client_kwargs(self) -> dict:
kvs = {
"ak": self.config.access_key,
"sk": self.config.secret_key,
"api_key": self.config.api_key,
"base_url": self.config.base_url,
}
kwargs = {k: v for k, v in kvs.items() if v}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs:
self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS)
if self.pricing_plan in self.cost_manager.token_costs:
super()._update_costs(usage, self.pricing_plan, local_calc_usage)

View file

@ -89,7 +89,7 @@ class OpenAILLM(BaseLLM):
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if finish_reason:
if hasattr(chunk, "usage"):
if hasattr(chunk, "usage") and chunk.usage:
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage)
elif hasattr(chunk.choices[0], "usage"):

View file

@ -143,6 +143,14 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}
# https://console.volcengine.com/ark/region:ark+cn-beijing/model
DOUBAO_TOKEN_COSTS = {
"doubao-lite": {"prompt": 0.0003, "completion": 0.0006},
"doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010},
"doubao-pro": {"prompt": 0.0008, "completion": 0.0020},
"doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090},
}
# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
TOKEN_MAX = {
"gpt-4-0125-preview": 128000,

View file

@ -69,4 +69,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
jieba==0.42.1 # for tool recommendation
volcengine-python-sdk[ark]~=1.0.94