From 38d6d81a2036865ec87235c5ab251024466ff4d8 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 13 Jun 2025 22:12:55 +0800 Subject: [PATCH] update test_bedrock_api --- tests/metagpt/provider/req_resp_const.py | 57 ++++++++++++++++++++-- tests/metagpt/provider/test_bedrock_api.py | 36 ++++++++++++-- 2 files changed, 83 insertions(+), 10 deletions(-) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 111b57f91..38fdd6063 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -191,7 +191,7 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[ BEDROCK_PROVIDER_REQUEST_BODY = { "mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0}, "meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0}, - "ai21": { + "ai21-j2": { "prompt": "", "temperature": 0.0, "topP": 0.0, @@ -201,6 +201,16 @@ BEDROCK_PROVIDER_REQUEST_BODY = { "presencePenalty": {"scale": 0.0}, "frequencyPenalty": {"scale": 0.0}, }, + "ai21-jamba": { + "messages": [], + "temperature": 0.0, + "topP": 0.0, + "max_tokens": 0, + "stopSequences": [], + "countPenalty": {"scale": 0.0}, + "presencePenalty": {"scale": 0.0}, + "frequencyPenalty": {"scale": 0.0}, + }, "cohere": { "prompt": "", "temperature": 0.0, @@ -214,6 +224,20 @@ BEDROCK_PROVIDER_REQUEST_BODY = { "logit_bias": {}, "truncate": "NONE", }, + "cohere-command-r": { + "message": [], + "chat_history": [], + "temperature": 0.0, + "p": 0.0, + "k": 0.0, + "max_tokens": 0, + "stop_sequences": [], + "return_likelihoods": "NONE", + "stream": False, + "num_generations": 0, + "logit_bias": {}, + "truncate": "NONE", + }, "anthropic": { "anthropic_version": "bedrock-2023-05-31", "max_tokens": 0, @@ -233,12 +257,20 @@ BEDROCK_PROVIDER_REQUEST_BODY = { BEDROCK_PROVIDER_RESPONSE_BODY = { "mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]}, "meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""}, - "ai21": { + "ai21-jamba": { "id": "", "prompt": {"text": "Hello World", "tokens": []}, - "completions": [ - {"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}} - ], + "choices": [{"message": {"content": "Hello World"}}], + }, + "ai21-jamba-stream": { + "id": "", + "prompt": {"text": "Hello World", "tokens": []}, + "choices": [{"delta": {"content": "Hello World"}}], + }, + "ai21-j2": { + "id": "", + "prompt": {"text": "Hello World", "tokens": []}, + "completions": [{"data": {"text": "Hello World"}, "finishReason": {"reason": "length", "length": 2}}], }, "cohere": { "generations": [ @@ -255,6 +287,21 @@ BEDROCK_PROVIDER_RESPONSE_BODY = { "id": "", "prompt": "", }, + "cohere-command-r": { + "generations": [ + { + "finish_reason": "", + "id": "", + "text": "Hello World", + "likelihood": 0.0, + "token_likelihoods": [{"token": 0.0}], + "is_finished": True, + "index": 0, + } + ], + "id": "", + "prompt": "", + }, "anthropic": { "id": "", "model": "", diff --git a/tests/metagpt/provider/test_bedrock_api.py b/tests/metagpt/provider/test_bedrock_api.py index 28d1d7008..1dc73d024 100644 --- a/tests/metagpt/provider/test_bedrock_api.py +++ b/tests/metagpt/provider/test_bedrock_api.py @@ -22,18 +22,42 @@ usage = { } -def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: - provider = self.config.model.split(".")[0] +def get_provider_name(model: str) -> str: + arr = model.split(".") + if len(arr) == 2: + provider, model_name = arr # meta、mistral…… + elif len(arr) == 3: + # some model_ids may contain country like us.xx.xxx + _, provider, model_name = arr + return provider + + +def deal_special_provider(provider: str, model: str, stream: bool = False) -> str: + # for ai21 + if "j2-" in model: + provider = f"{provider}-j2" + elif "jamba-" in model: + provider = f"{provider}-jamba" + elif "command-r" in model: + provider = f"{provider}-command-r" + if stream and "ai21" in model: + provider = f"{provider}-stream" + return provider + + +async def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict: + provider = get_provider_name(self.config.model) self._update_costs(usage, self.config.model) + provider = deal_special_provider(provider, self.config.model) return BEDROCK_PROVIDER_RESPONSE_BODY[provider] -def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: +async def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: # use json object to mock EventStream def dict2bytes(x): return json.dumps(x).encode("utf-8") - provider = self.config.model.split(".")[0] + provider = get_provider_name(self.config.model) if provider == "amazon": response_body_bytes = dict2bytes({"outputText": "Hello World"}) @@ -44,6 +68,7 @@ def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: elif provider == "cohere": response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"}) else: + provider = deal_special_provider(provider, self.config.model, stream=True) response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider]) response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]} @@ -52,7 +77,8 @@ def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict: def get_bedrock_request_body(model_id) -> dict: - provider = model_id.split(".")[0] + provider = get_provider_name(model_id) + provider = deal_special_provider(provider, model_id) return BEDROCK_PROVIDER_REQUEST_BODY[provider]