plano/model_server/app/prompt_guard/model_handler.py
CTran 8e54ac20d8
Refactor model server hardware config + add unit tests to load/request to the server (#189)
* remove mode/hardware

* add test and pre commit hook

* add pytest dependieces

* fix format

* fix lint

* fix precommit

* fix pre commit

* fix pre commit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit

* fix precommit
2024-10-16 16:58:10 -07:00

42 lines
1.2 KiB
Python

import time
import torch
import app.prompt_guard.model_utils as model_utils
class ArchGuardHanlder:
def __init__(self, model_dict, threshold=0.5):
self.task = "jailbreak"
self.positive_class = 2
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.threshold = threshold
def guard_predict(self, input_text, max_length=512):
start_time = time.perf_counter()
inputs = self.tokenizer(
input_text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = model_utils.softmax(logits)[self.positive_class]
if prob > self.threshold:
verdict = True
sentence = input_text
else:
verdict = False
sentence = None
result_dict = {
f"{self.task}_prob": prob.item(),
f"{self.task}_verdict": verdict,
f"{self.task}_sentence": sentence,
"time": time.perf_counter() - start_time,
}
return result_dict