mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix test
This commit is contained in:
parent
2405fb36e3
commit
b2ef3f7266
4 changed files with 32 additions and 39 deletions
|
|
@ -11,9 +11,9 @@ arch_guard_model_type = {
|
|||
|
||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for `get_guardrail_handler()` function on `cpu`
|
||||
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cpu"
|
||||
|
||||
|
|
@ -22,21 +22,19 @@ def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer
|
|||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_ov_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `cuda`
|
||||
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cuda"
|
||||
|
||||
|
|
@ -45,21 +43,19 @@ def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenize
|
|||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `mps`
|
||||
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "mps"
|
||||
|
||||
|
|
@ -68,12 +64,10 @@ def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer
|
|||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue