Add files via upload

This commit is contained in:
YangQianli92 2024-02-29 10:15:09 +08:00 committed by GitHub
parent fbd5d65e6b
commit 8f267b33d2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 23 deletions

View file

@ -21,11 +21,15 @@ class LLMProviderRegistry:
return self.providers[enum]
def register_provider(key):
def register_provider(keys):
"""register provider to registry"""
def decorator(cls):
LLM_REGISTRY.register(key, cls)
if isinstance(keys,list):
for key in keys:
LLM_REGISTRY.register(key, cls)
else:
LLM_REGISTRY.register(keys, cls)
return cls
return decorator

View file

@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -50,7 +50,7 @@ See FAQ 5.8
raise retry_state.outcome.exception()
@register_provider(LLMType.OPENAI)
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT])
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
@ -84,14 +84,33 @@ class OpenAILLM(BaseLLM):
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
usage = None
collected_messages = []
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
finish_reason = chunk.choices[0].finish_reason if hasattr(chunk.choices[0], "finish_reason") else None
log_llm_stream(chunk_message)
collected_messages.append(chunk_message)
if finish_reason:
if hasattr(chunk, "usage"):
# Some services have usage as an attribute of the chunk, such as Fireworks
usage = CompletionUsage(**chunk.usage)
elif hasattr(chunk.choices[0], "usage"):
# The usage of some services is an attribute of chunk.choices[0], such as Moonshot
usage = CompletionUsage(**chunk.choices[0].usage)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
if not usage:
# Some services do not provide the usage attribute, such as OpenAI or OpenLLM
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
kwargs = {
@ -99,7 +118,7 @@ class OpenAILLM(BaseLLM):
"max_tokens": self._get_max_tokens(messages),
"n": 1,
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": self.config.temperature,
"temperature": 0.3,
"model": self.model,
"timeout": max(self.config.timeout, timeout),
}
@ -126,18 +145,7 @@ class OpenAILLM(BaseLLM):
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
resp = self._achat_completion_stream(messages, timeout=timeout)
collected_messages = []
async for i in resp:
log_llm_stream(i)
collected_messages.append(i)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
await self._achat_completion_stream(messages, timeout=timeout)
rsp = await self._achat_completion(messages, timeout=timeout)
return self.get_choice_text(rsp)
@ -261,11 +269,12 @@ class OpenAILLM(BaseLLM):
if not self.config.calc_usage:
return usage
model = self.model if not isinstance(self.cost_manager, TokenCostManager) else "open-llm-model"
try:
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
usage.prompt_tokens = count_message_tokens(messages, model)
usage.completion_tokens = count_string_tokens(rsp, model)
except Exception as e:
logger.warning(f"usage calculation failed: {e}")
logger.error(f"usage calculation failed: {e}")
return usage