pre-commit

This commit is contained in:
usamimeri_renko 2024-05-20 15:16:40 +08:00
parent 7426ebc25a
commit c220da280d
2 changed files with 10 additions and 9 deletions

View file

@ -33,5 +33,5 @@ __all__ = [
"DashScopeLLM",
"AnthropicLLM",
"BedrockLLM",
"ArkLLM"
"ArkLLM",
]

View file

@ -2,11 +2,12 @@ from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from metagpt.provider.openai_api import OpenAILLM
from metagpt.configs.llm_config import LLMType
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
@register_provider(LLMType.ARK)
class ArkLLM(OpenAILLM):
@ -14,10 +15,10 @@ class ArkLLM(OpenAILLM):
用于火山方舟的API
https://www.volcengine.com/docs/82379/1263482
"""
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
extra_body={"stream_options": {"include_usage": True}}
)
@ -29,15 +30,15 @@ class ArkLLM(OpenAILLM):
collected_messages.append(chunk_message)
if chunk.usage:
# the usage of ark when streaming is in the last chunk while others are null
usage=CompletionUsage(**chunk.usage)
usage = CompletionUsage(**chunk.usage)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
self._update_costs(usage,chunk.model)
self._update_costs(usage, chunk.model)
return full_reply_content
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=self.get_timeout(timeout))
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage,rsp.model)
self._update_costs(rsp.usage, rsp.model)
return rsp