diff --git a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py index 3b2aced5..f87cf917 100755 --- a/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py +++ b/trustgraph-bedrock/trustgraph/model/text_completion/bedrock/llm.py @@ -5,6 +5,7 @@ Input is prompt, output is response. Mistral is default. """ import boto3 +from botocore.errorfactory import ThrottlingException import json from prometheus_client import Histogram import os @@ -24,32 +25,48 @@ default_subscriber = module default_model = 'mistral.mistral-large-2407-v1:0' default_temperature = 0.0 default_max_output = 2048 -default_aws_id_key = os.getenv("AWS_ID_KEY", None) -default_aws_secret = os.getenv("AWS_SECRET", None) -default_aws_region = os.getenv("AWS_REGION", 'us-west-2') + +# Actually, these could all just be None, no need to get environment +# variables, as Boto3 would pick all these up if not passed in as args +default_access_key_id = os.getenv("AWS_ACCESS_KEY_ID", None) +default_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", None) +default_session_token = os.getenv("AWS_SESSION_TOKEN", None) +default_profile = os.getenv("AWS_PROFILE", None) +default_region = os.getenv("AWS_DEFAULT_REGION", None) class Processor(ConsumerProducer): def __init__(self, **params): + + print(params) input_queue = params.get("input_queue", default_input_queue) output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) + model = params.get("model", default_model) - aws_id_key = params.get("aws_id_key", default_aws_id_key) - aws_secret = params.get("aws_secret", default_aws_secret) - aws_region = params.get("aws_region", default_aws_region) temperature = params.get("temperature", default_temperature) max_output = params.get("max_output", default_max_output) - if aws_id_key is None: - raise RuntimeError("AWS ID not specified") + aws_access_key_id = params.get( + "aws_access_key_id", default_access_key_id + ) - if aws_secret is None: - raise RuntimeError("AWS secret not specified") + aws_secret_access_key = params.get( + "aws_secret_access_key", default_secret_access_key + ) - if aws_region is None: - raise RuntimeError("AWS region not specified") + aws_session_token = params.get( + "aws_session_token", default_session_token + ) + + aws_region = params.get( + "aws_region", default_region + ) + + aws_profile = params.get( + "aws_profile", default_profile + ) super(Processor, self).__init__( **params | { @@ -82,9 +99,11 @@ class Processor(ConsumerProducer): self.max_output = max_output self.session = boto3.Session( - aws_access_key_id=aws_id_key, - aws_secret_access_key=aws_secret, - region_name=aws_region + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + profile_name=aws_profile, + region_name=aws_region, ) self.bedrock = self.session.client(service_name='bedrock-runtime') @@ -179,9 +198,6 @@ class Processor(ConsumerProducer): accept = 'application/json' contentType = 'application/json' - # FIXME: Consider catching request limits and raise TooManyRequests - # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html - with __class__.text_completion_metric.time(): response = self.bedrock.invoke_model( body=promptbody, modelId=self.model, accept=accept, @@ -243,10 +259,7 @@ class Processor(ConsumerProducer): print("Done.", flush=True) - - # FIXME: Wrong exception, don't know what Bedrock throws - # for a rate limit - except TooManyRequests: + except ThrottlingException: print("Send rate limit response...", flush=True) @@ -300,21 +313,27 @@ class Processor(ConsumerProducer): ) parser.add_argument( - '-z', '--aws-id-key', - default=default_aws_id_key, - help=f'AWS ID Key' + '-z', '--aws-access-key-id', + default=default_access_key_id, + help=f'AWS access key ID' ) parser.add_argument( - '-k', '--aws-secret', - default=default_aws_secret, - help=f'AWS Secret Key' + '-k', '--aws-secret-access-key', + default=default_secret_access_key, + help=f'AWS secret access key' ) parser.add_argument( '-r', '--aws-region', - default=default_aws_region, - help=f'AWS Region' + default=default_region, + help=f'AWS region' + ) + + parser.add_argument( + '--aws-profile', '--profile', + default=default_profile, + help=f'AWS profile name' ) parser.add_argument(