Refactored Ai21Provider to add support for Jamba-Instruct

This commit is contained in:
JGalego 2024-08-17 01:38:17 +01:00
parent 9f9e405e86
commit d65e5556dd

View file

@ -95,10 +95,32 @@ class MetaProvider(BaseBedrockProvider):
class Ai21Provider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
max_tokens_field_name = "maxTokens"
def __init__(self, model_type: Literal["j2", "jamba"]) -> None:
self.model_type = model_type
if self.model_type == "j2":
self.max_tokens_field_name = "maxTokens"
else:
self.max_tokens_field_name = "max_tokens"
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
if self.model_type == "j2":
body = super().get_request_body(messages, generate_kwargs, *args, **kwargs)
else:
body = json.dumps(
{
"messages": messages,
**generate_kwargs,
}
)
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["completions"][0]["data"]["text"]
if self.model_type == "j2":
# See https://docs.ai21.com/reference/j2-complete-ref
return rsp_dict["completions"][0]["data"]["text"]
else:
# See https://docs.ai21.com/reference/jamba-instruct-api
return rsp_dict["choices"][0]["message"]["content"]
class AmazonProvider(BaseBedrockProvider):
@ -136,4 +158,7 @@ def get_provider(model_id: str):
if provider == "meta":
# distinguish llama2 and llama3
return PROVIDERS[provider](model_name[:6])
elif provider == "ai21":
# distinguish between j2 and jamba
return PROVIDERS[provider](model_name.split("-")[0])
return PROVIDERS[provider]()