Update guardrail_handler and its associated tests

This commit is contained in:
Shuguang Chen 2024-12-05 11:30:58 -08:00
parent b686cf8b87
commit 09f7e1e604
7 changed files with 115 additions and 1091 deletions

View file

@ -1,43 +1,14 @@
import app.commons.utilities as utils import app.commons.utilities as utils
from openai import OpenAI
from app.commons.constants import * from app.commons.constants import *
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
from app.model_handler.guardrails import ArchGuardHanlder from app.model_handler.guardrails import get_guardrail_handler
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from openai import OpenAI
logger = utils.get_model_server_logger() logger = utils.get_model_server_logger()
def get_guardrail_handler():
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
if device == "cuda":
model_name = "katanemo/Arch-Guard"
else:
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)
# Define the client # Define the client
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")

View file

@ -25,7 +25,7 @@ class ArchIntentHandler(ArchBaseHandler):
task_prompt: str, task_prompt: str,
tool_prompt: str, tool_prompt: str,
format_prompt: str, format_prompt: str,
intent_instruction: str, extra_instruction: str,
generation_params: Dict, generation_params: Dict,
): ):
""" """
@ -37,7 +37,7 @@ class ArchIntentHandler(ArchBaseHandler):
task_prompt (str): The main task prompt for the system. task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools. tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format. format_prompt (str): A prompt specifying the desired output format.
intent_instruction (str): Instructions specific to intent handling. extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model. generation_params (Dict): Generation parameters for the model.
""" """
@ -50,7 +50,7 @@ class ArchIntentHandler(ArchBaseHandler):
generation_params, generation_params,
) )
self.intent_instruction = intent_instruction self.extra_instruction = extra_instruction
@override @override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
@ -85,7 +85,7 @@ class ArchIntentHandler(ArchBaseHandler):
""" """
messages = self._process_messages( messages = self._process_messages(
req.messages, req.tools, self.intent_instruction req.messages, req.tools, self.extra_instruction
) )
model_response = self.client.chat.completions.create( model_response = self.client.chat.completions.create(

View file

@ -1,8 +1,11 @@
import time import time
import torch import torch
import numpy as np import numpy as np
import app.commons.utilities as utils
from pydantic import BaseModel from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
class GuardRequest(BaseModel): class GuardRequest(BaseModel):
@ -93,3 +96,27 @@ class ArchGuardHanlder:
guard_result["latency"] = time.perf_counter() - start_time guard_result["latency"] = time.perf_counter() - start_time
return guard_result return guard_result
def get_guardrail_handler(device: str = None):
if device is None:
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)

View file

@ -1,32 +1,25 @@
import pytest import pytest
import httpx import httpx
from fastapi.testclient import TestClient
from app.main import app # Assuming your FastAPI app is in main.py
from unittest.mock import patch
import app.commons.globals as glb
import logging
logger = logging.getLogger(__name__) from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app) client = TestClient(app)
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
# [TODO] Review: check the following code
# [TODO] Review: update the following code
# Unit tests for the health check endpoint # Unit tests for the health check endpoint
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_healthz(): async def test_healthz():
response = client.get("/healthz") response = client.get("/healthz")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"status": "ok"} assert response.json() == {"status": "ok"}
# [TODO] Review: update the following code # [TODO] Review: check the following code
# Unit test for the models endpoint # Unit test for the models endpoint
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_models(): async def test_models():
response = client.get("/models") response = client.get("/models")
assert response.status_code == 200 assert response.status_code == 200
@ -34,80 +27,27 @@ async def test_models():
assert len(response.json()["data"]) > 0 assert len(response.json()["data"]) > 0
# [TODO] Review: update the following code # [TODO] Review: check the following code
# Unit test for embeddings endpoint # Unit test for the guardrail endpoint
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_guardrail_endpoint():
async def test_embedding():
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
response = client.post("/embeddings", json=request_data)
if request_data["model"] == "katanemo/bge-large-en-v1.5":
assert response.status_code == 200
assert response.json()["object"] == "list"
assert "data" in response.json()
else:
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the guard endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_guard():
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"} request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
response = client.post("/guard", json=request_data) response = client.post("/guardrails", json=request_data)
assert response.status_code == 200 assert response.status_code == 200
assert "jailbreak_verdict" in response.json() assert "jailbreak_verdict" in response.json()
# [TODO] Review: update the following code # [TODO] Review: check the following code
# Unit test for the zero-shot endpoint # Unit test for the function calling endpoint
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_function_calling_endpoint():
async def test_zeroshot():
request_data = {
"input": "Test input",
"labels": ["label1", "label2"],
"model": "katanemo/bart-large-mnli",
}
response = client.post("/zeroshot", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "predicted_class" in response.json()
else:
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the hallucination endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_hallucination():
request_data = {
"prompt": "Test hallucination",
"parameters": {"param1": "value1"},
"model": "katanemo/bart-large-mnli",
}
response = client.post("/hallucination", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "params_scores" in response.json()
else:
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the chat completion endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_chat_completion():
async with httpx.AsyncClient(app=app, base_url="http://test") as client: async with httpx.AsyncClient(app=app, base_url="http://test") as client:
request_data = { request_data = {
"messages": [{"role": "user", "content": "Hello!"}], "messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function-1.5B", "model": "Arch-Function",
"tools": [], # Assuming tools is part of the req as per the function "tools": [],
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed "metadata": {"x-arch-state": "[]"},
} }
response = await client.post("/v1/chat/completions", json=request_data) response = await client.post("/function_calling", json=request_data)
assert response.status_code == 200 assert response.status_code == 200
assert "choices" in response.json() assert "choices" in response.json()

View file

@ -1,949 +0,0 @@
[
{
"case": "tool_call_halluciation",
"tokens": [
"<tool_call>"
],
"expect": 1,
"logprobs": [
[
-0.3333307206630707,
-1.5310522317886353,
-3.5098977088928223,
-3.9004578590393066,
-5.775152683258057,
-5.814209461212158,
-5.9574151039123535,
-6.0094895362854,
-6.0094895362854,
-6.673445224761963
]
]
},
{
"case": "parameter_value_hallucination",
"expect": 0,
"tokens": [
"<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Sea",
",",
" Australia",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"1",
"'}}\n",
"</tool_call>"
],
"logprobs": [
[
-0.008103232830762863,
-5.085402488708496,
-6.777836799621582,
-7.558959007263184,
-9.850253105163574,
-10.266852378845215,
-10.540244102478027,
-10.722506523132324,
-10.800618171691895,
-10.917786598205566
],
[
0.0,
-23.25142478942871,
-25.139137268066406,
-26.2847843170166,
-28.992677688598633,
-29.070789337158203,
-29.55248260498047,
-29.91700553894043,
-30.20341682434082,
-30.307567596435547
],
[
0.0,
-21.66313934326172,
-23.06916046142578,
-23.32953453063965,
-25.65988540649414,
-25.985353469848633,
-26.519121170043945,
-27.07892417907715,
-27.977216720581055,
-28.458908081054688
],
[
0.0,
-28.094383239746094,
-28.56305694580078,
-29.109844207763672,
-29.44832992553711,
-31.79170036315918,
-32.0,
-32.05207443237305,
-32.31244659423828,
-32.364524841308594
],
[
0.0,
-30.489830017089844,
-31.140766143798828,
-31.81774139404297,
-34.525634765625,
-35.8275032043457,
-36.504478454589844,
-39.05614471435547,
-40.123680114746094,
-40.696502685546875
],
[
0.0,
-25.646865844726562,
-26.66232681274414,
-27.781936645507812,
-28.979660034179688,
-31.140764236450195,
-31.92188835144043,
-31.973962783813477,
-33.04149627685547,
-33.58828353881836
],
[
0.0,
-23.511798858642578,
-24.136695861816406,
-25.230268478393555,
-25.777053833007812,
-25.80309295654297,
-26.45402717590332,
-26.636289596557617,
-26.740440368652344,
-26.896663665771484
],
[
0.0,
-22.366153717041016,
-24.683483123779297,
-26.610252380371094,
-26.610252380371094,
-27.313264846801758,
-27.67778778076172,
-28.510986328125,
-28.615135192871094,
-29.13588523864746
],
[
0.0,
-22.52237319946289,
-24.292919158935547,
-24.344993591308594,
-24.39706802368164,
-24.73555564880371,
-29.943042755126953,
-29.969079971313477,
-30.021154403686523,
-30.0341739654541
],
[
0.0,
-30.17738151550293,
-30.411718368530273,
-30.88039207458496,
-30.984540939331055,
-31.270952224731445,
-31.895851135253906,
-32.46867370605469,
-32.624900817871094,
-33.484134674072266
],
[
0.0,
-28.146459579467773,
-29.396255493164062,
-30.099267959594727,
-31.127744674682617,
-31.179821014404297,
-32.807159423828125,
-33.7445068359375,
-33.770545959472656,
-34.069976806640625
],
[
0.0,
-26.323841094970703,
-26.558177947998047,
-30.515867233276367,
-30.932466506958008,
-31.37510108947754,
-31.531326293945312,
-31.70056915283203,
-32.065093994140625,
-32.364524841308594
],
[
0.0,
-26.922698974609375,
-30.28152847290039,
-31.505287170410156,
-33.30187225341797,
-33.73148727416992,
-34.27827453613281,
-34.33034896850586,
-34.460533142089844,
-34.720909118652344
],
[
0.0,
-21.532955169677734,
-26.94873809814453,
-29.109848022460938,
-30.80228042602539,
-31.55736541748047,
-33.484134674072266,
-34.681854248046875,
-35.384864807128906,
-35.853538513183594
],
[
0.0,
-19.502033233642578,
-20.46541976928711,
-24.110658645629883,
-24.501218795776367,
-25.256305694580078,
-25.82912826538086,
-25.881202697753906,
-26.063465118408203,
-26.063465118408203
],
[
0.0,
-24.37103271484375,
-25.256305694580078,
-25.933277130126953,
-26.714401245117188,
-28.2506103515625,
-31.010576248168945,
-32.07810974121094,
-34.62977981567383,
-35.241661071777344
],
[
-1.1920922133867862e-06,
-14.398697853088379,
-14.424736976623535,
-17.158666610717773,
-17.41904067993164,
-18.200162887573242,
-18.434499740600586,
-18.66883659362793,
-19.71033477783203,
-19.71033477783203
],
[
-0.0001445904199499637,
-8.98305892944336,
-11.35246467590332,
-13.1490478515625,
-13.669795989990234,
-14.073375701904297,
-14.516012191772461,
-14.555068969726562,
-15.622602462768555,
-15.635622024536133
],
[
-0.44747352600097656,
-1.0202960968017578,
-8.467000961303711,
-10.914518356323242,
-11.25300407409668,
-11.435266494750977,
-12.346576690673828,
-13.075624465942383,
-13.12769889831543,
-13.231849670410156
],
[
-3.123767137527466,
-1.1188862323760986,
-1.639634370803833,
-2.0562336444854736,
-2.8633930683135986,
-2.9675419330596924,
-3.4882919788360596,
-3.69659161567688,
-4.217339515686035,
-4.243376731872559
],
[
-7.199982064776123e-05,
-9.76410961151123,
-11.144091606140137,
-16.507802963256836,
-17.132701873779297,
-17.44515037536621,
-17.9138240814209,
-18.33042335510254,
-18.9162654876709,
-19.39795684814453
],
[
0.0,
-22.991050720214844,
-23.824249267578125,
-24.969894409179688,
-25.46460723876953,
-25.829130172729492,
-26.480066299438477,
-26.909683227539062,
-27.33930206298828,
-27.391376495361328
],
[
-0.21928852796554565,
-1.625309705734253,
-9.775025367736816,
-12.977627754211426,
-16.388530731201172,
-17.091541290283203,
-19.044347763061523,
-19.38283348083496,
-19.460947036743164,
-19.59113311767578
],
[
0.0,
-24.006507873535156,
-27.443450927734375,
-27.729862213134766,
-28.12042236328125,
-28.276647567749023,
-28.927583694458008,
-30.099267959594727,
-31.479251861572266,
-32.07810974121094
],
[
0.0,
-18.17412567138672,
-18.772987365722656,
-21.689178466796875,
-21.92351531982422,
-23.7200984954834,
-23.79821014404297,
-23.79821014404297,
-24.032546997070312,
-25.308382034301758
],
[
-0.12947827577590942,
-2.1083219051361084,
-12.419143676757812,
-15.23118782043457,
-15.595710754394531,
-15.830047607421875,
-17.001731872558594,
-17.60059356689453,
-18.121341705322266,
-18.251529693603516
],
[
0.0,
-19.449962615966797,
-24.371034622192383,
-24.917821884155273,
-25.529701232910156,
-25.85516929626465,
-26.037429809570312,
-26.115543365478516,
-26.623271942138672,
-26.649309158325195
],
[
-0.03332124650478363,
-3.4181859493255615,
-15.759925842285156,
-15.812002182006836,
-16.593124389648438,
-17.894996643066406,
-18.09027671813965,
-18.79328727722168,
-19.144792556762695,
-20.147233963012695
],
[
0.0,
-21.142393112182617,
-22.157852172851562,
-23.511798858642578,
-24.657445907592773,
-25.021968841552734,
-25.5427188873291,
-25.59479331970215,
-25.75101661682129,
-25.95931625366211
],
[
0.0,
-23.04312515258789,
-24.94385528564453,
-26.323841094970703,
-27.54759979248047,
-28.563060760498047,
-29.786819458007812,
-30.620018005371094,
-30.69812774658203,
-31.08869171142578
],
[
0.0,
-26.167617797851562,
-28.771360397338867,
-29.55248260498047,
-30.906429290771484,
-31.114728927612305,
-31.414159774780273,
-31.622459411621094,
-31.713590621948242,
-31.726608276367188
],
[
-0.05012698099017143,
-3.018392562866211,
-11.740934371948242,
-13.146955490112305,
-13.797887802124023,
-14.943536758422852,
-16.037107467651367,
-16.375595092773438,
-16.714080810546875,
-17.36501693725586
],
[
-0.9704352021217346,
-0.7360983490943909,
-2.1941938400268555,
-4.225115776062012,
-5.0062360763549805,
-5.2666120529174805,
-5.839434623718262,
-7.2714948654174805,
-8.33902645111084,
-8.495253562927246
],
[
-0.014467108063399792,
-4.258565902709961,
-8.789079666137695,
-10.429437637329102,
-10.793962478637695,
-11.835458755493164,
-11.939607620239258,
-13.31959342956543,
-13.866378784179688,
-15.038063049316406
],
[
0.0,
-20.08787727355957,
-21.350692749023438,
-21.415786743164062,
-21.50691795349121,
-21.50691795349121,
-22.7176570892334,
-24.13669776916504,
-24.188772201538086,
-24.34499740600586
]
]
},
{
"case": "fail_case",
"expect": 0,
"tokens": [
"<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Seattle",
",",
" WA",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"7",
"'}}\n",
"</tool_call>"
],
"logprobs": [
[
-0.00013815402053296566,
-9.113236427307129,
-10.571331977844238,
-14.099404335021973,
-14.28166675567627,
-15.583537101745605,
-15.81787395477295,
-16.143341064453125,
-16.143341064453125,
-16.260509490966797
],
[
0.0,
-26.896663665771484,
-27.32628059387207,
-27.41741180419922,
-32.07810974121094,
-32.07810974121094,
-32.28641128540039,
-32.29943084716797,
-32.44263458251953,
-32.520748138427734
],
[
0.0,
-22.444263458251953,
-24.527257919311523,
-27.15703773498535,
-28.016273498535156,
-28.2506103515625,
-28.693246841430664,
-29.070789337158203,
-29.565500259399414,
-29.812854766845703
],
[
0.0,
-27.860050201416016,
-28.641170501708984,
-29.448333740234375,
-30.932466506958008,
-31.63547706604004,
-32.33848571777344,
-32.85923767089844,
-33.17168426513672,
-33.45809555053711
],
[
0.0,
-31.81774139404297,
-31.895854949951172,
-32.05207824707031,
-35.43694305419922,
-36.3482551574707,
-38.61351013183594,
-39.26444625854492,
-40.61839294433594,
-41.71196365356445
],
[
0.0,
-27.33930206298828,
-27.834014892578125,
-28.849472045898438,
-30.567943572998047,
-32.98942565917969,
-33.067535400390625,
-33.067535400390625,
-35.67127990722656,
-35.69731903076172
],
[
0.0,
-25.33441925048828,
-26.063465118408203,
-26.219690322875977,
-26.2457275390625,
-26.53213882446289,
-27.365337371826172,
-28.354759216308594,
-28.667207717895508,
-28.74532127380371
],
[
0.0,
-24.423107147216797,
-24.579330444335938,
-26.81855010986328,
-28.12042236328125,
-28.32872200012207,
-28.61513328552246,
-29.16191864013672,
-29.187957763671875,
-29.240032196044922
],
[
0.0,
-22.027664184570312,
-23.850284576416016,
-23.980472564697266,
-24.292922973632812,
-24.787633895874023,
-29.279088973999023,
-29.55248260498047,
-29.903987884521484,
-30.190399169921875
],
[
0.0,
-31.609439849853516,
-31.817739486694336,
-32.54678726196289,
-32.676971435546875,
-32.781124114990234,
-32.98942565917969,
-33.106590270996094,
-33.57526397705078,
-34.369407653808594
],
[
0.0,
-29.34418296813965,
-29.63059425354004,
-30.021156311035156,
-30.984540939331055,
-33.21073913574219,
-34.30431365966797,
-34.56468963623047,
-34.70789337158203,
-34.79902648925781
],
[
0.0,
-25.438566207885742,
-25.69894027709961,
-30.190397262573242,
-30.802276611328125,
-31.58340072631836,
-31.609437942504883,
-31.64849281311035,
-31.973960876464844,
-32.29943084716797
],
[
0.0,
-27.157039642333984,
-32.104148864746094,
-32.33848571777344,
-34.04393768310547,
-34.12205505371094,
-34.40846252441406,
-34.42148208618164,
-34.772987365722656,
-34.87713623046875
],
[
0.0,
-24.813671112060547,
-26.974777221679688,
-31.010578155517578,
-31.08869171142578,
-32.1822624206543,
-35.33279037475586,
-35.489013671875,
-36.999183654785156,
-37.88446044921875
],
[
0.0,
-20.46541976928711,
-20.647682189941406,
-23.069164276123047,
-24.136699676513672,
-25.438570022583008,
-25.646869659423828,
-26.193655014038086,
-26.297805786132812,
-26.506103515625
],
[
0.0,
-27.18307113647461,
-28.30268096923828,
-28.56305694580078,
-29.526439666748047,
-32.416595458984375,
-35.202598571777344,
-36.426361083984375,
-39.31651306152344,
-39.38160705566406
],
[
0.0,
-18.7469482421875,
-20.100894927978516,
-21.402767181396484,
-21.428804397583008,
-22.20992660522461,
-22.34011459350586,
-22.730674743652344,
-23.069162368774414,
-23.980472564697266
],
[
-3.576278118089249e-07,
-15.2579345703125,
-16.481693267822266,
-17.991863250732422,
-19.215621948242188,
-20.25712013244629,
-21.350692749023438,
-22.314077377319336,
-22.496337890625,
-22.938974380493164
],
[
-0.08506780862808228,
-2.506549835205078,
-14.848289489746094,
-15.473188400268555,
-16.33242416381836,
-16.358461380004883,
-16.566761016845703,
-17.03543472290039,
-17.686370849609375,
-17.816556930541992
],
[
-0.0194891095161438,
-4.445854187011719,
-5.591499328613281,
-5.956024169921875,
-6.685070037841797,
-13.142353057861328,
-13.558952331542969,
-15.173273086547852,
-15.303461074829102,
-15.85024642944336
],
[
-0.0005990855861455202,
-7.4212646484375,
-15.675132751464844,
-15.72720718383789,
-16.76870346069336,
-16.76870346069336,
-17.706050872802734,
-18.669435501098633,
-19.398483276367188,
-19.658857345581055
],
[
0.0,
-24.110658645629883,
-25.829130172729492,
-26.011390686035156,
-26.011390686035156,
-26.532140731811523,
-26.58421516418457,
-27.651750564575195,
-27.75589942932129,
-28.055330276489258
],
[
-1.1408883333206177,
-0.38580334186553955,
-7.494022369384766,
-12.519245147705078,
-14.576202392578125,
-16.034297943115234,
-16.945608139038086,
-17.908992767333984,
-18.664077758789062,
-19.34105110168457
],
[
0.0,
-26.688365936279297,
-29.83889389038086,
-30.177383422851562,
-30.64605712890625,
-31.244916915893555,
-31.270954132080078,
-32.83319854736328,
-34.655818939208984,
-34.89015579223633
],
[
0.0,
-18.929210662841797,
-19.16354751586914,
-23.589908599853516,
-24.683481216430664,
-24.995929718017578,
-25.516677856445312,
-25.542715072631836,
-25.77705192565918,
-26.063465118408203
],
[
-0.2519786059856415,
-1.5017764568328857,
-12.437495231628418,
-15.457839012145996,
-15.744250297546387,
-16.837820053100586,
-17.41064453125,
-17.56686782836914,
-17.61894416809082,
-18.035541534423828
],
[
0.0,
-20.517494201660156,
-24.683483123779297,
-25.67290496826172,
-26.58421516418457,
-27.651750564575195,
-27.781936645507812,
-27.912124633789062,
-28.09438705444336,
-28.445892333984375
],
[
-3.40932747349143e-05,
-10.284820556640625,
-18.252273559570312,
-20.17904281616211,
-21.663175582885742,
-22.027700424194336,
-22.288074493408203,
-22.704673767089844,
-23.12127113342285,
-23.277496337890625
],
[
0.0,
-22.60049057006836,
-25.46460723876953,
-25.829130172729492,
-26.063467025756836,
-27.287227630615234,
-27.391376495361328,
-27.4694881439209,
-27.67778778076172,
-28.055330276489258
],
[
0.0,
-23.902362823486328,
-28.823436737060547,
-29.240036010742188,
-29.31814956665039,
-29.917007446289062,
-30.021160125732422,
-31.21887969970703,
-32.416603088378906,
-32.416603088378906
],
[
0.0,
-28.641170501708984,
-31.947925567626953,
-32.59886169433594,
-33.848655700683594,
-34.109031677246094,
-34.73393249511719,
-35.02033996582031,
-35.02033996582031,
-36.074859619140625
],
[
-0.013183215633034706,
-4.335395336151123,
-19.619365692138672,
-20.035964965820312,
-20.244266510009766,
-21.311800003051758,
-21.441987991333008,
-22.561595916748047,
-23.108383178710938,
-23.264606475830078
],
[
-8.344646857949556e-07,
-14.190400123596191,
-15.9088716506958,
-18.17412567138672,
-18.46053695678711,
-18.46053695678711,
-18.512611389160156,
-18.90317153930664,
-19.059398651123047,
-19.085433959960938
],
[
0.0,
-17.70545196533203,
-18.903175354003906,
-20.829944610595703,
-22.574451446533203,
-22.860862731933594,
-23.069162368774414,
-23.32953643798828,
-23.694061279296875,
-24.188772201538086
],
[
0.0,
-20.022781372070312,
-21.038240432739258,
-21.220502853393555,
-22.496337890625,
-22.769729614257812,
-23.589908599853516,
-23.65500259399414,
-23.94141387939453,
-24.266881942749023
]
]
}
]

View file

@ -64,6 +64,9 @@ def test_process_messages(mock_hanlder):
} }
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
# [TODO] Review: Update the following test # [TODO] Review: Update the following test
@patch("app.commons.constants.arch_function_client") @patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder") @patch("app.commons.constants.arch_function_hanlder")

View file

@ -1,10 +1,7 @@
import os
import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import app.commons.globals as glb from app.model_handler.guardrails import get_guardrail_handler
# Mock constants # Mock constants
glb.DEVICE = "cpu" # Adjust as needed for your test case
arch_guard_model_type = { arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu", "cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard", "cuda": "katanemo/Arch-Guard",
@ -12,36 +9,71 @@ arch_guard_model_type = {
} }
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps` # [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for get_prompt_guard function # Test for `get_guardrail_handler()` function on `cpu`
@patch("app.loader.AutoTokenizer.from_pretrained") @patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained") @patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") @patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer): def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device device = "cpu"
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_ov_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock() mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) guardrail = get_guardrail_handler(device=device)
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with( mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True guardrail["model_name"], trust_remote_code=True
)
mock_ov_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
) )
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)