diff --git a/model_server/app/tests/test_loaders_gpu.py b/model_server/app/tests/test_loaders_gpu.py index 46f73b49..9244c85c 100644 --- a/model_server/app/tests/test_loaders_gpu.py +++ b/model_server/app/tests/test_loaders_gpu.py @@ -5,7 +5,13 @@ import app.commons.globals as glb from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard # Mock constants -glb.DEVICE = "cuda" # Adjust as needed for your test case +if torch.cuda.is_available(): + DEVICE = "cuda" +elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" +glb.DEVICE = DEVICE # Adjust as needed for your test case arch_guard_model_type = { "cpu": "katanemo/Arch-Guard-cpu", "cuda": "katanemo/Arch-Guard", diff --git a/model_server/app/tests/test_loaders_mps.py b/model_server/app/tests/test_loaders_mps.py deleted file mode 100644 index 3bc76eb5..00000000 --- a/model_server/app/tests/test_loaders_mps.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -import pytest -from unittest.mock import patch, MagicMock -import app.commons.globals as glb -from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard - -# Mock constants -glb.DEVICE = "mps" # Adjust as needed for your test case -arch_guard_model_type = { - "cpu": "katanemo/Arch-Guard-cpu", - "cuda": "katanemo/Arch-Guard", - "mps": "katanemo/Arch-Guard", -} - - -@pytest.fixture -def mock_env(): - # Mock environment variables - os.environ["MODELS"] = "katanemo/bge-large-en-v1.5" - os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli" - - -# Test for get_embedding_model function -@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained") -@patch("app.loader.AutoModel.from_pretrained") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env): - mock_automodel.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - embedding_model = get_embedding_model() - - # Assertions - assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5" - assert mock_tokenizer.called_once_with( - "katanemo/bge-large-en-v1.5", trust_remote_code=True - ) - if glb.DEVICE != "cuda": - assert mock_ort_model.called_once_with( - "katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx" - ) - else: - assert mock_automodel.called_once_with( - "katanemo/bge-large-en-v1.5", device_map=glb.DEVICE - ) - - -# Test for get_zero_shot_model function -@patch("app.loader.ORTModelForSequenceClassification.from_pretrained") -@patch("app.loader.pipeline") -@patch("app.loader.AutoTokenizer.from_pretrained") -def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env): - mock_pipeline.return_value = MagicMock() - mock_ort_model.return_value = MagicMock() - mock_tokenizer.return_value = MagicMock() - - zero_shot_model = get_zero_shot_model() - - # Assertions - assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli" - assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli") - if glb.DEVICE != "cuda": - assert mock_ort_model.called_once_with( - "katanemo/bart-large-mnli", file_name="onnx/model.onnx" - ) - else: - assert mock_pipeline.called_once() - - -# Test for get_prompt_guard function -@patch("app.loader.AutoTokenizer.from_pretrained") -@patch("app.loader.OVModelForSequenceClassification.from_pretrained") -@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") -def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer): - # Mock model based on device - if glb.DEVICE == "cpu": - mock_ov_model.return_value = MagicMock() - else: - mock_auto_model.return_value = MagicMock() - - mock_tokenizer.return_value = MagicMock() - - prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) - - # Assertions - assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] - assert mock_tokenizer.called_once_with( - arch_guard_model_type[glb.DEVICE], trust_remote_code=True - ) - if glb.DEVICE == "cpu": - assert mock_ov_model.called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) - else: - assert mock_auto_model.called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - )