diff --git a/config/examples/huoshan_ark.yaml b/config/examples/huoshan_ark.yaml new file mode 100644 index 000000000..b0516359b --- /dev/null +++ b/config/examples/huoshan_ark.yaml @@ -0,0 +1,5 @@ +llm: + api_type: "ark" + model: "" # your model endpoint like ep-xxx + base_url: "https://ark.cn-beijing.volces.com/api/v3" + api_key: "" # your api-key like ey…… \ No newline at end of file diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index dbf04dac6..12bb8541e 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -33,6 +33,7 @@ class LLMType(Enum): YI = "yi" # lingyiwanwu OPENROUTER = "openrouter" BEDROCK = "bedrock" + ARK = "ark" def __missing__(self, key): return self.OPENAI diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index fcb5fa32a..c90f5774a 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -18,6 +18,7 @@ from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM from metagpt.provider.anthropic_api import AnthropicLLM from metagpt.provider.bedrock_api import BedrockLLM +from metagpt.provider.ark_api import ArkLLM __all__ = [ "GeminiLLM", @@ -32,4 +33,5 @@ __all__ = [ "DashScopeLLM", "AnthropicLLM", "BedrockLLM", + "ArkLLM", ] diff --git a/metagpt/provider/ark_api.py b/metagpt/provider/ark_api.py new file mode 100644 index 000000000..c24bd1ee9 --- /dev/null +++ b/metagpt/provider/ark_api.py @@ -0,0 +1,44 @@ +from openai import AsyncStream +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion, ChatCompletionChunk + +from metagpt.configs.llm_config import LLMType +from metagpt.const import USE_CONFIG_TIMEOUT +from metagpt.logs import log_llm_stream +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import OpenAILLM + + +@register_provider(LLMType.ARK) +class ArkLLM(OpenAILLM): + """ + 用于火山方舟的API + 见:https://www.volcengine.com/docs/82379/1263482 + """ + + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( + **self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), + stream=True, + extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage + ) + usage = None + collected_messages = [] + async for chunk in response: + chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message + log_llm_stream(chunk_message) + collected_messages.append(chunk_message) + if chunk.usage: + # 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[] + usage = CompletionUsage(**chunk.usage) + + log_llm_stream("\n") + full_reply_content = "".join(collected_messages) + self._update_costs(usage, chunk.model) + return full_reply_content + + async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion: + kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout)) + rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) + self._update_costs(rsp.usage, rsp.model) + return rsp diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index e813127a4..ef6f886e2 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -68,6 +68,15 @@ TOKEN_COSTS = { "openai/gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, "deepseek-chat": {"prompt": 0.00014, "completion": 0.00028}, "deepseek-coder": {"prompt": 0.00014, "completion": 0.00028}, + # For ark model https://www.volcengine.com/docs/82379/1099320 + "doubao-lite-4k-240515": {"prompt": 0.000042, "completion": 0.000084}, + "doubao-lite-32k-240515": {"prompt": 0.000042, "completion": 0.000084}, + "doubao-lite-128k-240515": {"prompt": 0.00011, "completion": 0.00013}, + "doubao-pro-4k-240515": {"prompt": 0.00011, "completion": 0.00028}, + "doubao-pro-32k-240515": {"prompt": 0.00011, "completion": 0.00028}, + "doubao-pro-128k-240515": {"prompt": 0.0007, "completion": 0.0012}, + "llama3-70b-llama3-70b-instruct": {"prompt": 0.0, "completion": 0.0}, + "llama3-8b-llama3-8b-instruct": {"prompt": 0.0, "completion": 0.0}, } @@ -209,6 +218,12 @@ TOKEN_MAX = { "openai/gpt-4-turbo-preview": 128000, "deepseek-chat": 32768, "deepseek-coder": 16385, + "doubao-lite-4k-240515": 4000, + "doubao-lite-32k-240515": 32000, + "doubao-lite-128k-240515": 128000, + "doubao-pro-4k-240515": 4000, + "doubao-pro-32k-240515": 32000, + "doubao-pro-128k-240515": 128000, } # For Amazon Bedrock US region diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index 8f2baea10..f563dccad 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -69,3 +69,5 @@ mock_llm_config_bedrock = LLMConfig( secret_key="123abc", max_token=10000, ) + +mock_llm_config_ark = LLMConfig(api_type="ark", api_key="eyxxx", base_url="xxx", model="ep-xxx") diff --git a/tests/metagpt/provider/test_ark.py b/tests/metagpt/provider/test_ark.py new file mode 100644 index 000000000..c3fb25846 --- /dev/null +++ b/tests/metagpt/provider/test_ark.py @@ -0,0 +1,85 @@ +""" +用于火山方舟Python SDK V3的测试用例 +API文档:https://www.volcengine.com/docs/82379/1263482 +""" + +from typing import AsyncIterator, List, Union + +import pytest +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta + +from metagpt.provider.ark_api import ArkLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_ark +from tests.metagpt.provider.req_resp_const import ( + get_openai_chat_completion, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) + +name = "AI assistant" +resp_cont = resp_cont_tmpl.format(name=name) +USAGE = {"completion_tokens": 1000, "prompt_tokens": 1000, "total_tokens": 2000} +default_resp = get_openai_chat_completion(name) +default_resp.model = "doubao-pro-32k-240515" +default_resp.usage = USAGE + + +def create_chat_completion_chunk( + content: str, finish_reason: str = None, choices: List[Choice] = None +) -> ChatCompletionChunk: + if choices is None: + choices = [ + Choice( + delta=ChoiceDelta(content=content, function_call=None, role="assistant", tool_calls=None), + finish_reason=finish_reason, + index=0, + logprobs=None, + ) + ] + + return ChatCompletionChunk( + id="012", + choices=choices, + created=1716278586, + model="doubao-pro-32k-240515", + object="chat.completion.chunk", + system_fingerprint=None, + usage=None if choices else USAGE, + ) + + +ark_resp_chunk = create_chat_completion_chunk(content="") +ark_resp_chunk_finish = create_chat_completion_chunk(content=resp_cont, finish_reason="stop") +ark_resp_chunk_last = create_chat_completion_chunk(content="", choices=[]) + + +async def chunk_iterator(chunks: List[ChatCompletionChunk]) -> AsyncIterator[ChatCompletionChunk]: + for chunk in chunks: + yield chunk + + +async def mock_ark_acompletions_create( + self, stream: bool = False, **kwargs +) -> Union[ChatCompletionChunk, ChatCompletion]: + if stream: + chunks = [ark_resp_chunk, ark_resp_chunk_finish, ark_resp_chunk_last] + return chunk_iterator(chunks) + else: + return default_resp + + +@pytest.mark.asyncio +async def test_ark_acompletion(mocker): + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_ark_acompletions_create) + + llm = ArkLLM(mock_llm_config_ark) + + resp = await llm.acompletion(messages) + assert resp.choices[0].finish_reason == "stop" + assert resp.choices[0].message.content == resp_cont + assert resp.usage == USAGE + + await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)