add generate_kwargs

This commit is contained in:
usamimeri_renko 2024-04-25 14:31:22 +08:00
parent ec7df8acdf
commit 4d1fb20785
3 changed files with 20 additions and 10 deletions

View file

@ -41,7 +41,6 @@ class AmazonBedrockLLM(BaseLLM):
@property
def _generate_kwargs(self):
return {
"max_token": self.config.get("max_token", 1024),
"temperature": self.config.get("temperature", 0.3),
"top_p": self.config.get("top_p", 0.95),
"top_k": self.config.get("top_k", 1),

View file

@ -3,8 +3,8 @@ import json
class BaseBedrockProvider(object):
# to handle different generation kwargs
def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs):
return json.dumps({"prompt": self.messages_to_prompt(messages)})
def get_request_body(self, messages, **generate_kwargs):
return json.dumps({"prompt": self.messages_to_prompt(messages)} | generate_kwargs)
def get_choice_text(self, response) -> str:
response_body = json.loads(response["body"].read())

View file

@ -8,13 +8,24 @@ class MistralProvider(BaseBedrockProvider):
# for mixtral and llama
return f"<s>[INST]{prompt}[/INST]"
def get_request_body(self, messages, max_token=None, temperature=None, top_p=None, top_k=None, **kwargs):
return json.dumps({
"prompt": self.format_prompt(self.messages_to_prompt(messages)),
"max_tokens": max_token,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k, })
def get_request_body(self, messages, **generate_kwargs):
return json.dumps({"prompt": self.format_prompt(self.messages_to_prompt(messages))} | generate_kwargs)
class AnthropicProvider(BaseBedrockProvider):
pass
class CohereProvider(BaseBedrockProvider):
pass
class MetaProvider(BaseBedrockProvider):
pass
class Ai21Provider(BaseBedrockProvider):
pass
PROVIDERS = {