add pre-commit

This commit is contained in:
usamimeri_renko 2024-04-29 15:04:33 +08:00
parent 3f108abd06
commit f14a1f63ef
7 changed files with 132 additions and 97 deletions

View file

@ -65,20 +65,17 @@ def get_openai_chat_completion(name: str) -> ChatCompletion:
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
role="assistant", content=resp_cont_tmpl.format(name=name)),
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110,
prompt_tokens=92, total_tokens=202),
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110,
prompt_tokens=92, total_tokens=202)
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
@ -87,8 +84,7 @@ def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) ->
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant",
content=resp_cont_tmpl.format(name=name)),
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
@ -137,8 +133,7 @@ def get_dashscope_response(name: str) -> GenerationResponse:
],
}
),
usage=GenerationUsage(
**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
)
)
@ -170,8 +165,7 @@ def get_anthropic_response(name: str, stream: bool = False) -> Message:
model=name,
role="assistant",
type="message",
content=[ContentBlock(
text=resp_cont_tmpl.format(name=name), type="text")],
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
)
@ -198,43 +192,81 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
"ai21": {
"prompt": "", "temperature": 0.0, "topP": 0.0, "maxTokens": 0,
"stopSequences": [], "countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0}
"prompt": "",
"temperature": 0.0,
"topP": 0.0,
"maxTokens": 0,
"stopSequences": [],
"countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"cohere": {
"prompt": "", "temperature": 0.0, "p": 0.0, "k": 0.0, "max_tokens": 0, "stop_sequences": [],
"return_likelihoods": "NONE", "stream": False, "num_generations": 0, "logit_bias": {}, "truncate": "NONE"
"prompt": "",
"temperature": 0.0,
"p": 0.0,
"k": 0.0,
"max_tokens": 0,
"stop_sequences": [],
"return_likelihoods": "NONE",
"stream": False,
"num_generations": 0,
"logit_bias": {},
"truncate": "NONE",
},
"anthropic": {
"anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, "system": "",
"messages": [{"role": "", "content": ""}], "temperature": 0.0, "top_p": 0.0, "top_k": 0, "stop_sequences": []
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 0,
"system": "",
"messages": [{"role": "", "content": ""}],
"temperature": 0.0,
"top_p": 0.0,
"top_k": 0,
"stop_sequences": [],
},
"amazon": {
"inputText": "", "textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []}
}
"inputText": "",
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
},
}
BEDROCK_PROVIDER_RESPONSE_BODY = {
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
"ai21": {
"id": "", "prompt": {"text": "Hello World", "tokens": []},
"completions": [{"data": {"text": "Hello World", "tokens": []},
"finishReason": {"reason": "length", "length": 2}}]
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
],
},
"cohere": {
"generations": [
{"finish_reason": "", "id": "", "text": "Hello World", "likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}], "is_finished": True, "index": 0}], "id": "", "prompt": ""
{
"finish_reason": "",
"id": "",
"text": "Hello World",
"likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}],
"is_finished": True,
"index": 0,
}
],
"id": "",
"prompt": "",
},
"anthropic": {
"id": "", "model": "", "type": "message", "role": "assistant", "content": [{"type": "text", "text": "Hello World"}],
"stop_reason": "", "stop_sequence": "", "usage": {"input_tokens": 0, "output_tokens": 0}
"id": "",
"model": "",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello World"}],
"stop_reason": "",
"stop_sequence": "",
"usage": {"input_tokens": 0, "output_tokens": 0},
},
"amazon": {
"inputTextTokenCount": 0,
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
},
"amazon": {"inputTextTokenCount": 0, "results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}]}
}

View file

@ -1,12 +1,21 @@
import pytest
import json
import pytest
from metagpt.provider.bedrock.utils import (
NOT_SUUPORT_STREAM_MODELS,
SUPPORT_STREAM_MODELS,
get_max_tokens,
)
from metagpt.provider.bedrock_api import BedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from metagpt.provider.bedrock.utils import get_max_tokens, SUPPORT_STREAM_MODELS, NOT_SUUPORT_STREAM_MODELS
from tests.metagpt.provider.req_resp_const import BEDROCK_PROVIDER_REQUEST_BODY, BEDROCK_PROVIDER_RESPONSE_BODY
from tests.metagpt.provider.req_resp_const import (
BEDROCK_PROVIDER_REQUEST_BODY,
BEDROCK_PROVIDER_RESPONSE_BODY,
)
# all available model from bedrock
models = (SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS)
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
messages = [{"role": "user", "content": "Hi!"}]
@ -19,22 +28,21 @@ def mock_bedrock_provider_stream_response(self, *args, **kwargs) -> dict:
# use json object to mock EventStream
def dict2bytes(x):
return json.dumps(x).encode("utf-8")
provider = self.config.model.split(".")[0]
if provider == "amazon":
response_body_bytes = dict2bytes({"outputText": "Hello World"})
elif provider == "anthropic":
response_body_bytes = dict2bytes({"type": "content_block_delta", "index": 0,
"delta": {"type": "text_delta", "text": "Hello World"}})
response_body_bytes = dict2bytes(
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
)
elif provider == "cohere":
response_body_bytes = dict2bytes(
{"is_finished": False, "text": "Hello World"})
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
else:
response_body_bytes = dict2bytes(
BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {
"body": [{"chunk": {"bytes": response_body_bytes}}]}
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
return response_body_stream
@ -77,19 +85,19 @@ class TestBedrockAPI:
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_bedrock_provider_response)
def _patch_invoke_model_stream(self, mocker):
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response)
mocker.patch(
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
mock_bedrock_provider_stream_response,
)
def test_const_kwargs(self, bedrock_api: BedrockLLM):
provider = bedrock_api.provider
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(
bedrock_api.config.model)
assert bedrock_api._const_kwargs[provider.max_tokens_field_name] <= get_max_tokens(bedrock_api.config.model)
def test_get_request_body(self, bedrock_api: BedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api.provider
request_body = json.loads(provider.get_request_body(
messages, bedrock_api._const_kwargs))
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))