mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 17:36:23 +02:00
* Tidy up duplicate tech specs in doc directory * Streaming LLM text-completion service tech spec. * text-completion and prompt interfaces * streaming change applied to all LLMs, so far tested with VertexAI * Skip Pinecone unit tests, upstream module issue is affecting things, tests are passing again * Added agent streaming, not working and has broken tests
460 lines
15 KiB
Python
Executable file
460 lines
15 KiB
Python
Executable file
|
|
"""
|
|
Simple LLM service, performs text prompt completion using AWS Bedrock.
|
|
Input is prompt, output is response. Mistral is default.
|
|
"""
|
|
|
|
import boto3
|
|
import json
|
|
import os
|
|
import enum
|
|
import logging
|
|
|
|
from .... exceptions import TooManyRequests
|
|
from .... base import LlmService, LlmResult, LlmChunk
|
|
|
|
# Module logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
default_ident = "text-completion"
|
|
|
|
default_model = 'mistral.mistral-large-2407-v1:0'
|
|
default_temperature = 0.0
|
|
default_max_output = 2048
|
|
default_top_p = 0.99
|
|
default_top_k = 40
|
|
|
|
# Actually, these could all just be None, no need to get environment
|
|
# variables, as Boto3 would pick all these up if not passed in as args
|
|
default_access_key_id = os.getenv("AWS_ACCESS_KEY_ID", None)
|
|
default_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY", None)
|
|
default_session_token = os.getenv("AWS_SESSION_TOKEN", None)
|
|
default_profile = os.getenv("AWS_PROFILE", None)
|
|
default_region = os.getenv("AWS_DEFAULT_REGION", None)
|
|
|
|
# Variant API handling depends on the model type
|
|
|
|
class ModelHandler:
|
|
def __init__(self):
|
|
self.temperature = default_temperature
|
|
self.max_output = default_max_output
|
|
self.top_p = default_top_p
|
|
self.top_k = default_top_k
|
|
def set_temperature(self, temperature):
|
|
self.temperature = temperature
|
|
def set_max_output(self, max_output):
|
|
self.max_output = max_output
|
|
def set_top_p(self, top_p):
|
|
self.top_p = top_p
|
|
def set_top_k(self, top_k):
|
|
self.top_k = top_k
|
|
def encode_request(self, system, prompt):
|
|
raise RuntimeError("format_request not implemented")
|
|
def decode_response(self, response):
|
|
raise RuntimeError("format_request not implemented")
|
|
def decode_stream_chunk(self, chunk):
|
|
raise RuntimeError("decode_stream_chunk not implemented")
|
|
|
|
class Mistral(ModelHandler):
|
|
def __init__(self):
|
|
self.top_p = 0.99
|
|
self.top_k = 40
|
|
def encode_request(self, system, prompt):
|
|
return json.dumps({
|
|
"prompt": f"{system}\n\n{prompt}",
|
|
"max_tokens": self.max_output,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
"top_k": self.top_k,
|
|
})
|
|
def decode_response(self, response):
|
|
response_body = json.loads(response.get("body").read())
|
|
return response_body['outputs'][0]['text']
|
|
def decode_stream_chunk(self, chunk):
|
|
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
|
if 'outputs' in chunk_obj and len(chunk_obj['outputs']) > 0:
|
|
return chunk_obj['outputs'][0].get('text', '')
|
|
return ''
|
|
|
|
# Llama 3
|
|
class Meta(ModelHandler):
|
|
def __init__(self):
|
|
self.top_p = 0.95
|
|
def encode_request(self, system, prompt):
|
|
return json.dumps({
|
|
"prompt": f"{system}\n\n{prompt}",
|
|
"max_gen_len": self.max_output,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
})
|
|
def decode_response(self, response):
|
|
model_response = json.loads(response["body"].read())
|
|
return model_response["generation"]
|
|
def decode_stream_chunk(self, chunk):
|
|
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
|
return chunk_obj.get('generation', '')
|
|
|
|
class Anthropic(ModelHandler):
|
|
def __init__(self):
|
|
self.top_p = 0.999
|
|
def encode_request(self, system, prompt):
|
|
return json.dumps({
|
|
"anthropic_version": "bedrock-2023-05-31",
|
|
"max_tokens": self.max_output,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"{system}\n\n{prompt}",
|
|
}
|
|
]
|
|
}
|
|
]
|
|
})
|
|
def decode_response(self, response):
|
|
model_response = json.loads(response["body"].read())
|
|
return model_response['content'][0]['text']
|
|
def decode_stream_chunk(self, chunk):
|
|
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
|
if chunk_obj.get('type') == 'content_block_delta':
|
|
if 'delta' in chunk_obj and 'text' in chunk_obj['delta']:
|
|
return chunk_obj['delta']['text']
|
|
return ''
|
|
|
|
class Ai21(ModelHandler):
|
|
def __init__(self):
|
|
self.top_p = 0.9
|
|
def encode_request(self, system, prompt):
|
|
return json.dumps({
|
|
"max_tokens": self.max_output,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": f"{system}\n\n{prompt}"
|
|
}
|
|
]
|
|
})
|
|
def decode_response(self, response):
|
|
content = response['body'].read()
|
|
content_str = content.decode('utf-8')
|
|
content_json = json.loads(content_str)
|
|
return content_json['choices'][0]['message']['content']
|
|
def decode_stream_chunk(self, chunk):
|
|
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
|
if 'choices' in chunk_obj and len(chunk_obj['choices']) > 0:
|
|
delta = chunk_obj['choices'][0].get('delta', {})
|
|
return delta.get('content', '')
|
|
return ''
|
|
|
|
class Cohere(ModelHandler):
|
|
def encode_request(self, system, prompt):
|
|
return json.dumps({
|
|
"max_tokens": self.max_output,
|
|
"temperature": self.temperature,
|
|
"message": f"{system}\n\n{prompt}",
|
|
})
|
|
def decode_response(self, response):
|
|
content = response['body'].read()
|
|
content_str = content.decode('utf-8')
|
|
content_json = json.loads(content_str)
|
|
return content_json['text']
|
|
def decode_stream_chunk(self, chunk):
|
|
chunk_obj = json.loads(chunk.get('chunk').get('bytes').decode())
|
|
return chunk_obj.get('text', '')
|
|
|
|
Default=Mistral
|
|
|
|
class Processor(LlmService):
|
|
|
|
def __init__(self, **params):
|
|
|
|
logger.debug(f"Bedrock LLM initialized with params: {params}")
|
|
|
|
model = params.get("model", default_model)
|
|
temperature = params.get("temperature", default_temperature)
|
|
max_output = params.get("max_output", default_max_output)
|
|
|
|
aws_access_key_id = params.get(
|
|
"aws_access_key_id", default_access_key_id
|
|
)
|
|
|
|
aws_secret_access_key = params.get(
|
|
"aws_secret_access_key", default_secret_access_key
|
|
)
|
|
|
|
aws_session_token = params.get(
|
|
"aws_session_token", default_session_token
|
|
)
|
|
|
|
aws_region = params.get(
|
|
"aws_region", default_region
|
|
)
|
|
|
|
aws_profile = params.get(
|
|
"aws_profile", default_profile
|
|
)
|
|
|
|
super(Processor, self).__init__(
|
|
**params | {
|
|
"model": model,
|
|
"temperature": temperature,
|
|
"max_output": max_output,
|
|
}
|
|
)
|
|
|
|
# Store default configuration
|
|
self.default_model = model
|
|
self.temperature = temperature
|
|
self.max_output = max_output
|
|
|
|
# Cache for model variants to avoid re-initialization
|
|
self.model_variants = {}
|
|
|
|
self.session = boto3.Session(
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
aws_session_token=aws_session_token,
|
|
profile_name=aws_profile,
|
|
region_name=aws_region,
|
|
)
|
|
|
|
self.bedrock = self.session.client(service_name='bedrock-runtime')
|
|
|
|
logger.info("Bedrock LLM service initialized")
|
|
|
|
def determine_variant(self, model):
|
|
|
|
# FIXME: Missing, Amazon models, Deepseek
|
|
|
|
# This set of conditions deals with normal bedrock on-demand usage
|
|
if model.startswith("mistral"):
|
|
return Mistral
|
|
elif model.startswith("meta"):
|
|
return Meta
|
|
elif model.startswith("anthropic"):
|
|
return Anthropic
|
|
elif model.startswith("ai21"):
|
|
return Ai21
|
|
elif model.startswith("cohere"):
|
|
return Cohere
|
|
|
|
# The inference profiles
|
|
if model.startswith("us.meta"):
|
|
return Meta
|
|
elif model.startswith("us.anthropic"):
|
|
return Anthropic
|
|
elif model.startswith("eu.meta"):
|
|
return Meta
|
|
elif model.startswith("eu.anthropic"):
|
|
return Anthropic
|
|
|
|
return Default
|
|
|
|
def _get_or_create_variant(self, model_name, temperature=None):
|
|
"""Get cached model variant or create new one"""
|
|
# Use provided temperature or fall back to default
|
|
effective_temperature = temperature if temperature is not None else self.temperature
|
|
|
|
# Create a cache key that includes temperature to avoid conflicts
|
|
cache_key = f"{model_name}:{effective_temperature}"
|
|
|
|
if cache_key not in self.model_variants:
|
|
logger.info(f"Creating model variant for '{model_name}' with temperature {effective_temperature}")
|
|
variant_class = self.determine_variant(model_name)
|
|
variant = variant_class()
|
|
variant.set_temperature(effective_temperature)
|
|
variant.set_max_output(self.max_output)
|
|
self.model_variants[cache_key] = variant
|
|
|
|
return self.model_variants[cache_key]
|
|
|
|
async def generate_content(self, system, prompt, model=None, temperature=None):
|
|
|
|
# Use provided model or fall back to default
|
|
model_name = model or self.default_model
|
|
# Use provided temperature or fall back to default
|
|
effective_temperature = temperature if temperature is not None else self.temperature
|
|
|
|
logger.debug(f"Using model: {model_name}")
|
|
logger.debug(f"Using temperature: {effective_temperature}")
|
|
|
|
try:
|
|
# Get the appropriate variant for this model
|
|
variant = self._get_or_create_variant(model_name, effective_temperature)
|
|
|
|
promptbody = variant.encode_request(system, prompt)
|
|
|
|
accept = 'application/json'
|
|
contentType = 'application/json'
|
|
|
|
response = self.bedrock.invoke_model(
|
|
body=promptbody,
|
|
modelId=model_name,
|
|
accept=accept,
|
|
contentType=contentType
|
|
)
|
|
|
|
# Response structure decode
|
|
outputtext = variant.decode_response(response)
|
|
|
|
metadata = response['ResponseMetadata']['HTTPHeaders']
|
|
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
|
|
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
|
|
|
|
logger.debug(f"LLM output: {outputtext}")
|
|
logger.info(f"Input Tokens: {inputtokens}")
|
|
logger.info(f"Output Tokens: {outputtokens}")
|
|
|
|
resp = LlmResult(
|
|
text = outputtext,
|
|
in_token = inputtokens,
|
|
out_token = outputtokens,
|
|
model = model_name
|
|
)
|
|
|
|
return resp
|
|
|
|
except self.bedrock.exceptions.ThrottlingException as e:
|
|
|
|
logger.warning(f"Hit rate limit: {e}")
|
|
|
|
# Leave rate limit retries to the base handler
|
|
raise TooManyRequests()
|
|
|
|
except Exception as e:
|
|
|
|
# Apart from rate limits, treat all exceptions as unrecoverable
|
|
|
|
logger.error(f"Bedrock LLM exception ({type(e).__name__}): {e}", exc_info=True)
|
|
raise e
|
|
|
|
def supports_streaming(self):
|
|
"""Bedrock supports streaming"""
|
|
return True
|
|
|
|
async def generate_content_stream(self, system, prompt, model=None, temperature=None):
|
|
"""Stream content generation from Bedrock"""
|
|
model_name = model or self.default_model
|
|
effective_temperature = temperature if temperature is not None else self.temperature
|
|
|
|
logger.debug(f"Using model (streaming): {model_name}")
|
|
logger.debug(f"Using temperature: {effective_temperature}")
|
|
|
|
try:
|
|
variant = self._get_or_create_variant(model_name, effective_temperature)
|
|
promptbody = variant.encode_request(system, prompt)
|
|
|
|
accept = 'application/json'
|
|
contentType = 'application/json'
|
|
|
|
response = self.bedrock.invoke_model_with_response_stream(
|
|
body=promptbody,
|
|
modelId=model_name,
|
|
accept=accept,
|
|
contentType=contentType
|
|
)
|
|
|
|
total_input_tokens = 0
|
|
total_output_tokens = 0
|
|
|
|
stream = response.get('body')
|
|
if stream:
|
|
for event in stream:
|
|
chunk = event.get('chunk')
|
|
if chunk:
|
|
# Decode the chunk text
|
|
text = variant.decode_stream_chunk(event)
|
|
if text:
|
|
yield LlmChunk(
|
|
text=text,
|
|
in_token=None,
|
|
out_token=None,
|
|
model=model_name,
|
|
is_final=False
|
|
)
|
|
|
|
# Try to extract metadata from the event
|
|
metadata = event.get('metadata')
|
|
if metadata:
|
|
usage = metadata.get('usage')
|
|
if usage:
|
|
total_input_tokens = usage.get('inputTokens', 0)
|
|
total_output_tokens = usage.get('outputTokens', 0)
|
|
|
|
# Send final chunk with token counts
|
|
yield LlmChunk(
|
|
text="",
|
|
in_token=total_input_tokens,
|
|
out_token=total_output_tokens,
|
|
model=model_name,
|
|
is_final=True
|
|
)
|
|
|
|
logger.debug("Streaming complete")
|
|
|
|
except self.bedrock.exceptions.ThrottlingException as e:
|
|
logger.warning(f"Hit rate limit during streaming: {e}")
|
|
raise TooManyRequests()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Bedrock streaming exception ({type(e).__name__}): {e}", exc_info=True)
|
|
raise e
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
|
|
LlmService.add_args(parser)
|
|
|
|
parser.add_argument(
|
|
'-m', '--model',
|
|
default="mistral.mistral-large-2407-v1:0",
|
|
help=f'Bedrock model (default: Mistral-Large-2407)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-z', '--aws-access-key-id',
|
|
default=default_access_key_id,
|
|
help=f'AWS access key ID'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-k', '--aws-secret-access-key',
|
|
default=default_secret_access_key,
|
|
help=f'AWS secret access key'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-r', '--aws-region',
|
|
default=default_region,
|
|
help=f'AWS region'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--aws-profile', '--profile',
|
|
default=default_profile,
|
|
help=f'AWS profile name'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-t', '--temperature',
|
|
type=float,
|
|
default=default_temperature,
|
|
help=f'LLM temperature parameter (default: {default_temperature})'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'-x', '--max-output',
|
|
type=int,
|
|
default=default_max_output,
|
|
help=f'LLM max output tokens (default: {default_max_output})'
|
|
)
|
|
|
|
def run():
|
|
|
|
Processor.launch(default_ident, __doc__)
|