change provider to private

This commit is contained in:
usamimeri_renko 2024-04-25 20:58:50 +08:00
parent 6355c5c0ed
commit a6058ca629

View file

@ -16,7 +16,7 @@ class AmazonBedrockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self.__client = self.__init_client("bedrock-runtime")
self.provider = get_provider(self.config.model)
self.__provider = get_provider(self.config.model)
def __init_client(self, service_name: Literal["bedrock-runtime", "bedrock"]):
# access key from https://us-east-1.console.aws.amazon.com/iam
@ -33,8 +33,8 @@ class AmazonBedrockLLM(BaseLLM):
client = self.__init_client("bedrock")
# only output text-generation models
response = client.list_foundation_models(byOutputModality='TEXT')
summaries = [f'{summary.get("modelId", ""):50} Support Streaming:{summary.get("responseStreamingSupported","")}'
for summary in response.get("modelSummaries", {})]
summaries = [f'{summary["modelId"]:50} Support Streaming:{summary["responseStreamingSupported"]}'
for summary in response["modelSummaries"]]
logger.info("\n"+"\n".join(summaries))
@property
@ -45,12 +45,12 @@ class AmazonBedrockLLM(BaseLLM):
}
def completion(self, messages: list[dict]):
request_body = self.provider.get_request_body(
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
response = self.__client.invoke_model(
modelId=self.config.model, body=request_body
)
completions = self.provider.get_choice_text(response)
completions = self.__provider.get_choice_text(response)
return completions
def _chat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
@ -59,7 +59,7 @@ class AmazonBedrockLLM(BaseLLM):
f"model {self.config.model} doesn't support streaming output!")
return self.completion(messages)
request_body = self.provider.get_request_body(
request_body = self.__provider.get_request_body(
messages, **self._generate_kwargs)
response = self.__client.invoke_model_with_response_stream(
modelId=self.config.model, body=request_body
@ -67,7 +67,7 @@ class AmazonBedrockLLM(BaseLLM):
collected_content = []
for event in response["body"]:
chunk_text = self.provider.get_choice_text_from_stream(event)
chunk_text = self.__provider.get_choice_text_from_stream(event)
collected_content.append(chunk_text)
log_llm_stream(chunk_text)