remove opus since unavailable now and fix bug

This commit is contained in:
usamimeri_renko 2024-04-27 13:09:04 +08:00
parent a05d25757c
commit 98cb452911
5 changed files with 42 additions and 16 deletions

View file

@ -86,7 +86,7 @@ class AmazonBedrockLLM(BaseLLM):
def completion(self, messages: list[dict]) -> str:
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
messages, self._generate_kwargs)
response_body = self.invoke_model(request_body)
completions = self.__provider.get_choice_text(response_body)
return completions
@ -98,7 +98,7 @@ class AmazonBedrockLLM(BaseLLM):
return self.completion(messages)
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
messages, self._generate_kwargs, stream=True)
response = self.invoke_model_with_response_stream(request_body)
collected_content = []

View file

@ -10,7 +10,7 @@ class BaseBedrockProvider(ABC):
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...
def get_request_body(self, messages: list[dict], **generate_kwargs):
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
return body

View file

@ -17,7 +17,7 @@ class MistralProvider(BaseBedrockProvider):
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def get_request_body(self, messages: list[dict], **generate_kwargs):
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
return body
@ -41,6 +41,16 @@ class CohereProvider(BaseBedrockProvider):
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict["generations"][0]["text"]
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), "stream": kwargs.get("stream", False), **generate_kwargs})
return body
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict.get("text", "")
return completions
class MetaProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
@ -74,7 +84,7 @@ class AmazonProvider(BaseBedrockProvider):
max_tokens_field_name = "maxTokenCount"
def get_request_body(self, messages: list[dict], **generate_kwargs):
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs):
body = json.dumps({
"inputText": self.messages_to_prompt(messages),
"textGenerationConfig": generate_kwargs

View file

@ -14,15 +14,24 @@ SUPPORT_STREAM_MODELS = {
"amazon.titan-tg1-large": 8000,
"amazon.titan-text-express-v1": 8000,
"anthropic.claude-instant-v1": 100000,
"anthropic.claude-instant-v1:2:100k": 100000,
"anthropic.claude-v1": 100000,
"anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0:28k": 28000,
"anthropic.claude-3-sonnet-20240229-v1:0:200k": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"cohere.command-text-v14": 4096,
"cohere.command-light-text-v14": 4096,
"meta.llama2-70b-v1": 4096,
"anthropic.claude-3-haiku-20240307-v1:0:48k": 48000,
"anthropic.claude-3-haiku-20240307-v1:0:200k": 200000,
"cohere.command-text-v14": 4000,
"cohere.command-text-v14:7:4k": 4000,
"cohere.command-light-text-v14": 4000,
"cohere.command-light-text-v14:7:4k": 4000,
"meta.llama2-70b-v1": 4000,
"meta.llama2-13b-chat-v1:0:4k": 4000,
"meta.llama2-13b-chat-v1": 2000,
"meta.llama2-70b-v1:0:4k": 4000,
"meta.llama3-8b-instruct-v1:0": 2000,
"meta.llama3-70b-instruct-v1:0": 2000,
"mistral.mistral-7b-instruct-v0:2": 32000,

View file

@ -20,14 +20,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
def dict2bytes(x):
return json.dumps(x).encode("utf-8")
provider = self.config.model.split(".")[0]
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
# decoded bytes share the same format as non-stream response_body except for titan
if provider == "amazon":
response_body_stream = {
"body": [{'chunk': {'bytes': dict2bytes({"outputText": "Hello World"})}}]}
response_body_bytes = dict2bytes({"outputText": "Hello World"})
elif provider == "anthropic":
response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0,
"delta": {"type": "text_delta", "text": "Hello World"}})
elif provider == "cohere":
response_body_bytes = dict2bytes(
{"is_finished": False, "text": "Hello World"})
else:
response_body_stream = {
"body": [{'chunk': {'bytes': response_body_bytes}}]}
response_body_bytes = dict2bytes(
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {
"body": [{'chunk': {'bytes': response_body_bytes}}]}
return response_body_stream
@ -75,7 +82,7 @@ class TestAPI:
"""Ensure request body has correct format"""
provider = bedrock_api._get_provider()
request_body = json.loads(provider.get_request_body(
messages, **bedrock_api._generate_kwargs))
messages, bedrock_api._generate_kwargs))
assert is_subset(request_body, get_bedrock_request_body(
bedrock_api.config.model))