mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix AWS bedrock issues with newer model invocation (#572)
- Fixed models so that global.* models work - Fixed Claude 4.5 & 4.7 invocation by removing top_p top_k params
This commit is contained in:
parent
72cb1c98e0
commit
c808d26b0b
1 changed files with 11 additions and 38 deletions
|
|
@ -21,8 +21,6 @@ default_ident = "text-completion"
|
|||
default_model = 'mistral.mistral-large-2407-v1:0'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 2048
|
||||
default_top_p = 0.99
|
||||
default_top_k = 40
|
||||
|
||||
# Actually, these could all just be None, no need to get environment
|
||||
# variables, as Boto3 would pick all these up if not passed in as args
|
||||
|
|
@ -38,16 +36,10 @@ class ModelHandler:
|
|||
def __init__(self):
|
||||
self.temperature = default_temperature
|
||||
self.max_output = default_max_output
|
||||
self.top_p = default_top_p
|
||||
self.top_k = default_top_k
|
||||
def set_temperature(self, temperature):
|
||||
self.temperature = temperature
|
||||
def set_max_output(self, max_output):
|
||||
self.max_output = max_output
|
||||
def set_top_p(self, top_p):
|
||||
self.top_p = top_p
|
||||
def set_top_k(self, top_k):
|
||||
self.top_k = top_k
|
||||
def encode_request(self, system, prompt):
|
||||
raise RuntimeError("format_request not implemented")
|
||||
def decode_response(self, response):
|
||||
|
|
@ -57,15 +49,12 @@ class ModelHandler:
|
|||
|
||||
class Mistral(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.99
|
||||
self.top_k = 40
|
||||
pass
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"prompt": f"{system}\n\n{prompt}",
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
})
|
||||
def decode_response(self, response):
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
|
@ -79,13 +68,12 @@ class Mistral(ModelHandler):
|
|||
# Llama 3
|
||||
class Meta(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.95
|
||||
pass
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"prompt": f"{system}\n\n{prompt}",
|
||||
"max_gen_len": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
})
|
||||
def decode_response(self, response):
|
||||
model_response = json.loads(response["body"].read())
|
||||
|
|
@ -96,13 +84,12 @@ class Meta(ModelHandler):
|
|||
|
||||
class Anthropic(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.999
|
||||
pass
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -127,12 +114,11 @@ class Anthropic(ModelHandler):
|
|||
|
||||
class Ai21(ModelHandler):
|
||||
def __init__(self):
|
||||
self.top_p = 0.9
|
||||
pass
|
||||
def encode_request(self, system, prompt):
|
||||
return json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
|
@ -230,30 +216,17 @@ class Processor(LlmService):
|
|||
|
||||
def determine_variant(self, model):
|
||||
|
||||
# FIXME: Missing, Amazon models, Deepseek
|
||||
|
||||
# This set of conditions deals with normal bedrock on-demand usage
|
||||
if model.startswith("mistral"):
|
||||
if ".anthropic." in model or model.startswith("anthropic"):
|
||||
return Anthropic
|
||||
elif ".meta." in model or model.startswith("meta"):
|
||||
return Meta
|
||||
elif ".mistral." in model or model.startswith("mistral"):
|
||||
return Mistral
|
||||
elif model.startswith("meta"):
|
||||
return Meta
|
||||
elif model.startswith("anthropic"):
|
||||
return Anthropic
|
||||
elif model.startswith("ai21"):
|
||||
elif ".ai21." in model or model.startswith("ai21"):
|
||||
return Ai21
|
||||
elif model.startswith("cohere"):
|
||||
elif ".cohere." in model or model.startswith("cohere"):
|
||||
return Cohere
|
||||
|
||||
# The inference profiles
|
||||
if model.startswith("us.meta"):
|
||||
return Meta
|
||||
elif model.startswith("us.anthropic"):
|
||||
return Anthropic
|
||||
elif model.startswith("eu.meta"):
|
||||
return Meta
|
||||
elif model.startswith("eu.anthropic"):
|
||||
return Anthropic
|
||||
|
||||
return Default
|
||||
|
||||
def _get_or_create_variant(self, model_name, temperature=None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue