mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
Update guardrail_handler and its associated tests
This commit is contained in:
parent
b686cf8b87
commit
09f7e1e604
7 changed files with 115 additions and 1091 deletions
|
|
@ -1,10 +1,7 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.model_handler.guardrails import get_guardrail_handler
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
|
|
@ -12,36 +9,71 @@ arch_guard_model_type = {
|
|||
}
|
||||
|
||||
|
||||
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for `get_guardrail_handler()` function on `cpu`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cpu"
|
||||
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_ov_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `cuda`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cuda"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `mps`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "mps"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue