lazy installation

This commit is contained in:
usamimeri_renko 2024-04-26 02:02:38 +08:00
parent 6561c7aa7e
commit cafe666bfd
2 changed files with 10 additions and 2 deletions

View file

@ -6,7 +6,11 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.bedrock.bedrock_provider import get_provider
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
import boto3
try:
import boto3
except ImportError:
raise ImportError(
"boto3 not found! please install it by `pip install boto3` first ")
@register_provider([LLMType.AMAZON_BEDROCK])
@ -97,7 +101,7 @@ class AmazonBedrockLLM(BaseLLM):
return full_text
# boto3 don't support support asynchronous calls.
# for asynchronous version of boto3,check out:
# for asynchronous version of boto3, check out:
# https://aioboto3.readthedocs.io/en/latest/usage.html
# However,aioboto3 doesn't support invoke model

View file

@ -16,6 +16,7 @@ class MistralProvider(BaseBedrockProvider):
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def messages_to_prompt(self, messages: list[dict]) -> str:
return messages_to_prompt_claude(messages)
@ -37,6 +38,7 @@ class CohereProvider(BaseBedrockProvider):
class MetaProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
max_tokens_field_name = "max_gen_len"
def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None:
@ -54,6 +56,7 @@ class MetaProvider(BaseBedrockProvider):
class Ai21Provider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
max_tokens_field_name = "maxTokens"
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
@ -62,6 +65,7 @@ class Ai21Provider(BaseBedrockProvider):
class AmazonProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
max_tokens_field_name = "maxTokenCount"
def get_request_body(self, messages: list[dict], **generate_kwargs):