Merge pull request #1848 from better629/main

update test_bedrock_api
This commit is contained in:
better629 2025-06-13 22:15:25 +08:00 committed by GitHub
commit a855e66d04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 83 additions and 10 deletions

View file

@ -191,7 +191,7 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
"ai21": {
"ai21-j2": {
"prompt": "",
"temperature": 0.0,
"topP": 0.0,
@ -201,6 +201,16 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"ai21-jamba": {
"messages": [],
"temperature": 0.0,
"topP": 0.0,
"max_tokens": 0,
"stopSequences": [],
"countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"cohere": {
"prompt": "",
"temperature": 0.0,
@ -214,6 +224,20 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
"logit_bias": {},
"truncate": "NONE",
},
"cohere-command-r": {
"message": [],
"chat_history": [],
"temperature": 0.0,
"p": 0.0,
"k": 0.0,
"max_tokens": 0,
"stop_sequences": [],
"return_likelihoods": "NONE",
"stream": False,
"num_generations": 0,
"logit_bias": {},
"truncate": "NONE",
},
"anthropic": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 0,
@ -233,12 +257,20 @@ BEDROCK_PROVIDER_REQUEST_BODY = {
BEDROCK_PROVIDER_RESPONSE_BODY = {
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
"ai21": {
"ai21-jamba": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
],
"choices": [{"message": {"content": "Hello World"}}],
},
"ai21-jamba-stream": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"choices": [{"delta": {"content": "Hello World"}}],
},
"ai21-j2": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [{"data": {"text": "Hello World"}, "finishReason": {"reason": "length", "length": 2}}],
},
"cohere": {
"generations": [
@ -255,6 +287,21 @@ BEDROCK_PROVIDER_RESPONSE_BODY = {
"id": "",
"prompt": "",
},
"cohere-command-r": {
"generations": [
{
"finish_reason": "",
"id": "",
"text": "Hello World",
"likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}],
"is_finished": True,
"index": 0,
}
],
"id": "",
"prompt": "",
},
"anthropic": {
"id": "",
"model": "",

View file

@ -22,18 +22,42 @@ usage = {
}
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
provider = self.config.model.split(".")[0]
def get_provider_name(model: str) -> str:
arr = model.split(".")
if len(arr) == 2:
provider, model_name = arr # meta、mistral……
elif len(arr) == 3:
# some model_ids may contain country like us.xx.xxx
_, provider, model_name = arr
return provider
def deal_special_provider(provider: str, model: str, stream: bool = False) -> str:
# for ai21
if "j2-" in model:
provider = f"{provider}-j2"
elif "jamba-" in model:
provider = f"{provider}-jamba"
elif "command-r" in model:
provider = f"{provider}-command-r"
if stream and "ai21" in model:
provider = f"{provider}-stream"
return provider
async def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
provider = get_provider_name(self.config.model)
self._update_costs(usage, self.config.model)
provider = deal_special_provider(provider, self.config.model)
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
# use json object to mock EventStream
def dict2bytes(x):
return json.dumps(x).encode("utf-8")
provider = self.config.model.split(".")[0]
provider = get_provider_name(self.config.model)
if provider == "amazon":
response_body_bytes = dict2bytes({"outputText": "Hello World"})
@ -44,6 +68,7 @@ def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
elif provider == "cohere":
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
else:
provider = deal_special_provider(provider, self.config.model, stream=True)
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
@ -52,7 +77,8 @@ def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
def get_bedrock_request_body(model_id) -> dict:
provider = model_id.split(".")[0]
provider = get_provider_name(model_id)
provider = deal_special_provider(provider, model_id)
return BEDROCK_PROVIDER_REQUEST_BODY[provider]