This commit is contained in:
cotran 2024-12-10 19:31:02 -08:00
parent 2405fb36e3
commit b2ef3f7266
4 changed files with 32 additions and 39 deletions

View file

@ -11,9 +11,9 @@ arch_guard_model_type = {
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cpu"
@ -22,21 +22,19 @@ def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_ov_model.assert_called_once_with(
guardrail["model_name"],
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
@ -45,21 +43,19 @@ def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenize
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"
@ -68,12 +64,10 @@ def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)