mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
add some type hint
This commit is contained in:
parent
f45a379183
commit
aded1dc2ed
4 changed files with 11 additions and 14 deletions
|
|
@ -1,5 +1,3 @@
|
|||
|
||||
import json
|
||||
from typing import Literal
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
|
@ -38,13 +36,13 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
logger.info("\n"+"\n".join(summaries))
|
||||
|
||||
@property
|
||||
def _generate_kwargs(self):
|
||||
def _generate_kwargs(self) -> dict:
|
||||
# for now only use temperature due to the difference of request body
|
||||
return {
|
||||
"temperature": self.config.get("temperature", 0.1),
|
||||
}
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
def completion(self, messages: list[dict]) -> str:
|
||||
request_body = self.__provider.get_request_body(
|
||||
messages, **self._generate_kwargs)
|
||||
response = self.__client.invoke_model(
|
||||
|
|
@ -53,7 +51,7 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
completions = self.__provider.get_choice_text(response)
|
||||
return completions
|
||||
|
||||
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
|
||||
if self.config.model in NOT_SUUPORT_STREAM_MODELS:
|
||||
logger.warning(
|
||||
f"model {self.config.model} doesn't support streaming output!")
|
||||
|
|
@ -84,5 +82,3 @@ class AmazonBedrockLLM(BaseLLM):
|
|||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return self._chat_completion_stream(messages)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ class BaseBedrockProvider(ABC):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
...
|
||||
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
body = json.dumps(
|
||||
{"prompt": self.messages_to_prompt(messages), **generate_kwargs})
|
||||
return body
|
||||
|
|
@ -19,7 +19,7 @@ class BaseBedrockProvider(ABC):
|
|||
completions = self._get_completion_from_dict(response_body)
|
||||
return completions
|
||||
|
||||
def get_choice_text_from_stream(self, event):
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = self._get_completion_from_dict(rsp_dict)
|
||||
return completions
|
||||
|
|
@ -28,6 +28,6 @@ class BaseBedrockProvider(ABC):
|
|||
response_body = json.loads(response["body"].read())
|
||||
return response_body
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
def messages_to_prompt(self, messages: list[dict]) -> str:
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i['role']}: {i['content']}" for i in messages])
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ class MistralProvider(BaseBedrockProvider):
|
|||
class AnthropicProvider(BaseBedrockProvider):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
body = json.dumps(
|
||||
{"messages": messages, "anthropic_version": "bedrock-2023-05-31", **generate_kwargs})
|
||||
return body
|
||||
|
|
@ -50,7 +50,9 @@ class Ai21Provider(BaseBedrockProvider):
|
|||
|
||||
|
||||
class AmazonProvider(BaseBedrockProvider):
|
||||
def get_request_body(self, messages, **generate_kwargs):
|
||||
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
|
||||
|
||||
def get_request_body(self, messages: list[dict], **generate_kwargs):
|
||||
body = json.dumps({
|
||||
"inputText": self.messages_to_prompt(messages),
|
||||
"textGenerationConfig": generate_kwargs
|
||||
|
|
@ -60,7 +62,7 @@ class AmazonProvider(BaseBedrockProvider):
|
|||
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
|
||||
return rsp_dict['results'][0]['outputText'].strip()
|
||||
|
||||
def get_choice_text_from_stream(self, event):
|
||||
def get_choice_text_from_stream(self, event) -> str:
|
||||
rsp_dict = json.loads(event["chunk"]["bytes"])
|
||||
completions = rsp_dict["outputText"]
|
||||
return completions
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def messages_to_prompt_llama(messages: list[dict]):
|
||||
BOS, EOS = "<s>", "</s>"
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue