mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Refactor model server hardware config + add unit tests to load/request to the server (#189)
* remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
This commit is contained in:
parent
3bd2ffe9fb
commit
8e54ac20d8
13 changed files with 480 additions and 43 deletions
38
.github/workflows/model-server-tests.yml
vendored
Normal file
38
.github/workflows/model-server-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
name: Run Model Server tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # Run tests on pushes to the main branch
|
||||
pull_request:
|
||||
branches:
|
||||
- main # Run tests on pull requests to the main branch
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# Step 1: Check out the code from your repository
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
# Step 2: Set up Python (specify the version)
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10" # Adjust to your Python version
|
||||
|
||||
# Step 3: Install dependencies (from requirements.txt or Pipfile)
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd model_server
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt # Or use pipenv install
|
||||
pip install pytest
|
||||
|
||||
# Step 4: Set PYTHONPATH and run tests
|
||||
- name: Run model server tests with pytest
|
||||
run: |
|
||||
cd model_server
|
||||
PYTHONPATH=. pytest --maxfail=5 --disable-warnings
|
||||
|
|
@ -18,14 +18,16 @@ arch_function_generation_params = {
|
|||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"}
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
||||
prompt_guard_dict = loader.get_prompt_guard(
|
||||
arch_guard_model_type[glb.HARDWARE], glb.HARDWARE
|
||||
)
|
||||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
|
|
|
|||
|
|
@ -2,5 +2,3 @@ import app.commons.utilities as utils
|
|||
|
||||
|
||||
DEVICE = utils.get_device()
|
||||
MODE = utils.get_serving_mode()
|
||||
HARDWARE = utils.get_hardware(MODE)
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ def get_device():
|
|||
available_device = {
|
||||
"cpu": True,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False,
|
||||
"mps": (
|
||||
torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
if available_device["cuda"]:
|
||||
|
|
@ -37,24 +39,6 @@ def get_device():
|
|||
return device
|
||||
|
||||
|
||||
def get_serving_mode():
|
||||
mode = os.getenv("MODE", "cloud")
|
||||
|
||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
||||
raise ValueError(f"Invalid serving mode: {mode}")
|
||||
|
||||
return mode
|
||||
|
||||
|
||||
def get_hardware(mode):
|
||||
if mode == "local-cpu":
|
||||
hardware = "cpu"
|
||||
else:
|
||||
hardware = "gpu" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
return hardware
|
||||
|
||||
|
||||
def get_client(endpoint):
|
||||
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
||||
return client
|
||||
|
|
|
|||
|
|
@ -7,6 +7,10 @@ from optimum.onnxruntime import (
|
|||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
|
@ -60,28 +64,20 @@ def get_zero_shot_model(
|
|||
return zero_shot_model
|
||||
|
||||
|
||||
def get_prompt_guard(model_name, hardware_config="cpu"):
|
||||
def get_prompt_guard(model_name):
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if hardware_config == "cpu":
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
device = "cpu"
|
||||
if glb.DEVICE == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
elif hardware_config == "gpu":
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
prompt_guard = {
|
||||
"hardware_config": hardware_config,
|
||||
"device": device,
|
||||
"device": glb.DEVICE,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from app.commons.constants import embedding_model, zero_shot_model, arch_guard_h
|
|||
from app.function_calling.model_utils import (
|
||||
chat_completion as arch_function_chat_completion,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
|
@ -173,6 +174,7 @@ async def zeroshot(req: ZeroShotRequest, res: Response):
|
|||
|
||||
|
||||
@app.post("/hallucination")
|
||||
@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu'
|
||||
async def hallucination(req: HallucinationRequest, res: Response):
|
||||
"""
|
||||
Take input as text and return the prediction of hallucination for each parameter
|
||||
|
|
|
|||
|
|
@ -11,15 +11,14 @@ class ArchGuardHanlder:
|
|||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
self.hardware_config = model_dict["hardware_config"]
|
||||
|
||||
self.threshold = threshold
|
||||
|
||||
def guard_predict(self, input_text):
|
||||
def guard_predict(self, input_text, max_length=512):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
input_text, truncation=True, max_length=512, return_tensors="pt"
|
||||
input_text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
|||
106
model_server/app/tests/test_app.py
Normal file
106
model_server/app/tests/test_app.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
import pytest
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app # Assuming your FastAPI app is in main.py
|
||||
from unittest.mock import patch
|
||||
import app.commons.globals as glb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
|
||||
|
||||
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_healthz():
|
||||
response = client.get("/healthz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_models():
|
||||
response = client.get("/models")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# Unit test for embeddings endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_embedding():
|
||||
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
|
||||
response = client.post("/embeddings", json=request_data)
|
||||
if request_data["model"] == "katanemo/bge-large-en-v1.5":
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert "data" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the guard endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_guard():
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guard", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# Unit test for the zero-shot endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_zeroshot():
|
||||
request_data = {
|
||||
"input": "Test input",
|
||||
"labels": ["label1", "label2"],
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/zeroshot", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "predicted_class" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the hallucination endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_hallucination():
|
||||
request_data = {
|
||||
"prompt": "Test hallucination",
|
||||
"parameters": {"param1": "value1"},
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/hallucination", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "params_scores" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the chat completion endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_chat_completion():
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
request_data = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tools": [], # Assuming tools is part of the req as per the function
|
||||
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
|
||||
}
|
||||
response = await client.post("/v1/chat/completions", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "choices" in response.json()
|
||||
102
model_server/app/tests/test_loaders_cpu.py
Normal file
102
model_server/app/tests/test_loaders_cpu.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
102
model_server/app/tests/test_loaders_gpu.py
Normal file
102
model_server/app/tests/test_loaders_gpu.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cuda" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
102
model_server/app/tests/test_loaders_mps.py
Normal file
102
model_server/app/tests/test_loaders_mps.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "mps" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
assert mock_tokenizer.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_automodel.called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
assert mock_tokenizer.called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
assert mock_ort_model.called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_ov_model, mock_auto_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
assert mock_tokenizer.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
assert mock_ov_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
assert mock_auto_model.called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -29,6 +29,9 @@ openai = "1.50.2"
|
|||
tf-keras = "*"
|
||||
onnx = "1.17.0"
|
||||
onnxruntime = "1.19.2"
|
||||
httpx = "*"
|
||||
pytest-asyncio = "*"
|
||||
pytest = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
archgw_modelserver = "app.cli:run_server"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ openvino==2024.4.0
|
|||
psutil
|
||||
dateparser
|
||||
openai==1.50.2
|
||||
httpx
|
||||
pytest-asyncio
|
||||
pytest
|
||||
pandas
|
||||
tf-keras
|
||||
onnx==1.17.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue