This commit is contained in:
usamimeri_renko 2024-04-26 17:15:24 +08:00
parent 8fafa2eb4e
commit 83c8ccb6b9
2 changed files with 2 additions and 11 deletions

View file

@ -17,11 +17,6 @@ except ImportError:
@register_provider([LLMType.AMAZON_BEDROCK])
class AmazonBedrockLLM(BaseLLM):
"""
check out:
https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html
"""
def __init__(self, config: LLMConfig):
self.config = config
self.__client = self.__init_client("bedrock-runtime")

View file

@ -31,7 +31,7 @@ def get_bedrock_request_body(model_id) -> dict:
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
def is_subset(subset, superset):
def is_subset(subset, superset) -> bool:
"""Ensure all fields in request body are allowed.
```python
@ -71,9 +71,6 @@ class TestAPI:
provider = bedrock_api._get_provider()
request_body = json.loads(provider.get_request_body(
messages, **bedrock_api._generate_kwargs))
print(get_bedrock_request_body(
bedrock_api.config.model))
print(request_body)
assert is_subset(request_body, get_bedrock_request_body(
bedrock_api.config.model))
@ -88,5 +85,4 @@ class TestAPI:
mock_bedrock_provider_response)
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
assert bedrock_api._chat_completion_stream(
messages) == "Hello World"
assert bedrock_api._chat_completion_stream(messages) == "Hello World"