Update guardrail_handler and its associated tests

This commit is contained in:
Shuguang Chen 2024-12-05 11:30:58 -08:00
parent b686cf8b87
commit 09f7e1e604
7 changed files with 115 additions and 1091 deletions

View file

@ -1,10 +1,7 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.model_handler.guardrails import get_guardrail_handler
# Mock constants
glb.DEVICE = "cpu" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
@ -12,36 +9,71 @@ arch_guard_model_type = {
}
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cpu"
mock_ov_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
guardrail = get_guardrail_handler(device=device)
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
guardrail["model_name"], trust_remote_code=True
)
mock_ov_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)