mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-15 10:32:37 +02:00
release/v1.4 -> master (#548)
This commit is contained in:
parent
3ec2cd54f9
commit
2bd68ed7f4
94 changed files with 8571 additions and 1740 deletions
|
|
@ -183,13 +183,13 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.model = model
|
||||
# Store default configuration
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
self.variant = self.determine_variant(self.model)()
|
||||
self.variant.set_temperature(temperature)
|
||||
self.variant.set_max_output(max_output)
|
||||
# Cache for model variants to avoid re-initialization
|
||||
self.model_variants = {}
|
||||
|
||||
self.session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
|
|
@ -208,47 +208,75 @@ class Processor(LlmService):
|
|||
# FIXME: Missing, Amazon models, Deepseek
|
||||
|
||||
# This set of conditions deals with normal bedrock on-demand usage
|
||||
if self.model.startswith("mistral"):
|
||||
if model.startswith("mistral"):
|
||||
return Mistral
|
||||
elif self.model.startswith("meta"):
|
||||
elif model.startswith("meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("anthropic"):
|
||||
elif model.startswith("anthropic"):
|
||||
return Anthropic
|
||||
elif self.model.startswith("ai21"):
|
||||
elif model.startswith("ai21"):
|
||||
return Ai21
|
||||
elif self.model.startswith("cohere"):
|
||||
elif model.startswith("cohere"):
|
||||
return Cohere
|
||||
|
||||
# The inference profiles
|
||||
if self.model.startswith("us.meta"):
|
||||
if model.startswith("us.meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("us.anthropic"):
|
||||
elif model.startswith("us.anthropic"):
|
||||
return Anthropic
|
||||
elif self.model.startswith("eu.meta"):
|
||||
elif model.startswith("eu.meta"):
|
||||
return Meta
|
||||
elif self.model.startswith("eu.anthropic"):
|
||||
elif model.startswith("eu.anthropic"):
|
||||
return Anthropic
|
||||
|
||||
return Default
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
def _get_or_create_variant(self, model_name, temperature=None):
|
||||
"""Get cached model variant or create new one"""
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
# Create a cache key that includes temperature to avoid conflicts
|
||||
cache_key = f"{model_name}:{effective_temperature}"
|
||||
|
||||
if cache_key not in self.model_variants:
|
||||
logger.info(f"Creating model variant for '{model_name}' with temperature {effective_temperature}")
|
||||
variant_class = self.determine_variant(model_name)
|
||||
variant = variant_class()
|
||||
variant.set_temperature(effective_temperature)
|
||||
variant.set_max_output(self.max_output)
|
||||
self.model_variants[cache_key] = variant
|
||||
|
||||
return self.model_variants[cache_key]
|
||||
|
||||
async def generate_content(self, system, prompt, model=None, temperature=None):
|
||||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
logger.debug(f"Using model: {model_name}")
|
||||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
# Get the appropriate variant for this model
|
||||
variant = self._get_or_create_variant(model_name, effective_temperature)
|
||||
|
||||
promptbody = self.variant.encode_request(system, prompt)
|
||||
promptbody = variant.encode_request(system, prompt)
|
||||
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
|
||||
response = self.bedrock.invoke_model(
|
||||
body=promptbody,
|
||||
modelId=self.model,
|
||||
modelId=model_name,
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
||||
# Response structure decode
|
||||
outputtext = self.variant.decode_response(response)
|
||||
outputtext = variant.decode_response(response)
|
||||
|
||||
metadata = response['ResponseMetadata']['HTTPHeaders']
|
||||
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
|
||||
|
|
@ -262,7 +290,7 @@ class Processor(LlmService):
|
|||
text = outputtext,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model
|
||||
model = model_name
|
||||
)
|
||||
|
||||
return resp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue