diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index 67fb6afdb..e7c280ee3 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -33,7 +33,7 @@ class LLMType(Enum): YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" BEDROCK = "bedrock" - ARK = "ark" + ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk def __missing__(self, key): return self.OPENAI diff --git a/metagpt/provider/ark_api.py b/metagpt/provider/ark_api.py index c24bd1ee9..0c5704b91 100644 --- a/metagpt/provider/ark_api.py +++ b/metagpt/provider/ark_api.py @@ -1,12 +1,33 @@ -from openai import AsyncStream -from openai.types import CompletionUsage -from openai.types.chat import ChatCompletion, ChatCompletionChunk +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Provider for volcengine. +See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model + +config2.yaml example: +```yaml +llm: + base_url: "https://ark.cn-beijing.volces.com/api/v3" + api_type: "ark" + endpoint: "ep-2024080514****-d****" + api_key: "d47****b-****-****-****-d6e****0fd77" + pricing_plan: "doubao-lite" +``` +""" +from typing import Optional, Union + +from pydantic import BaseModel +from volcenginesdkarkruntime import AsyncArk +from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper +from volcenginesdkarkruntime._streaming import AsyncStream +from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk from metagpt.configs.llm_config import LLMType from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.logs import log_llm_stream from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM +from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS @register_provider(LLMType.ARK) @@ -16,11 +37,45 @@ class ArkLLM(OpenAILLM): 见:https://www.volcengine.com/docs/82379/1263482 """ + aclient: Optional[AsyncArk] = None + + def _init_client(self): + """SDK: https://github.com/openai/openai-python#async-usage""" + self.model = ( + self.config.endpoint or self.config.model + ) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint + self.pricing_plan = self.config.pricing_plan or self.model + kwargs = self._make_client_kwargs() + self.aclient = AsyncArk(**kwargs) + + def _make_client_kwargs(self) -> dict: + kvs = { + "ak": self.config.access_key, + "sk": self.config.secret_key, + "api_key": self.config.api_key, + "base_url": self.config.base_url, + } + kwargs = {k: v for k, v in kvs.items() if v} + + # to use proxy, openai v1 needs http_client + if proxy_params := self._get_proxy_params(): + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + + return kwargs + + def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): + if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs: + self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS) + if model in self.cost_manager.token_costs: + self.pricing_plan = model + if self.pricing_plan in self.cost_manager.token_costs: + super()._update_costs(usage, self.pricing_plan, local_calc_usage) + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True, - extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage + extra_body={"stream_options": {"include_usage": True}}, # 只有增加这个参数才会在流式时最后返回usage ) usage = None collected_messages = [] @@ -30,7 +85,7 @@ class ArkLLM(OpenAILLM): collected_messages.append(chunk_message) if chunk.usage: # 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[] - usage = CompletionUsage(**chunk.usage) + usage = chunk.usage log_llm_stream("\n") full_reply_content = "".join(collected_messages) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 4643b0829..a7df27258 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -188,6 +188,14 @@ FIREWORKS_GRADE_TOKEN_COSTS = { "mixtral-8x7b": {"prompt": 0.4, "completion": 1.6}, } +# https://console.volcengine.com/ark/region:ark+cn-beijing/model +DOUBAO_TOKEN_COSTS = { + "doubao-lite": {"prompt": 0.0003, "completion": 0.0006}, + "doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010}, + "doubao-pro": {"prompt": 0.0008, "completion": 0.0020}, + "doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090}, +} + # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo TOKEN_MAX = { "gpt-4o-2024-05-13": 128000, diff --git a/requirements.txt b/requirements.txt index 42483f5fe..db0204862 100644 --- a/requirements.txt +++ b/requirements.txt @@ -72,9 +72,11 @@ qianfan~=0.3.16 dashscope~=1.19.3 rank-bm25==0.2.2 # for tool recommendation jieba==0.42.1 # for tool recommendation +volcengine-python-sdk[ark]~=1.0.94 # llama-index-vector-stores-elasticsearch~=0.2.5 # Used by `metagpt/memory/longterm_memory.py` # llama-index-vector-stores-chroma~=0.1.10 # Used by `metagpt/memory/longterm_memory.py` gymnasium==0.29.1 boto3~=1.34.69 spark_ai_python~=0.3.30 agentops +