plano/model_server/src/core/guardrails.py
2024-12-10 19:31:02 -08:00

170 lines
5.4 KiB
Python

import time
import torch
import numpy as np
import src.commons.utils as utils
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from src.core.model_utils import GuardRequest, GuardResponse
class ArchGuardHanlder:
def __init__(self, model_dict):
"""
Initializes the ArchGuardHanlder with the given model dictionary.
Args:
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
"""
self.model = model_dict["model"]
self.model_name = model_dict["model_name"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Splits the input text into chunks of up to `max_num_words` words.
Args:
text (str): The input text to be split.
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
Returns:
List[str]: A list of text chunks.
"""
words = text.split()
chunks = [
" ".join(words[i : i + max_num_words])
for i in range(0, len(words), max_num_words)
]
return chunks
@staticmethod
def softmax(x):
"""
Computes the softmax of the input array.
Args:
x (np.ndarray): The input array.
Returns:
np.ndarray: The softmax of the input.
"""
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
"""
Predicts the result for the provided text for a specific task.
Args:
task (str): The task to perform (e.g., "jailbreak").
text (str): The input text to classify.
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
"""
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
start_time = time.perf_counter()
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
latency = time.perf_counter() - start_time
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
else:
verdict = False
sentence = None
return GuardResponse(
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
)
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
"""
Makes a prediction based on the GuardRequest input.
Args:
req (GuardRequest): The GuardRequest object containing the input text and task.
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
Note:
currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
if len(req.input.split()) < max_num_words:
return self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
prob, verdict, sentence, latency = [], False, [], 0
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result.verdict:
prob.append(chunk_result.prob[0])
verdict = True
sentence.append(chunk_result.sentence[0])
latency += chunk_result.latency
return GuardResponse(
prob=prob, verdict=verdict, sentence=sentence, latency=latency
)
def get_guardrail_handler(device: str = None):
"""
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
Args:
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
Returns:
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
"""
if device is None:
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)