Merge pull request #1450 from JGalego/feat/bedrock-update

feat(bedrock): Temporary AWS credentials via env vars + supported models update
This commit is contained in:
Alexander Wu 2024-10-20 13:58:03 +08:00 committed by GitHub
commit 8f34c746a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 178 additions and 60 deletions

View file

@ -57,6 +57,7 @@ class LLMConfig(YamlModel):
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
session_token: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later

View file

@ -57,15 +57,52 @@ class AnthropicProvider(BaseBedrockProvider):
class CohereProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
# For more information, see
# (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
# (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generations"][0]["text"]
def messages_to_prompt(self, messages: list[dict]) -> str:
if "command-r" in self.model_name:
role_map = {
"user": "USER",
"assistant": "CHATBOT",
"system": "USER"
}
messages = list(
map(
lambda message: {
"role": role_map[message["role"]],
"message": message["content"]
},
messages
)
)
return messages
else:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
)
prompt = self.messages_to_prompt(messages)
if "command-r" in self.model_name:
chat_history, message = prompt[:-1], prompt[-1]["message"]
body = json.dumps({
"message": message,
"chat_history": chat_history,
**generate_kwargs
})
else:
body = json.dumps({
"prompt": prompt,
"stream": kwargs.get("stream", False),
**generate_kwargs
})
return body
def get_choice_text_from_stream(self, event) -> str:
@ -95,10 +132,37 @@ class MetaProvider(BaseBedrockProvider):
class Ai21Provider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
max_tokens_field_name = "maxTokens"
def __init__(self, model_type: Literal["j2", "jamba"]) -> None:
self.model_type = model_type
if self.model_type == "j2":
self.max_tokens_field_name = "maxTokens"
else:
self.max_tokens_field_name = "max_tokens"
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
if self.model_type == "j2":
body = super().get_request_body(messages, generate_kwargs, *args, **kwargs)
else:
body = json.dumps(
{
"messages": messages,
**generate_kwargs,
}
)
return body
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
return completions
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["completions"][0]["data"]["text"]
if self.model_type == "j2":
# See https://docs.ai21.com/reference/j2-complete-ref
return rsp_dict["completions"][0]["data"]["text"]
else:
# See https://docs.ai21.com/reference/jamba-instruct-api
return rsp_dict["choices"][0]["message"]["content"]
class AmazonProvider(BaseBedrockProvider):
@ -136,4 +200,10 @@ def get_provider(model_id: str):
if provider == "meta":
# distinguish llama2 and llama3
return PROVIDERS[provider](model_name[:6])
elif provider == "ai21":
# distinguish between j2 and jamba
return PROVIDERS[provider](model_name.split("-")[0])
elif provider == "cohere":
# distinguish between R/R+ and older models
return PROVIDERS[provider](model_name)
return PROVIDERS[provider]()

View file

@ -1,52 +1,97 @@
from metagpt.logs import logger
# max_tokens for each model
NOT_SUUPORT_STREAM_MODELS = {
"ai21.j2-grande-instruct": 8000,
"ai21.j2-jumbo-instruct": 8000,
"ai21.j2-mid": 8000,
"ai21.j2-mid-v1": 8000,
"ai21.j2-ultra": 8000,
"ai21.j2-ultra-v1": 8000,
NOT_SUPPORT_STREAM_MODELS = {
# Jurassic-2 Mid-v1 and Ultra-v1
# + Legacy date: 2024-04-30 (us-west-2/Oregon)
# + EOL date: 2024-08-31 (us-west-2/Oregon)
"ai21.j2-mid-v1": 8191,
"ai21.j2-ultra-v1": 8191,
}
SUPPORT_STREAM_MODELS = {
"amazon.titan-tg1-large": 8000,
"amazon.titan-text-express-v1": 8000,
"amazon.titan-text-express-v1:0:8k": 8000,
"amazon.titan-text-lite-v1:0:4k": 4000,
"amazon.titan-text-lite-v1": 4000,
"anthropic.claude-instant-v1": 100000,
"anthropic.claude-instant-v1:2:100k": 100000,
"anthropic.claude-v1": 100000,
"anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000,
"anthropic.claude-v2:0:18k": 18000,
"anthropic.claude-v2:1:200k": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"cohere.command-text-v14": 4000,
"cohere.command-text-v14:7:4k": 4000,
"cohere.command-light-text-v14": 4000,
"cohere.command-light-text-v14:7:4k": 4000,
"meta.llama2-13b-chat-v1:0:4k": 4000,
"meta.llama2-13b-chat-v1": 2000,
"meta.llama2-70b-v1": 4000,
"meta.llama2-70b-v1:0:4k": 4000,
"meta.llama2-70b-chat-v1": 2000,
"meta.llama2-70b-chat-v1:0:4k": 2000,
"meta.llama3-8b-instruct-v1:0": 2000,
"meta.llama3-70b-instruct-v1:0": 2000,
"mistral.mistral-7b-instruct-v0:2": 32000,
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
"mistral.mistral-large-2402-v1:0": 32000,
# Jamba-Instruct
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jamba.html
"ai21.jamba-instruct-v1:0": 4096,
# Titan Text G1 - Lite
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
"amazon.titan-text-lite-v1:0:4k": 4096,
"amazon.titan-text-lite-v1": 4096,
# Titan Text G1 - Express
"amazon.titan-text-express-v1": 8192,
"amazon.titan-text-express-v1:0:8k": 8192,
# Titan Text Premier
"amazon.titan-text-premier-v1:0": 3072,
"amazon.titan-text-premier-v1:0:32k": 3072,
# Claude Instant v1
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
# https://docs.anthropic.com/en/docs/about-claude/models#model-comparison
"anthropic.claude-instant-v1": 4096,
"anthropic.claude-instant-v1:2:100k": 4096,
# Claude v2
"anthropic.claude-v2": 4096,
"anthropic.claude-v2:0:18k": 4096,
"anthropic.claude-v2:0:100k": 4096,
# Claude v2.1
"anthropic.claude-v2:1": 4096,
"anthropic.claude-v2:1:18k": 4096,
"anthropic.claude-v2:1:200k": 4096,
# Claude 3 Sonnet
"anthropic.claude-3-sonnet-20240229-v1:0": 4096,
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 4096,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 4096,
# Claude 3 Haiku
"anthropic.claude-3-haiku-20240307-v1:0": 4096,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 4096,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 4096,
# Claude 3 Opus
"anthropic.claude-3-opus-20240229-v1:0": 4096,
# Claude 3.5 Sonnet
"anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
# Command Text
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
"cohere.command-text-v14": 4096,
"cohere.command-text-v14:7:4k": 4096,
# Command Light Text
"cohere.command-light-text-v14": 4096,
"cohere.command-light-text-v14:7:4k": 4096,
# Command R
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
"cohere.command-r-v1:0": 4096,
# Command R+
"cohere.command-r-plus-v1:0": 4096,
# Llama 2 (--> Llama 3/3.1/3.2) !!!
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
# + Legacy: 2024-05-12
# + EOL: 2024-10-30
# "meta.llama2-13b-chat-v1": 2048,
# "meta.llama2-13b-chat-v1:0:4k": 2048,
# "meta.llama2-70b-v1": 2048,
# "meta.llama2-70b-v1:0:4k": 2048,
# "meta.llama2-70b-chat-v1": 2048,
# "meta.llama2-70b-chat-v1:0:4k": 2048,
# Llama 3 Instruct
# "meta.llama3-8b-instruct-v1:0": 2048,
"meta.llama3-70b-instruct-v1:0": 2048,
# Llama 3.1 Instruct
# "meta.llama3-1-8b-instruct-v1:0": 2048,
"meta.llama3-1-70b-instruct-v1:0": 2048,
"meta.llama3-1-405b-instruct-v1:0": 2048,
# Llama 3.2 Instruct
# "meta.llama3-2-3b-instruct-v1:0": 2048,
# "meta.llama3-2-11b-instruct-v1:0": 2048,
"meta.llama3-2-90b-instruct-v1:0": 2048,
# Mistral 7B Instruct
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html
# "mistral.mistral-7b-instruct-v0:2": 8192,
# Mixtral 8x7B Instruct
"mistral.mixtral-8x7b-instruct-v0:1": 4096,
# Mistral Small
"mistral.mistral-small-2402-v1:0": 8192,
# Mistral Large (24.02)
"mistral.mistral-large-2402-v1:0": 8192,
# Mistral Large 2 (24.07)
"mistral.mistral-large-2407-v1:0": 8192
}
# TODO:use a more general function for constructing chat templates.
@ -106,7 +151,7 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str:
def get_max_tokens(model_id: str) -> int:
try:
max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
except KeyError:
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
max_tokens = 2048

View file

@ -1,3 +1,4 @@
import os
import asyncio
import json
from functools import partial
@ -11,7 +12,7 @@ from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.bedrock.bedrock_provider import get_provider
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
from metagpt.provider.bedrock.utils import NOT_SUPPORT_STREAM_MODELS, get_max_tokens
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
@ -24,18 +25,19 @@ class BedrockLLM(BaseLLM):
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
"""initialize boto3 client"""
# access key and secret key from https://us-east-1.console.aws.amazon.com/iam
self.__credentital_kwargs = {
"aws_secret_access_key": self.config.secret_key,
"aws_access_key_id": self.config.access_key,
"region_name": self.config.region_name,
self.__credential_kwargs = {
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY", self.config.secret_key),
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID", self.config.access_key),
"aws_session_token": os.environ.get("AWS_SESSION_TOKEN", self.config.session_token),
"region_name": os.environ.get("AWS_DEFAULT_REGION", self.config.region_name),
}
session = boto3.Session(**self.__credentital_kwargs)
session = boto3.Session(**self.__credential_kwargs)
client = session.client(service_name)
return client
@ -111,7 +113,7 @@ class BedrockLLM(BaseLLM):
return await self.acompletion(messages)
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
rsp = await self.acompletion(messages)
full_text = self.get_choice_text(rsp)
log_llm_stream(full_text)

View file

@ -3,7 +3,7 @@ import json
import pytest
from metagpt.provider.bedrock.utils import (
NOT_SUUPORT_STREAM_MODELS,
NOT_SUPPORT_STREAM_MODELS,
SUPPORT_STREAM_MODELS,
)
from metagpt.provider.bedrock_api import BedrockLLM
@ -14,7 +14,7 @@ from tests.metagpt.provider.req_resp_const import (
)
# all available model from bedrock
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
models = SUPPORT_STREAM_MODELS | NOT_SUPPORT_STREAM_MODELS
messages = [{"role": "user", "content": "Hi!"}]
usage = {
"prompt_tokens": 1000000,