From b5d2f19d9b746d4e888c5b13f41b7e193f402daf Mon Sep 17 00:00:00 2001 From: cotran Date: Wed, 16 Oct 2024 16:49:44 -0700 Subject: [PATCH] fix precommit --- model_server/app/main.py | 2 + model_server/app/tests/test_app.py | 14 +++ model_server/app/tests/test_loaders_gpu.py | 8 +- model_server/app/tests/test_loaders_mps.py | 102 +++++++++++++++++++++ 4 files changed, 119 insertions(+), 7 deletions(-) create mode 100644 model_server/app/tests/test_loaders_mps.py diff --git a/model_server/app/main.py b/model_server/app/main.py index c6f5752a..93d6217b 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -13,6 +13,7 @@ from app.commons.constants import embedding_model, zero_shot_model, arch_guard_h from app.function_calling.model_utils import ( chat_completion as arch_function_chat_completion, ) +from unittest.mock import patch logger = utils.get_model_server_logger() @@ -173,6 +174,7 @@ async def zeroshot(req: ZeroShotRequest, res: Response): @app.post("/hallucination") +@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu' async def hallucination(req: HallucinationRequest, res: Response): """ Take input as text and return the prediction of hallucination for each parameter diff --git a/model_server/app/tests/test_app.py b/model_server/app/tests/test_app.py index 77ab14d2..c91fc153 100644 --- a/model_server/app/tests/test_app.py +++ b/model_server/app/tests/test_app.py @@ -2,12 +2,20 @@ import pytest import httpx from fastapi.testclient import TestClient from app.main import app # Assuming your FastAPI app is in main.py +from unittest.mock import patch +import app.commons.globals as glb +import logging + +logger = logging.getLogger(__name__) client = TestClient(app) +logger.info(f"Model will be loaded on device: {glb.DEVICE}") + # Unit tests for the health check endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_healthz(): response = client.get("/healthz") assert response.status_code == 200 @@ -16,6 +24,7 @@ async def test_healthz(): # Unit test for the models endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_models(): response = client.get("/models") assert response.status_code == 200 @@ -25,6 +34,7 @@ async def test_models(): # Unit test for embeddings endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_embedding(): request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"} response = client.post("/embeddings", json=request_data) @@ -38,6 +48,7 @@ async def test_embedding(): # Unit test for the guard endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_guard(): request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"} response = client.post("/guard", json=request_data) @@ -47,6 +58,7 @@ async def test_guard(): # Unit test for the zero-shot endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_zeroshot(): request_data = { "input": "Test input", @@ -63,6 +75,7 @@ async def test_zeroshot(): # Unit test for the hallucination endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_hallucination(): request_data = { "prompt": "Test hallucination", @@ -79,6 +92,7 @@ async def test_hallucination(): # Unit test for the chat completion endpoint @pytest.mark.asyncio +@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_chat_completion(): async with httpx.AsyncClient(app=app, base_url="http://test") as client: request_data = { diff --git a/model_server/app/tests/test_loaders_gpu.py b/model_server/app/tests/test_loaders_gpu.py index 9244c85c..46f73b49 100644 --- a/model_server/app/tests/test_loaders_gpu.py +++ b/model_server/app/tests/test_loaders_gpu.py @@ -5,13 +5,7 @@ import app.commons.globals as glb from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard # Mock constants -if torch.cuda.is_available(): - DEVICE = "cuda" -elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - DEVICE = "mps" -else: - DEVICE = "cpu" -glb.DEVICE = DEVICE # Adjust as needed for your test case +glb.DEVICE = "cuda" # Adjust as needed for your test case arch_guard_model_type = { "cpu": "katanemo/Arch-Guard-cpu", "cuda": "katanemo/Arch-Guard", diff --git a/model_server/app/tests/test_loaders_mps.py b/model_server/app/tests/test_loaders_mps.py new file mode 100644 index 00000000..3bc76eb5 --- /dev/null +++ b/model_server/app/tests/test_loaders_mps.py @@ -0,0 +1,102 @@ +import os +import pytest +from unittest.mock import patch, MagicMock +import app.commons.globals as glb +from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard + +# Mock constants +glb.DEVICE = "mps" # Adjust as needed for your test case +arch_guard_model_type = { + "cpu": "katanemo/Arch-Guard-cpu", + "cuda": "katanemo/Arch-Guard", + "mps": "katanemo/Arch-Guard", +} + + +@pytest.fixture +def mock_env(): + # Mock environment variables + os.environ["MODELS"] = "katanemo/bge-large-en-v1.5" + os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli" + + +# Test for get_embedding_model function +@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained") +@patch("app.loader.AutoModel.from_pretrained") +@patch("app.loader.AutoTokenizer.from_pretrained") +def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env): + mock_automodel.return_value = MagicMock() + mock_ort_model.return_value = MagicMock() + mock_tokenizer.return_value = MagicMock() + + embedding_model = get_embedding_model() + + # Assertions + assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5" + assert mock_tokenizer.called_once_with( + "katanemo/bge-large-en-v1.5", trust_remote_code=True + ) + if glb.DEVICE != "cuda": + assert mock_ort_model.called_once_with( + "katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx" + ) + else: + assert mock_automodel.called_once_with( + "katanemo/bge-large-en-v1.5", device_map=glb.DEVICE + ) + + +# Test for get_zero_shot_model function +@patch("app.loader.ORTModelForSequenceClassification.from_pretrained") +@patch("app.loader.pipeline") +@patch("app.loader.AutoTokenizer.from_pretrained") +def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env): + mock_pipeline.return_value = MagicMock() + mock_ort_model.return_value = MagicMock() + mock_tokenizer.return_value = MagicMock() + + zero_shot_model = get_zero_shot_model() + + # Assertions + assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli" + assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli") + if glb.DEVICE != "cuda": + assert mock_ort_model.called_once_with( + "katanemo/bart-large-mnli", file_name="onnx/model.onnx" + ) + else: + assert mock_pipeline.called_once() + + +# Test for get_prompt_guard function +@patch("app.loader.AutoTokenizer.from_pretrained") +@patch("app.loader.OVModelForSequenceClassification.from_pretrained") +@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") +def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer): + # Mock model based on device + if glb.DEVICE == "cpu": + mock_ov_model.return_value = MagicMock() + else: + mock_auto_model.return_value = MagicMock() + + mock_tokenizer.return_value = MagicMock() + + prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) + + # Assertions + assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] + assert mock_tokenizer.called_once_with( + arch_guard_model_type[glb.DEVICE], trust_remote_code=True + ) + if glb.DEVICE == "cpu": + assert mock_ov_model.called_once_with( + arch_guard_model_type[glb.DEVICE], + device_map=glb.DEVICE, + low_cpu_mem_usage=True, + ) + else: + assert mock_auto_model.called_once_with( + arch_guard_model_type[glb.DEVICE], + device_map=glb.DEVICE, + low_cpu_mem_usage=True, + )