mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 02:23:52 +02:00
Refactored Ai21Provider to add support for Jamba-Instruct
This commit is contained in:
parent
9f9e405e86
commit
d65e5556dd
1 changed files with 27 additions and 2 deletions
|
|
@ -95,10 +95,32 @@ class MetaProvider(BaseBedrockProvider):
|
|||
class Ai21Provider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
|
||||
|
||||
max_tokens_field_name = "maxTokens"
|
||||
def __init__(self, model_type: Literal["j2", "jamba"]) -> None:
|
||||
self.model_type = model_type
|
||||
if self.model_type == "j2":
|
||||
self.max_tokens_field_name = "maxTokens"
|
||||
else:
|
||||
self.max_tokens_field_name = "max_tokens"
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
|
||||
if self.model_type == "j2":
|
||||
body = super().get_request_body(messages, generate_kwargs, *args, **kwargs)
|
||||
else:
|
||||
body = json.dumps(
|
||||
{
|
||||
"messages": messages,
|
||||
**generate_kwargs,
|
||||
}
|
||||
)
|
||||
return body
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["completions"][0]["data"]["text"]
|
||||
if self.model_type == "j2":
|
||||
# See https://docs.ai21.com/reference/j2-complete-ref
|
||||
return rsp_dict["completions"][0]["data"]["text"]
|
||||
else:
|
||||
# See https://docs.ai21.com/reference/jamba-instruct-api
|
||||
return rsp_dict["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
|
|
@ -136,4 +158,7 @@ def get_provider(model_id: str):
|
|||
if provider == "meta":
|
||||
# distinguish llama2 and llama3
|
||||
return PROVIDERS[provider](model_name[:6])
|
||||
elif provider == "ai21":
|
||||
# distinguish between j2 and jamba
|
||||
return PROVIDERS[provider](model_name.split("-")[0])
|
||||
return PROVIDERS[provider]()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue