Merge pull request #1231 from usamimeri/main

Feat add Amazon Bedrock support
This commit is contained in:
Alexander Wu 2024-05-17 10:56:48 +08:00 committed by GitHub
commit ca67c8f9ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 665 additions and 6 deletions

View file

@ -1,6 +1,4 @@
{
"executablePath": "/usr/bin/chromium",
"args": [
"--no-sandbox"
]
}
"executablePath": "/usr/bin/chromium",
"args": ["--no-sandbox"]
}

View file

@ -32,6 +32,7 @@ class LLMType(Enum):
MISTRAL = "mistral"
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
def __missing__(self, key):
return self.OPENAI
@ -74,10 +75,14 @@ class LLMConfig(YamlModel):
best_of: Optional[int] = None
n: Optional[int] = None
stream: bool = False
logprobs: Optional[bool] = None # https://cookbook.openai.com/examples/using_logprobs
# https://cookbook.openai.com/examples/using_logprobs
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
timeout: int = 600
# For Amazon Bedrock
region_name: str = None
# For Network
proxy: Optional[str] = None

View file

@ -17,6 +17,7 @@ from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
from metagpt.provider.anthropic_api import AnthropicLLM
from metagpt.provider.bedrock_api import BedrockLLM
__all__ = [
"GeminiLLM",
@ -30,4 +31,5 @@ __all__ = [
"QianFanLLM",
"DashScopeLLM",
"AnthropicLLM",
"BedrockLLM",
]

View file

View file

@ -0,0 +1,28 @@
import json
from abc import ABC, abstractmethod
class BaseBedrockProvider(ABC):
# to handle different generation kwargs
max_tokens_field_name = "max_tokens"
@abstractmethod
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...
def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs) -> str:
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
return body
def get_choice_text(self, response_body: dict) -> str:
completions = self._get_completion_from_dict(response_body)
return completions
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = self._get_completion_from_dict(rsp_dict)
return completions
def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -0,0 +1,121 @@
import json
from typing import Literal
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
from metagpt.provider.bedrock.utils import (
messages_to_prompt_llama2,
messages_to_prompt_llama3,
)
class MistralProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
def messages_to_prompt(self, messages: list[dict]):
return messages_to_prompt_llama2(messages)
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["outputs"][0]["text"]
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["content"][0]["text"]
def get_choice_text_from_stream(self, event) -> str:
# https://docs.anthropic.com/claude/reference/messages-streaming
rsp_dict = json.loads(event["chunk"]["bytes"])
if rsp_dict["type"] == "content_block_delta":
completions = rsp_dict["delta"]["text"]
return completions
else:
return ""
class CohereProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generations"][0]["text"]
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs}
)
return body
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict.get("text", "")
return completions
class MetaProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
max_tokens_field_name = "max_gen_len"
def __init__(self, llama_version: Literal["llama2", "llama3"]) -> None:
self.llama_version = llama_version
def messages_to_prompt(self, messages: list[dict]):
if self.llama_version == "llama2":
return messages_to_prompt_llama2(messages)
else:
return messages_to_prompt_llama3(messages)
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generation"]
class Ai21Provider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-jurassic2.html
max_tokens_field_name = "maxTokens"
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["completions"][0]["data"]["text"]
class AmazonProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
max_tokens_field_name = "maxTokenCount"
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({"inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs})
return body
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["results"][0]["outputText"]
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict["outputText"]
return completions
PROVIDERS = {
"mistral": MistralProvider,
"meta": MetaProvider,
"ai21": Ai21Provider,
"cohere": CohereProvider,
"anthropic": AnthropicProvider,
"amazon": AmazonProvider,
}
def get_provider(model_id: str):
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
if provider not in PROVIDERS:
raise KeyError(f"{provider} is not supported!")
if provider == "meta":
# distinguish llama2 and llama3
return PROVIDERS[provider](model_name[:6])
return PROVIDERS[provider]()

View file

@ -0,0 +1,112 @@
from metagpt.logs import logger
# max_tokens for each model
NOT_SUUPORT_STREAM_MODELS = {
"ai21.j2-grande-instruct": 8000,
"ai21.j2-jumbo-instruct": 8000,
"ai21.j2-mid": 8000,
"ai21.j2-mid-v1": 8000,
"ai21.j2-ultra": 8000,
"ai21.j2-ultra-v1": 8000,
}
SUPPORT_STREAM_MODELS = {
"amazon.titan-tg1-large": 8000,
"amazon.titan-text-express-v1": 8000,
"amazon.titan-text-express-v1:0:8k": 8000,
"amazon.titan-text-lite-v1:0:4k": 4000,
"amazon.titan-text-lite-v1": 4000,
"anthropic.claude-instant-v1": 100000,
"anthropic.claude-instant-v1:2:100k": 100000,
"anthropic.claude-v1": 100000,
"anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000,
"anthropic.claude-v2:0:18k": 18000,
"anthropic.claude-v2:1:200k": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"cohere.command-text-v14": 4000,
"cohere.command-text-v14:7:4k": 4000,
"cohere.command-light-text-v14": 4000,
"cohere.command-light-text-v14:7:4k": 4000,
"meta.llama2-13b-chat-v1:0:4k": 4000,
"meta.llama2-13b-chat-v1": 2000,
"meta.llama2-70b-v1": 4000,
"meta.llama2-70b-v1:0:4k": 4000,
"meta.llama2-70b-chat-v1": 4000,
"meta.llama2-70b-chat-v1:0:4k": 4000,
"meta.llama3-8b-instruct-v1:0": 2000,
"meta.llama3-70b-instruct-v1:0": 2000,
"mistral.mistral-7b-instruct-v0:2": 32000,
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
"mistral.mistral-large-2402-v1:0": 32000,
}
# TODO:use a more general function for constructing chat templates.
def messages_to_prompt_llama2(messages: list[dict]) -> str:
BOS = ("<s>",)
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
prompt = f"{BOS}"
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "system":
prompt += f"{B_SYS} {content} {E_SYS}"
elif role == "user":
prompt += f"{B_INST} {content} {E_INST}"
elif role == "assistant":
prompt += f"{content}"
else:
logger.warning(f"Unknown role name {role} when formatting messages")
prompt += f"{content}"
return prompt
def messages_to_prompt_llama3(messages: list[dict]) -> str:
BOS = "<|begin_of_text|>"
GENERAL_TEMPLATE = "<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>"
prompt = f"{BOS}"
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
if role != "assistant":
prompt += "<|start_header_id|>assistant<|end_header_id|>"
return prompt
def messages_to_prompt_claude2(messages: list[dict]) -> str:
GENERAL_TEMPLATE = "\n\n{role}: {content}"
prompt = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
if role != "assistant":
prompt += "\n\nAssistant:"
return prompt
def get_max_tokens(model_id: str) -> int:
try:
max_tokens = (NOT_SUUPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]
except KeyError:
logger.warning(f"Couldn't find model:{model_id} , max tokens has been set to 2048")
max_tokens = 2048
return max_tokens

View file

@ -0,0 +1,140 @@
import json
from typing import Literal
import boto3
from botocore.eventstream import EventStream
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.bedrock.bedrock_provider import get_provider
from metagpt.provider.bedrock.utils import NOT_SUUPORT_STREAM_MODELS, get_max_tokens
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import BEDROCK_TOKEN_COSTS
@register_provider([LLMType.BEDROCK])
class BedrockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.__client = self.__init_client("bedrock-runtime")
self.__provider = get_provider(self.config.model)
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
logger.warning("Amazon bedrock doesn't support asynchronous now")
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
logger.warning(f"model {self.config.model} doesn't support streaming output!")
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
"""initialize boto3 client"""
# access key and secret key from https://us-east-1.console.aws.amazon.com/iam
self.__credentital_kwargs = {
"aws_secret_access_key": self.config.secret_key,
"aws_access_key_id": self.config.access_key,
"region_name": self.config.region_name,
}
session = boto3.Session(**self.__credentital_kwargs)
client = session.client(service_name)
return client
@property
def client(self):
return self.__client
@property
def provider(self):
return self.__provider
def list_models(self):
"""list all available text-generation models
```shell
ai21.j2-ultra-v1 Support Streaming:False
meta.llama3-70b-instruct-v1:0 Support Streaming:True
```
"""
client = self.__init_client("bedrock")
# only output text-generation models
response = client.list_foundation_models(byOutputModality="TEXT")
summaries = [
f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
for summary in response["modelSummaries"]
]
logger.info("\n" + "\n".join(summaries))
def invoke_model(self, request_body: str) -> dict:
response = self.__client.invoke_model(modelId=self.config.model, body=request_body)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
response_body = self._get_response_body(response)
return response_body
def invoke_model_with_response_stream(self, request_body: str) -> EventStream:
response = self.__client.invoke_model_with_response_stream(modelId=self.config.model, body=request_body)
usage = self._get_usage(response)
self._update_costs(usage, self.config.model)
return response
@property
def _const_kwargs(self) -> dict:
model_max_tokens = get_max_tokens(self.config.model)
if self.config.max_token > model_max_tokens:
max_tokens = model_max_tokens
else:
max_tokens = self.config.max_token
return {self.__provider.max_tokens_field_name: max_tokens, "temperature": self.config.temperature}
# boto3 don't support support asynchronous calls.
# for asynchronous version of boto3, check out:
# https://aioboto3.readthedocs.io/en/latest/usage.html
# However,aioboto3 doesn't support invoke model
def get_choice_text(self, rsp: dict) -> str:
return self.__provider.get_choice_text(rsp)
async def acompletion(self, messages: list[dict]) -> dict:
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
response_body = self.invoke_model(request_body)
return response_body
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self.acompletion(messages)
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
rsp = await self.acompletion(messages)
full_text = self.get_choice_text(rsp)
log_llm_stream(full_text)
return full_text
request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []
for event in response["body"]:
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)
log_llm_stream("\n")
full_text = ("".join(collected_content)).lstrip()
return full_text
def _get_response_body(self, response) -> dict:
response_body = json.loads(response["body"].read())
return response_body
def _get_usage(self, response) -> dict[str, int]:
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0))
completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0))
usage = (
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
},
)
return usage

View file

@ -210,6 +210,53 @@ TOKEN_MAX = {
"deepseek-coder": 16385,
}
# For Amazon Bedrock US region
# See https://aws.amazon.com/cn/bedrock/pricing/
BEDROCK_TOKEN_COSTS = {
"amazon.titan-tg1-large": {"prompt": 0.0008, "completion": 0.0008},
"amazon.titan-text-express-v1": {"prompt": 0.0008, "completion": 0.0008},
"amazon.titan-text-express-v1:0:8k": {"prompt": 0.0008, "completion": 0.0008},
"amazon.titan-text-lite-v1:0:4k": {"prompt": 0.0003, "completion": 0.0004},
"amazon.titan-text-lite-v1": {"prompt": 0.0003, "completion": 0.0004},
"anthropic.claude-instant-v1": {"prompt": 0.0008, "completion": 0.00024},
"anthropic.claude-instant-v1:2:100k": {"prompt": 0.0008, "completion": 0.00024},
"anthropic.claude-v1": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2:1": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2:0:18k": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-v2:1:200k": {"prompt": 0.008, "completion": 0.0024},
"anthropic.claude-3-sonnet-20240229-v1:0": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-sonnet-20240229-v1:0:28k": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-sonnet-20240229-v1:0:200k": {"prompt": 0.003, "completion": 0.015},
"anthropic.claude-3-haiku-20240307-v1:0": {"prompt": 0.00025, "completion": 0.00125},
"anthropic.claude-3-haiku-20240307-v1:0:48k": {"prompt": 0.00025, "completion": 0.00125},
"anthropic.claude-3-haiku-20240307-v1:0:200k": {"prompt": 0.00025, "completion": 0.00125},
# currently (2024-4-29) only available at US West (Oregon) AWS Region.
"anthropic.claude-3-opus-20240229-v1:0": {"prompt": 0.015, "completion": 0.075},
"cohere.command-text-v14": {"prompt": 0.0015, "completion": 0.0015},
"cohere.command-text-v14:7:4k": {"prompt": 0.0015, "completion": 0.0015},
"cohere.command-light-text-v14": {"prompt": 0.0003, "completion": 0.0003},
"cohere.command-light-text-v14:7:4k": {"prompt": 0.0003, "completion": 0.0003},
"meta.llama2-13b-chat-v1:0:4k": {"prompt": 0.00075, "completion": 0.001},
"meta.llama2-13b-chat-v1": {"prompt": 0.00075, "completion": 0.001},
"meta.llama2-70b-v1": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama2-70b-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama2-70b-chat-v1": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama2-70b-chat-v1:0:4k": {"prompt": 0.00195, "completion": 0.00256},
"meta.llama3-8b-instruct-v1:0": {"prompt": 0.0004, "completion": 0.0006},
"meta.llama3-70b-instruct-v1:0": {"prompt": 0.00265, "completion": 0.0035},
"mistral.mistral-7b-instruct-v0:2": {"prompt": 0.00015, "completion": 0.0002},
"mistral.mixtral-8x7b-instruct-v0:1": {"prompt": 0.00045, "completion": 0.0007},
"mistral.mistral-large-2402-v1:0": {"prompt": 0.008, "completion": 0.024},
"ai21.j2-grande-instruct": {"prompt": 0.0125, "completion": 0.0125},
"ai21.j2-jumbo-instruct": {"prompt": 0.0188, "completion": 0.0188},
"ai21.j2-mid": {"prompt": 0.0125, "completion": 0.0125},
"ai21.j2-mid-v1": {"prompt": 0.0125, "completion": 0.0125},
"ai21.j2-ultra": {"prompt": 0.0188, "completion": 0.0188},
"ai21.j2-ultra-v1": {"prompt": 0.0188, "completion": 0.0188},
}
def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
"""Return the number of tokens used by a list of messages."""

View file

@ -70,3 +70,4 @@ qianfan==0.3.2
dashscope==1.14.1
rank-bm25==0.2.2 # for tool recommendation
gymnasium==0.29.1
boto3==1.34.92

View file

@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
mock_llm_config_anthropic = LLMConfig(
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
)
mock_llm_config_bedrock = LLMConfig(
api_type="bedrock",
model="gpt-100",
region_name="somewhere",
access_key="123abc",
secret_key="123abc",
max_token=10000,
)

View file

@ -183,3 +183,90 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
# For Amazon Bedrock
# Check the API documentation of each model
# https://docs.aws.amazon.com/bedrock/latest/userguide
BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
"ai21": {
"prompt": "",
"temperature": 0.0,
"topP": 0.0,
"maxTokens": 0,
"stopSequences": [],
"countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"cohere": {
"prompt": "",
"temperature": 0.0,
"p": 0.0,
"k": 0.0,
"max_tokens": 0,
"stop_sequences": [],
"return_likelihoods": "NONE",
"stream": False,
"num_generations": 0,
"logit_bias": {},
"truncate": "NONE",
},
"anthropic": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 0,
"system": "",
"messages": [{"role": "", "content": ""}],
"temperature": 0.0,
"top_p": 0.0,
"top_k": 0,
"stop_sequences": [],
},
"amazon": {
"inputText": "",
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
},
}
BEDROCK_PROVIDER_RESPONSE_BODY = {
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
"ai21": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
],
},
"cohere": {
"generations": [
{
"finish_reason": "",
"id": "",
"text": "Hello World",
"likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}],
"is_finished": True,
"index": 0,
}
],
"id": "",
"prompt": "",
},
"anthropic": {
"id": "",
"model": "",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello World"}],
"stop_reason": "",
"stop_sequence": "",
"usage": {"input_tokens": 0, "output_tokens": 0},
},
"amazon": {
"inputTextTokenCount": 0,
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
},
}

View file

@ -0,0 +1,109 @@
import json
import pytest
from metagpt.provider.bedrock.utils import (
NOT_SUUPORT_STREAM_MODELS,
SUPPORT_STREAM_MODELS,
)
from metagpt.provider.bedrock_api import BedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from tests.metagpt.provider.req_resp_const import (
BEDROCK_PROVIDER_REQUEST_BODY,
BEDROCK_PROVIDER_RESPONSE_BODY,
)
# all available model from bedrock
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
messages = [{"role": "user", "content": "Hi!"}]
usage = {
"prompt_tokens": 1000000,
"completion_tokens": 1000000,
}
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
provider = self.config.model.split(".")[0]
self._update_costs(usage, self.config.model)
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
# use json object to mock EventStream
def dict2bytes(x):
return json.dumps(x).encode("utf-8")
provider = self.config.model.split(".")[0]
if provider == "amazon":
response_body_bytes = dict2bytes({"outputText": "Hello World"})
elif provider == "anthropic":
response_body_bytes = dict2bytes(
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
)
elif provider == "cohere":
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
else:
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
self._update_costs(usage, self.config.model)
return response_body_stream
def get_bedrock_request_body(model_id) -> dict:
provider = model_id.split(".")[0]
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
def is_subset(subset, superset) -> bool:
"""Ensure all fields in request body are allowed.
```python
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
is_subset(subset, superset)
```
>>>False
"""
for key, value in subset.items():
if key not in superset:
return False
if isinstance(value, dict):
if not isinstance(superset[key], dict):
return False
if not is_subset(value, superset[key]):
return False
return True
@pytest.fixture(scope="class", params=models)
def bedrock_api(request) -> BedrockLLM:
model_id = request.param
mock_llm_config_bedrock.model = model_id
api = BedrockLLM(mock_llm_config_bedrock)
return api
class TestBedrockAPI:
def _patch_invoke_model(self, mocker):
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
def _patch_invoke_model_stream(self, mocker):
mocker.patch(
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
mock_invoke_model_stream,
)
def test_get_request_body(self, bedrock_api: BedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api.provider
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
@pytest.mark.asyncio
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
self._patch_invoke_model(mocker)
self._patch_invoke_model_stream(mocker)
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
assert await bedrock_api.aask(messages, stream=True) == "Hello World"