From d65e5556dd6f9a6eec5f4e67ee5c0d59112bd22d Mon Sep 17 00:00:00 2001 From: JGalego Date: Sat, 17 Aug 2024 01:38:17 +0100 Subject: [PATCH] Refactored Ai21Provider to add support for Jamba-Instruct --- metagpt/provider/bedrock/bedrock_provider.py | 29 ++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index 1236bf56b..6a7c91740 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -95,10 +95,32 @@ class MetaProvider(BaseBedrockProvider): class Ai21Provider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html - max_tokens_field_name = "maxTokens" + def __init__(self, model_type: Literal["j2", "jamba"]) -> None: + self.model_type = model_type + if self.model_type == "j2": + self.max_tokens_field_name = "maxTokens" + else: + self.max_tokens_field_name = "max_tokens" + + def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str: + if self.model_type == "j2": + body = super().get_request_body(messages, generate_kwargs, *args, **kwargs) + else: + body = json.dumps( + { + "messages": messages, + **generate_kwargs, + } + ) + return body def _get_completion_from_dict(self, rsp_dict: dict) -> str: - return rsp_dict["completions"][0]["data"]["text"] + if self.model_type == "j2": + # See https://docs.ai21.com/reference/j2-complete-ref + return rsp_dict["completions"][0]["data"]["text"] + else: + # See https://docs.ai21.com/reference/jamba-instruct-api + return rsp_dict["choices"][0]["message"]["content"] class AmazonProvider(BaseBedrockProvider): @@ -136,4 +158,7 @@ def get_provider(model_id: str): if provider == "meta": # distinguish llama2 and llama3 return PROVIDERS[provider](model_name[:6]) + elif provider == "ai21": + # distinguish between j2 and jamba + return PROVIDERS[provider](model_name.split("-")[0]) return PROVIDERS[provider]()