diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index af8f56372..851041e15 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -31,6 +31,7 @@ class LLMType(Enum): MOONSHOT = "moonshot" MISTRAL = "mistral" YI = "yi" # lingyiwanwu + ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk def __missing__(self, key): return self.OPENAI diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 14d5e7682..0ed390397 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -17,6 +17,7 @@ from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM +from metagpt.provider.ark_api import ArkLLM __all__ = [ "GeminiLLM", @@ -30,4 +31,5 @@ __all__ = [ "QianFanLLM", "DashScopeLLM", "AnthropicLLM", + "ArkLLM", ] diff --git a/metagpt/provider/ark_api.py b/metagpt/provider/ark_api.py new file mode 100644 index 000000000..6506be87e --- /dev/null +++ b/metagpt/provider/ark_api.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Provider for volcengine. +See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model + +config2.yaml example: +```yaml +llm: + base_url: "https://ark.cn-beijing.volces.com/api/v3" + api_type: "ark" + endpoint: "ep-2024080514****-d****" + api_key: "d47****b-****-****-****-d6e****0fd77" + pricing_plan: "doubao-lite" +``` +""" +from typing import Optional, Union + +from pydantic import BaseModel +from volcenginesdkarkruntime import AsyncArk +from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper + +from metagpt.configs.llm_config import LLMType +from metagpt.provider import OpenAILLM +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS + + +@register_provider(LLMType.ARK) +class ArkLLM(OpenAILLM): + aclient: Optional[AsyncArk] = None + + def _init_client(self): + """SDK: https://github.com/openai/openai-python#async-usage""" + self.model = ( + self.config.endpoint or self.config.model + ) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint + self.pricing_plan = self.config.pricing_plan or self.model + kwargs = self._make_client_kwargs() + self.aclient = AsyncArk(**kwargs) + + def _make_client_kwargs(self) -> dict: + kvs = { + "ak": self.config.access_key, + "sk": self.config.secret_key, + "api_key": self.config.api_key, + "base_url": self.config.base_url, + } + kwargs = {k: v for k, v in kvs.items() if v} + + # to use proxy, openai v1 needs http_client + if proxy_params := self._get_proxy_params(): + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + + return kwargs + + def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): + if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs: + self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS) + if self.pricing_plan in self.cost_manager.token_costs: + super()._update_costs(usage, self.pricing_plan, local_calc_usage) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index dbfed72df..623f1596d 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -89,7 +89,7 @@ class OpenAILLM(BaseLLM): log_llm_stream(chunk_message) collected_messages.append(chunk_message) if finish_reason: - if hasattr(chunk, "usage"): + if hasattr(chunk, "usage") and chunk.usage: # Some services have usage as an attribute of the chunk, such as Fireworks usage = CompletionUsage(**chunk.usage) elif hasattr(chunk.choices[0], "usage"): diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 0ba2daa89..aad87184d 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -143,6 +143,14 @@ FIREWORKS_GRADE_TOKEN_COSTS = { "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, } +# https://console.volcengine.com/ark/region:ark+cn-beijing/model +DOUBAO_TOKEN_COSTS = { + "doubao-lite": {"prompt": 0.0003, "completion": 0.0006}, + "doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010}, + "doubao-pro": {"prompt": 0.0008, "completion": 0.0020}, + "doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090}, +} + # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo TOKEN_MAX = { "gpt-4-0125-preview": 128000, diff --git a/requirements.txt b/requirements.txt index 73ca0f87b..4269f0971 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,5 @@ imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 rank-bm25==0.2.2 # for tool recommendation -jieba==0.42.1 # for tool recommendation \ No newline at end of file +jieba==0.42.1 # for tool recommendation +volcengine-python-sdk[ark]~=1.0.94 \ No newline at end of file