mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
from unittest.mock import patch, MagicMock
|
|
from app.model_handler.guardrails import get_guardrail_handler
|
|
|
|
# Mock constants
|
|
arch_guard_model_type = {
|
|
"cpu": "katanemo/Arch-Guard-cpu",
|
|
"cuda": "katanemo/Arch-Guard",
|
|
"mps": "katanemo/Arch-Guard",
|
|
}
|
|
|
|
|
|
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
|
# Test for `get_guardrail_handler()` function on `cpu`
|
|
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
|
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
|
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
|
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
|
device = "cpu"
|
|
|
|
mock_ov_model.return_value = MagicMock()
|
|
mock_tokenizer.return_value = MagicMock()
|
|
|
|
guardrail = get_guardrail_handler(device=device)
|
|
|
|
mock_tokenizer.assert_called_once_with(
|
|
guardrail["model_name"], trust_remote_code=True
|
|
)
|
|
|
|
mock_ov_model.assert_called_once_with(
|
|
guardrail["model_name"],
|
|
device_map=device,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
|
|
|
|
# Test for `get_guardrail_handler()` function on `cuda`
|
|
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
|
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
|
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
|
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
|
device = "cuda"
|
|
|
|
mock_auto_model.return_value = MagicMock()
|
|
mock_tokenizer.return_value = MagicMock()
|
|
|
|
guardrail = get_guardrail_handler(device=device)
|
|
|
|
mock_tokenizer.assert_called_once_with(
|
|
guardrail["model_name"], trust_remote_code=True
|
|
)
|
|
|
|
mock_auto_model.assert_called_once_with(
|
|
guardrail["model_name"],
|
|
device_map=device,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
|
|
|
|
# Test for `get_guardrail_handler()` function on `mps`
|
|
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
|
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
|
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
|
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
|
device = "mps"
|
|
|
|
mock_auto_model.return_value = MagicMock()
|
|
mock_tokenizer.return_value = MagicMock()
|
|
|
|
guardrail = get_guardrail_handler(device=device)
|
|
|
|
mock_tokenizer.assert_called_once_with(
|
|
guardrail["model_name"], trust_remote_code=True
|
|
)
|
|
|
|
mock_auto_model.assert_called_once_with(
|
|
guardrail["model_name"],
|
|
device_map=device,
|
|
low_cpu_mem_usage=True,
|
|
)
|