mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
|
|
import os
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import patch, MagicMock
|
||
|
|
import app.commons.globals as glb
|
||
|
|
|
||
|
|
# Mock constants
|
||
|
|
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||
|
|
arch_guard_model_type = {
|
||
|
|
"cpu": "katanemo/Arch-Guard-cpu",
|
||
|
|
"cuda": "katanemo/Arch-Guard",
|
||
|
|
"mps": "katanemo/Arch-Guard",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
|
||
|
|
# Test for get_prompt_guard function
|
||
|
|
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||
|
|
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||
|
|
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||
|
|
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||
|
|
# Mock model based on device
|
||
|
|
if glb.DEVICE == "cpu":
|
||
|
|
mock_ov_model.return_value = MagicMock()
|
||
|
|
else:
|
||
|
|
mock_auto_model.return_value = MagicMock()
|
||
|
|
|
||
|
|
mock_tokenizer.return_value = MagicMock()
|
||
|
|
|
||
|
|
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||
|
|
|
||
|
|
# Assertions
|
||
|
|
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||
|
|
mock_tokenizer.assert_called_once_with(
|
||
|
|
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||
|
|
)
|
||
|
|
if glb.DEVICE == "cpu":
|
||
|
|
mock_ov_model.assert_called_once_with(
|
||
|
|
arch_guard_model_type[glb.DEVICE],
|
||
|
|
device_map=glb.DEVICE,
|
||
|
|
low_cpu_mem_usage=True,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
mock_auto_model.assert_called_once_with(
|
||
|
|
arch_guard_model_type[glb.DEVICE],
|
||
|
|
device_map=glb.DEVICE,
|
||
|
|
low_cpu_mem_usage=True,
|
||
|
|
)
|