mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 18:36:34 +02:00
* Update model server * Delete model_server/.vscode/settings.json * Update loader.py * Fix errors * Update log mode
43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
import time
|
|
import torch
|
|
import app.prompt_guard.model_utils as model_utils
|
|
|
|
|
|
class ArchGuardHanlder:
|
|
def __init__(self, model_dict, threshold=0.5):
|
|
self.task = "jailbreak"
|
|
self.positive_class = 2
|
|
|
|
self.model = model_dict["model"]
|
|
self.tokenizer = model_dict["tokenizer"]
|
|
self.device = model_dict["device"]
|
|
self.hardware_config = model_dict["hardware_config"]
|
|
|
|
self.threshold = threshold
|
|
|
|
def guard_predict(self, input_text):
|
|
start_time = time.perf_counter()
|
|
|
|
inputs = self.tokenizer(
|
|
input_text, truncation=True, max_length=512, return_tensors="pt"
|
|
).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
|
prob = model_utils.softmax(logits)[self.positive_class]
|
|
|
|
if prob > self.threshold:
|
|
verdict = True
|
|
sentence = input_text
|
|
else:
|
|
verdict = False
|
|
sentence = None
|
|
|
|
result_dict = {
|
|
f"{self.task}_prob": prob.item(),
|
|
f"{self.task}_verdict": verdict,
|
|
f"{self.task}_sentence": sentence,
|
|
"time": time.perf_counter() - start_time,
|
|
}
|
|
|
|
return result_dict
|