mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge pull request #1450 from JGalego/feat/bedrock-update
feat(bedrock): Temporary AWS credentials via env vars + supported models update
This commit is contained in:
commit
8f34c746a7
5 changed files with 178 additions and 60 deletions
|
|
@ -57,6 +57,7 @@ class LLMConfig(YamlModel):
|
|||
# For Cloud Service Provider like Baidu/ Alibaba
|
||||
access_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
session_token: Optional[str] = None
|
||||
endpoint: Optional[str] = None # for self-deployed model on the cloud
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
|
|
|
|||
|
|
@ -57,15 +57,52 @@ class AnthropicProvider(BaseBedrockProvider):
|
|||
|
||||
|
||||
class CohereProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
# For more information, see
|
||||
# (Command) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
# (Command R/R+) https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
|
||||
def __init__(self, model_name: str) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["generations"][0]["text"]
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
if "command-r" in self.model_name:
|
||||
role_map = {
|
||||
"user": "USER",
|
||||
"assistant": "CHATBOT",
|
||||
"system": "USER"
|
||||
}
|
||||
messages = list(
|
||||
map(
|
||||
lambda message: {
|
||||
"role": role_map[message["role"]],
|
||||
"message": message["content"]
|
||||
},
|
||||
messages
|
||||
)
|
||||
)
|
||||
return messages
|
||||
else:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
|
||||
)
|
||||
prompt = self.messages_to_prompt(messages)
|
||||
if "command-r" in self.model_name:
|
||||
chat_history, message = prompt[:-1], prompt[-1]["message"]
|
||||
body = json.dumps({
|
||||
"message": message,
|
||||
"chat_history": chat_history,
|
||||
**generate_kwargs
|
||||
})
|
||||
else:
|
||||
body = json.dumps({
|
||||
"prompt": prompt,
|
||||
"stream": kwargs.get("stream", False),
|
||||
**generate_kwargs
|
||||
})
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
|
|
@ -95,10 +132,37 @@ class MetaProvider(BaseBedrockProvider):
|
|||
class Ai21Provider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
|
||||
|
||||
max_tokens_field_name = "maxTokens"
|
||||
def __init__(self, model_type: Literal["j2", "jamba"]) -> None:
|
||||
self.model_type = model_type
|
||||
if self.model_type == "j2":
|
||||
self.max_tokens_field_name = "maxTokens"
|
||||
else:
|
||||
self.max_tokens_field_name = "max_tokens"
|
||||
|
||||
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
|
||||
if self.model_type == "j2":
|
||||
body = super().get_request_body(messages, generate_kwargs, *args, **kwargs)
|
||||
else:
|
||||
body = json.dumps(
|
||||
{
|
||||
"messages": messages,
|
||||
**generate_kwargs,
|
||||
}
|
||||
)
|
||||
return body
|
||||
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
return completions
|
||||
|
||||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict["completions"][0]["data"]["text"]
|
||||
if self.model_type == "j2":
|
||||
# See https://docs.ai21.com/reference/j2-complete-ref
|
||||
return rsp_dict["completions"][0]["data"]["text"]
|
||||
else:
|
||||
# See https://docs.ai21.com/reference/jamba-instruct-api
|
||||
return rsp_dict["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
|
|
@ -136,4 +200,10 @@ def get_provider(model_id: str):
|
|||
if provider == "meta":
|
||||
# distinguish llama2 and llama3
|
||||
return PROVIDERS[provider](model_name[:6])
|
||||
elif provider == "ai21":
|
||||
# distinguish between j2 and jamba
|
||||
return PROVIDERS[provider](model_name.split("-")[0])
|
||||
elif provider == "cohere":
|
||||
# distinguish between R/R+ and older models
|
||||
return PROVIDERS[provider](model_name)
|
||||
return PROVIDERS[provider]()
|
||||
|
|
|
|||
|
|
@ -1,52 +1,97 @@
|
|||
from metagpt.logs import logger
|
||||
|
||||
# max_tokens for each model
|
||||
NOT_SUUPORT_STREAM_MODELS = {
|
||||
"ai21.j2-grande-instruct": 8000,
|
||||
"ai21.j2-jumbo-instruct": 8000,
|
||||
"ai21.j2-mid": 8000,
|
||||
"ai21.j2-mid-v1": 8000,
|
||||
"ai21.j2-ultra": 8000,
|
||||
"ai21.j2-ultra-v1": 8000,
|
||||
NOT_SUPPORT_STREAM_MODELS = {
|
||||
# Jurassic-2 Mid-v1 and Ultra-v1
|
||||
# + Legacy date: 2024-04-30 (us-west-2/Oregon)
|
||||
# + EOL date: 2024-08-31 (us-west-2/Oregon)
|
||||
"ai21.j2-mid-v1": 8191,
|
||||
"ai21.j2-ultra-v1": 8191,
|
||||
}
|
||||
|
||||
SUPPORT_STREAM_MODELS = {
|
||||
"amazon.titan-tg1-large": 8000,
|
||||
"amazon.titan-text-express-v1": 8000,
|
||||
"amazon.titan-text-express-v1:0:8k": 8000,
|
||||
"amazon.titan-text-lite-v1:0:4k": 4000,
|
||||
"amazon.titan-text-lite-v1": 4000,
|
||||
"anthropic.claude-instant-v1": 100000,
|
||||
"anthropic.claude-instant-v1:2:100k": 100000,
|
||||
"anthropic.claude-v1": 100000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"anthropic.claude-v2:1": 200000,
|
||||
"anthropic.claude-v2:0:18k": 18000,
|
||||
"anthropic.claude-v2:1:200k": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
|
||||
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 200000,
|
||||
"cohere.command-text-v14": 4000,
|
||||
"cohere.command-text-v14:7:4k": 4000,
|
||||
"cohere.command-light-text-v14": 4000,
|
||||
"cohere.command-light-text-v14:7:4k": 4000,
|
||||
"meta.llama2-13b-chat-v1:0:4k": 4000,
|
||||
"meta.llama2-13b-chat-v1": 2000,
|
||||
"meta.llama2-70b-v1": 4000,
|
||||
"meta.llama2-70b-v1:0:4k": 4000,
|
||||
"meta.llama2-70b-chat-v1": 2000,
|
||||
"meta.llama2-70b-chat-v1:0:4k": 2000,
|
||||
"meta.llama3-8b-instruct-v1:0": 2000,
|
||||
"meta.llama3-70b-instruct-v1:0": 2000,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32000,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
|
||||
"mistral.mistral-large-2402-v1:0": 32000,
|
||||
# Jamba-Instruct
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jamba.html
|
||||
"ai21.jamba-instruct-v1:0": 4096,
|
||||
# Titan Text G1 - Lite
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
||||
"amazon.titan-text-lite-v1:0:4k": 4096,
|
||||
"amazon.titan-text-lite-v1": 4096,
|
||||
# Titan Text G1 - Express
|
||||
"amazon.titan-text-express-v1": 8192,
|
||||
"amazon.titan-text-express-v1:0:8k": 8192,
|
||||
# Titan Text Premier
|
||||
"amazon.titan-text-premier-v1:0": 3072,
|
||||
"amazon.titan-text-premier-v1:0:32k": 3072,
|
||||
# Claude Instant v1
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-text-completion.html
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models#model-comparison
|
||||
"anthropic.claude-instant-v1": 4096,
|
||||
"anthropic.claude-instant-v1:2:100k": 4096,
|
||||
# Claude v2
|
||||
"anthropic.claude-v2": 4096,
|
||||
"anthropic.claude-v2:0:18k": 4096,
|
||||
"anthropic.claude-v2:0:100k": 4096,
|
||||
# Claude v2.1
|
||||
"anthropic.claude-v2:1": 4096,
|
||||
"anthropic.claude-v2:1:18k": 4096,
|
||||
"anthropic.claude-v2:1:200k": 4096,
|
||||
# Claude 3 Sonnet
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": 4096,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 4096,
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 4096,
|
||||
# Claude 3 Haiku
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": 4096,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:48k": 4096,
|
||||
"anthropic.claude-3-haiku-20240307-v1:0:200k": 4096,
|
||||
# Claude 3 Opus
|
||||
"anthropic.claude-3-opus-20240229-v1:0": 4096,
|
||||
# Claude 3.5 Sonnet
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
|
||||
# Command Text
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
|
||||
"cohere.command-text-v14": 4096,
|
||||
"cohere.command-text-v14:7:4k": 4096,
|
||||
# Command Light Text
|
||||
"cohere.command-light-text-v14": 4096,
|
||||
"cohere.command-light-text-v14:7:4k": 4096,
|
||||
# Command R
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html
|
||||
"cohere.command-r-v1:0": 4096,
|
||||
# Command R+
|
||||
"cohere.command-r-plus-v1:0": 4096,
|
||||
# Llama 2 (--> Llama 3/3.1/3.2) !!!
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
|
||||
# + Legacy: 2024-05-12
|
||||
# + EOL: 2024-10-30
|
||||
# "meta.llama2-13b-chat-v1": 2048,
|
||||
# "meta.llama2-13b-chat-v1:0:4k": 2048,
|
||||
# "meta.llama2-70b-v1": 2048,
|
||||
# "meta.llama2-70b-v1:0:4k": 2048,
|
||||
# "meta.llama2-70b-chat-v1": 2048,
|
||||
# "meta.llama2-70b-chat-v1:0:4k": 2048,
|
||||
# Llama 3 Instruct
|
||||
# "meta.llama3-8b-instruct-v1:0": 2048,
|
||||
"meta.llama3-70b-instruct-v1:0": 2048,
|
||||
# Llama 3.1 Instruct
|
||||
# "meta.llama3-1-8b-instruct-v1:0": 2048,
|
||||
"meta.llama3-1-70b-instruct-v1:0": 2048,
|
||||
"meta.llama3-1-405b-instruct-v1:0": 2048,
|
||||
# Llama 3.2 Instruct
|
||||
# "meta.llama3-2-3b-instruct-v1:0": 2048,
|
||||
# "meta.llama3-2-11b-instruct-v1:0": 2048,
|
||||
"meta.llama3-2-90b-instruct-v1:0": 2048,
|
||||
# Mistral 7B Instruct
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral-text-completion.html
|
||||
# "mistral.mistral-7b-instruct-v0:2": 8192,
|
||||
# Mixtral 8x7B Instruct
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": 4096,
|
||||
# Mistral Small
|
||||
"mistral.mistral-small-2402-v1:0": 8192,
|
||||
# Mistral Large (24.02)
|
||||
"mistral.mistral-large-2402-v1:0": 8192,
|
||||
# Mistral Large 2 (24.07)
|
||||
"mistral.mistral-large-2407-v1:0": 8192
|
||||
}
|
||||
|
||||
# TODO:use a more general function for constructing chat templates.
|
||||
|
|
@ -106,7 +151,7 @@ def messages_to_prompt_claude2(messages: list[dict]) -> str:
|
|||
|
||||
def get_max_tokens(model_id: str) -> int:
|
||||
try:
|
||||
max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
|
||||
except KeyError:
|
||||
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
|
||||
max_tokens = 2048
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
|
|
@ -11,7 +12,7 @@ from metagpt.const import USE_CONFIG_TIMEOUT
|
|||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.bedrock.bedrock_provider import get_provider
|
||||
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
|
||||
from metagpt.provider.bedrock.utils import NOT_SUPPORT_STREAM_MODELS, get_max_tokens
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
|
||||
|
|
@ -24,18 +25,19 @@ class BedrockLLM(BaseLLM):
|
|||
self.__client = self.__init_client("bedrock-runtime")
|
||||
self.__provider = get_provider(self.config.model)
|
||||
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
logger.warning(f"model {self.config.model} doesn't support streaming output!")
|
||||
|
||||
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
|
||||
"""initialize boto3 client"""
|
||||
# access key and secret key from https://us-east-1.console.aws.amazon.com/iam
|
||||
self.__credentital_kwargs = {
|
||||
"aws_secret_access_key": self.config.secret_key,
|
||||
"aws_access_key_id": self.config.access_key,
|
||||
"region_name": self.config.region_name,
|
||||
self.__credential_kwargs = {
|
||||
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY", self.config.secret_key),
|
||||
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID", self.config.access_key),
|
||||
"aws_session_token": os.environ.get("AWS_SESSION_TOKEN", self.config.session_token),
|
||||
"region_name": os.environ.get("AWS_DEFAULT_REGION", self.config.region_name),
|
||||
}
|
||||
session = boto3.Session(**self.__credentital_kwargs)
|
||||
session = boto3.Session(**self.__credential_kwargs)
|
||||
client = session.client(service_name)
|
||||
return client
|
||||
|
||||
|
|
@ -111,7 +113,7 @@ class BedrockLLM(BaseLLM):
|
|||
return await self.acompletion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
|
||||
rsp = await self.acompletion(messages)
|
||||
full_text = self.get_choice_text(rsp)
|
||||
log_llm_stream(full_text)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import json
|
|||
import pytest
|
||||
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
NOT_SUPPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
|
|
@ -14,7 +14,7 @@ from tests.metagpt.provider.req_resp_const import (
|
|||
)
|
||||
|
||||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUPPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
usage = {
|
||||
"prompt_tokens": 1000000,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue