From aded1dc2ed3000ae5e0be39344935c01f17d8781 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 25 Apr 2024 21:23:25 +0800 Subject: [PATCH] add some type hint --- metagpt/provider/bedrock/amazon_bedrock_api.py | 10 +++------- metagpt/provider/bedrock/base_provider.py | 6 +++--- metagpt/provider/bedrock/bedrock_provider.py | 8 +++++--- metagpt/provider/bedrock/utils.py | 1 - 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/metagpt/provider/bedrock/amazon_bedrock_api.py b/metagpt/provider/bedrock/amazon_bedrock_api.py index 3b2ea0b81..2a72de019 100644 --- a/metagpt/provider/bedrock/amazon_bedrock_api.py +++ b/metagpt/provider/bedrock/amazon_bedrock_api.py @@ -1,5 +1,3 @@ - -import json from typing import Literal from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.provider.llm_provider_registry import register_provider @@ -38,13 +36,13 @@ class AmazonBedrockLLM(BaseLLM): logger.info("\n"+"\n".join(summaries)) @property - def _generate_kwargs(self): + def _generate_kwargs(self) -> dict: # for now only use temperature due to the difference of request body return { "temperature": self.config.get("temperature", 0.1), } - def completion(self, messages: list[dict]): + def completion(self, messages: list[dict]) -> str: request_body = self.__provider.get_request_body( messages, **self._generate_kwargs) response = self.__client.invoke_model( @@ -53,7 +51,7 @@ class AmazonBedrockLLM(BaseLLM): completions = self.__provider.get_choice_text(response) return completions - def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): + def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: if self.config.model in NOT_SUUPORT_STREAM_MODELS: logger.warning( f"model {self.config.model} doesn't support streaming output!") @@ -84,5 +82,3 @@ class AmazonBedrockLLM(BaseLLM): async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return self._chat_completion_stream(messages) - - diff --git a/metagpt/provider/bedrock/base_provider.py b/metagpt/provider/bedrock/base_provider.py index c591549ce..724cbb669 100644 --- a/metagpt/provider/bedrock/base_provider.py +++ b/metagpt/provider/bedrock/base_provider.py @@ -9,7 +9,7 @@ class BaseBedrockProvider(ABC): def _get_completion_from_dict(self, rsp_dict: dict) -> str: ... - def get_request_body(self, messages, **generate_kwargs): + def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( {"prompt": self.messages_to_prompt(messages), **generate_kwargs}) return body @@ -19,7 +19,7 @@ class BaseBedrockProvider(ABC): completions = self._get_completion_from_dict(response_body) return completions - def get_choice_text_from_stream(self, event): + def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = self._get_completion_from_dict(rsp_dict) return completions @@ -28,6 +28,6 @@ class BaseBedrockProvider(ABC): response_body = json.loads(response["body"].read()) return response_body - def messages_to_prompt(self, messages: list[dict]): + def messages_to_prompt(self, messages: list[dict]) -> str: """[{"role": "user", "content": msg}] to user: etc.""" return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) diff --git a/metagpt/provider/bedrock/bedrock_provider.py b/metagpt/provider/bedrock/bedrock_provider.py index e2dba9223..47348c083 100644 --- a/metagpt/provider/bedrock/bedrock_provider.py +++ b/metagpt/provider/bedrock/bedrock_provider.py @@ -16,7 +16,7 @@ class MistralProvider(BaseBedrockProvider): class AnthropicProvider(BaseBedrockProvider): # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - def get_request_body(self, messages, **generate_kwargs): + def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps( {"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs}) return body @@ -50,7 +50,9 @@ class Ai21Provider(BaseBedrockProvider): class AmazonProvider(BaseBedrockProvider): - def get_request_body(self, messages, **generate_kwargs): + # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html + + def get_request_body(self, messages: list[dict], **generate_kwargs): body = json.dumps({ "inputText": self.messages_to_prompt(messages), "textGenerationConfig": generate_kwargs @@ -60,7 +62,7 @@ class AmazonProvider(BaseBedrockProvider): def _get_completion_from_dict(self, rsp_dict: dict) -> str: return rsp_dict['results'][0]['outputText'].strip() - def get_choice_text_from_stream(self, event): + def get_choice_text_from_stream(self, event) -> str: rsp_dict = json.loads(event["chunk"]["bytes"]) completions = rsp_dict["outputText"] return completions diff --git a/metagpt/provider/bedrock/utils.py b/metagpt/provider/bedrock/utils.py index 57b83681c..58236157d 100644 --- a/metagpt/provider/bedrock/utils.py +++ b/metagpt/provider/bedrock/utils.py @@ -1,6 +1,5 @@ from metagpt.logs import logger - def messages_to_prompt_llama(messages: list[dict]): BOS, EOS = "", "" B_INST, E_INST = "[INST]", "[/INST]"