resolve problems

This commit is contained in:
usamimeri_renko 2024-04-28 19:18:12 +08:00
parent 98cb452911
commit 79251cd3cd
9 changed files with 25 additions and 216 deletions

View file

@ -62,7 +62,7 @@ mock_llm_config_anthropic = LLMConfig(
)
mock_llm_config_bedrock = LLMConfig(
api_type="amazon_bedrock",
api_type="bedrock",
model="gpt-100",
region_name="somewhere",
access_key="123abc",

View file

@ -160,8 +160,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
),
ContentBlockDeltaEvent(
index=0,
delta=TextDelta(text=resp_cont_tmpl.format(
name=name), type="text_delta"),
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
type="content_block_delta",
),
]
@ -237,5 +236,5 @@ BEDROCK_PROVIDER_RESPONSE_BODY = {
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
},
"amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]}
}
"amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]}
}

View file

@ -1,6 +1,6 @@
import pytest
import json
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
from metagpt.provider.bedrock_api import AmazonBedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
@ -34,7 +34,7 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {
"body": [{'chunk': {'bytes': response_body_bytes}}]}
"body": [{"chunk": {"bytes": response_body_bytes}}]}
return response_body_stream
@ -74,13 +74,13 @@ def bedrock_api(request) -> AmazonBedrockLLM:
class TestAPI:
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
provider = bedrock_api._get_provider()
provider = bedrock_api.provider
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
bedrock_api.config.model)
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api._get_provider()
provider = bedrock_api.provider
request_body = json.loads(provider.get_request_body(
messages, bedrock_api._generate_kwargs))
@ -88,13 +88,13 @@ class TestAPI:
bedrock_api.config.model))
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
assert bedrock_api.completion(messages) == "Hello World"
def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model",
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model",
mock_bedrock_provider_response)
mocker.patch("metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mocker.patch("metagpt.provider.bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
assert bedrock_api._chat_completion_stream(messages) == "Hello World"