plano/model_server/app/prompt_guard/model_handler.py
Shuguang Chen 3b7c58698f
Update model_server (#164)
* Update model server

* Delete model_server/.vscode/settings.json

* Update loader.py

* Fix errors

* Update log mode
2024-10-09 18:04:52 -07:00

43 lines
1.2 KiB
Python

import time
import torch
import app.prompt_guard.model_utils as model_utils
class ArchGuardHanlder:
def __init__(self, model_dict, threshold=0.5):
self.task = "jailbreak"
self.positive_class = 2
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.hardware_config = model_dict["hardware_config"]
self.threshold = threshold
def guard_predict(self, input_text):
start_time = time.perf_counter()
inputs = self.tokenizer(
input_text, truncation=True, max_length=512, return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = model_utils.softmax(logits)[self.positive_class]
if prob > self.threshold:
verdict = True
sentence = input_text
else:
verdict = False
sentence = None
result_dict = {
f"{self.task}_prob": prob.item(),
f"{self.task}_verdict": verdict,
f"{self.task}_sentence": sentence,
"time": time.perf_counter() - start_time,
}
return result_dict