add some type hint

This commit is contained in:
usamimeri_renko 2024-04-25 21:23:25 +08:00
parent f45a379183
commit aded1dc2ed
4 changed files with 11 additions and 14 deletions

View file

@ -1,5 +1,3 @@
import json
from typing import Literal
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.provider.llm_provider_registry import register_provider
@ -38,13 +36,13 @@ class AmazonBedrockLLM(BaseLLM):
logger.info("\n"+"\n".join(summaries))
@property
def _generate_kwargs(self):
def _generate_kwargs(self) -> dict:
# for now only use temperature due to the difference of request body
return {
"temperature": self.config.get("temperature", 0.1),
}
def completion(self, messages: list[dict]):
def completion(self, messages: list[dict]) -> str:
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
response = self.__client.invoke_model(
@ -53,7 +51,7 @@ class AmazonBedrockLLM(BaseLLM):
completions = self.__provider.get_choice_text(response)
return completions
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
logger.warning(
f"model {self.config.model} doesn't support streaming output!")
@ -84,5 +82,3 @@ class AmazonBedrockLLM(BaseLLM):
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return self._chat_completion_stream(messages)

View file

@ -9,7 +9,7 @@ class BaseBedrockProvider(ABC):
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
...
def get_request_body(self, messages, **generate_kwargs):
def get_request_body(self, messages: list[dict], **generate_kwargs):
body = json.dumps(
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
return body
@ -19,7 +19,7 @@ class BaseBedrockProvider(ABC):
completions = self._get_completion_from_dict(response_body)
return completions
def get_choice_text_from_stream(self, event):
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = self._get_completion_from_dict(rsp_dict)
return completions
@ -28,6 +28,6 @@ class BaseBedrockProvider(ABC):
response_body = json.loads(response["body"].read())
return response_body
def messages_to_prompt(self, messages: list[dict]):
def messages_to_prompt(self, messages: list[dict]) -> str:
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])

View file

@ -16,7 +16,7 @@ class MistralProvider(BaseBedrockProvider):
class AnthropicProvider(BaseBedrockProvider):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
def get_request_body(self, messages, **generate_kwargs):
def get_request_body(self, messages: list[dict], **generate_kwargs):
body = json.dumps(
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
return body
@ -50,7 +50,9 @@ class Ai21Provider(BaseBedrockProvider):
class AmazonProvider(BaseBedrockProvider):
def get_request_body(self, messages, **generate_kwargs):
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
def get_request_body(self, messages: list[dict], **generate_kwargs):
body = json.dumps({
"inputText": self.messages_to_prompt(messages),
"textGenerationConfig": generate_kwargs
@ -60,7 +62,7 @@ class AmazonProvider(BaseBedrockProvider):
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
return rsp_dict['results'][0]['outputText'].strip()
def get_choice_text_from_stream(self, event):
def get_choice_text_from_stream(self, event) -> str:
rsp_dict = json.loads(event["chunk"]["bytes"])
completions = rsp_dict["outputText"]
return completions

View file

@ -1,6 +1,5 @@
from metagpt.logs import logger
def messages_to_prompt_llama(messages: list[dict]):
BOS, EOS = "<s>", "</s>"
B_INST, E_INST = "[INST]", "[/INST]"