This commit is contained in:
usamimeri_renko 2024-04-26 16:08:39 +08:00
parent cafe666bfd
commit 4c394a1cac
7 changed files with 198 additions and 29 deletions

View file

@ -60,3 +60,12 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
mock_llm_config_anthropic = LLMConfig(
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
)
mock_llm_config_bedrock = LLMConfig(
api_type="amazon_bedrock",
model="gpt-100",
region_name="somewhere",
access_key="123abc",
secret_key="123abc",
max_token=10000,
)

View file

@ -65,17 +65,20 @@ def get_openai_chat_completion(name: str) -> ChatCompletion:
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
message=ChatCompletionMessage(
role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
usage=CompletionUsage(completion_tokens=110,
prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = CompletionUsage(completion_tokens=110,
prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
@ -84,7 +87,8 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) ->
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
delta=ChoiceDelta(role="assistant",
content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
@ -133,7 +137,8 @@ def get_dashscope_response(name: str) -> GenerationResponse:
],
}
),
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
usage=GenerationUsage(
**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
)
)
@ -155,7 +160,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
),
ContentBlockDeltaEvent(
index=0,
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
delta=TextDelta(text=resp_cont_tmpl.format(
name=name), type="text_delta"),
type="content_block_delta",
),
]
@ -165,7 +171,8 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
model=name,
role="assistant",
type="message",
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
content=[ContentBlock(
text=resp_cont_tmpl.format(name=name), type="text")],
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
)
@ -183,3 +190,52 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
# For Amazon Bedrock
BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
"ai21": {
"prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0,
"stopSequences": [], "countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0}
},
"cohere": {
"prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [],
"return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE"
},
"anthropic": {
"anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "",
"messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": []
},
"amazon": {
"inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}
}
}
BEDROCK_PROVIDER_RESPONSE_BODY = {
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
"ai21": {
"id": "", "prompt": {"text": "Hello World", "tokens": []},
"completions": [{"data": {"text": "Hello World", "tokens": []},
"finishReason": {"reason": "length", "length": 2}}]
},
"cohere": {
"generations": [
{"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": ""
},
"anthropic": {
"id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}],
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
},
"amazon": {'inputTextTokenCount': 0, 'results': [{'tokenCount': 0, 'outputText': 'Hello World', 'completionReason': ""}]}
}
BEDROCK_PROVIDER_STREAM_RESPONSE = {}

View file

@ -0,0 +1,91 @@
import pytest
import json
from metagpt.provider.bedrock.amazon_bedrock_api import AmazonBedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
from tests.metagpt.provider.req_resp_const import (
BEDROCK_PROVIDER_REQUEST_BODY,
BEDROCK_PROVIDER_RESPONSE_BODY,
BEDROCK_PROVIDER_STREAM_RESPONSE)
from botocore.response import StreamingBody
# all available model from bedrock
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
messages = [{"role": "user", "content": "Hi!"}]
def mock_bedrock_provider_response(self, *args, **kwargs) -> dict:
provider = self.config.model.split(".")[0]
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> StreamingBody:
provider = self.config.model.split(".")[0]
response_json = BEDROCK_PROVIDER_STREAM_RESPONSE[provider]
return
def get_bedrock_request_body(model_id) -> dict:
provider = model_id.split(".")[0]
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
def is_subset(subset, superset):
"""Ensure all fields in request body are allowed.
```python
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
is_subset(subset, superset)
```
>>>False
"""
for key, value in subset.items():
if key not in superset:
return False
if isinstance(value, dict):
if not isinstance(superset[key], dict):
return False
if not is_subset(value, superset[key]):
return False
return True
@ pytest.fixture(scope="class", params=models)
def bedrock_api(request) -> AmazonBedrockLLM:
model_id = request.param
mock_llm_config_bedrock.model = model_id
api = AmazonBedrockLLM(mock_llm_config_bedrock)
return api
class TestAPI:
def test_generate_kwargs(self, bedrock_api: AmazonBedrockLLM):
provider = bedrock_api._get_provider()
assert bedrock_api._generate_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
bedrock_api.config.model)
def test_get_request_body(self, bedrock_api: AmazonBedrockLLM):
provider = bedrock_api._get_provider()
request_body = json.loads(provider.get_request_body(
messages, **bedrock_api._generate_kwargs))
print(get_bedrock_request_body(
bedrock_api.config.model))
print(request_body)
assert is_subset(request_body, get_bedrock_request_body(
bedrock_api.config.model))
def test_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
mocker.patch(
"metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model", mock_bedrock_provider_response)
assert bedrock_api.completion(messages) == "Hello World"
# def test_stream_completion(self, bedrock_api: AmazonBedrockLLM, mocker):
# mocker.patch(
# "metagpt.provider.bedrock.amazon_bedrock_api.AmazonBedrockLLM.invoke_model_with_response_stream", mock_bedrock_provider_response)
# assert bedrock_api._chat_completion_stream(messages) == "Hello World"
if __name__ == '__main__':
print(get_bedrock_request_body("amazon.titan-tg1-large"))