Update guardrail_handler and its associated tests

This commit is contained in:
Shuguang Chen 2024-12-05 11:30:58 -08:00
parent b686cf8b87
commit 09f7e1e604
7 changed files with 115 additions and 1091 deletions

View file

@ -1,43 +1,14 @@
import app.commons.utilities as utils
from openai import OpenAI
from app.commons.constants import *
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
from app.model_handler.guardrails import ArchGuardHanlder
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from openai import OpenAI
from app.model_handler.guardrails import get_guardrail_handler
logger = utils.get_model_server_logger()
def get_guardrail_handler():
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
if device == "cuda":
model_name = "katanemo/Arch-Guard"
else:
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)
# Define the client
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")

View file

@ -25,7 +25,7 @@ class ArchIntentHandler(ArchBaseHandler):
task_prompt: str,
tool_prompt: str,
format_prompt: str,
intent_instruction: str,
extra_instruction: str,
generation_params: Dict,
):
"""
@ -37,7 +37,7 @@ class ArchIntentHandler(ArchBaseHandler):
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
intent_instruction (str): Instructions specific to intent handling.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
"""
@ -50,7 +50,7 @@ class ArchIntentHandler(ArchBaseHandler):
generation_params,
)
self.intent_instruction = intent_instruction
self.extra_instruction = extra_instruction
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
@ -85,7 +85,7 @@ class ArchIntentHandler(ArchBaseHandler):
"""
messages = self._process_messages(
req.messages, req.tools, self.intent_instruction
req.messages, req.tools, self.extra_instruction
)
model_response = self.client.chat.completions.create(

View file

@ -1,8 +1,11 @@
import time
import torch
import numpy as np
import app.commons.utilities as utils
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
class GuardRequest(BaseModel):
@ -93,3 +96,27 @@ class ArchGuardHanlder:
guard_result["latency"] = time.perf_counter() - start_time
return guard_result
def get_guardrail_handler(device: str = None):
if device is None:
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)

View file

@ -1,32 +1,25 @@
import pytest
import httpx
from fastapi.testclient import TestClient
from app.main import app # Assuming your FastAPI app is in main.py
from unittest.mock import patch
import app.commons.globals as glb
import logging
logger = logging.getLogger(__name__)
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
# [TODO] Review: update the following code
# [TODO] Review: check the following code
# Unit tests for the health check endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_healthz():
response = client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
# [TODO] Review: update the following code
# [TODO] Review: check the following code
# Unit test for the models endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_models():
response = client.get("/models")
assert response.status_code == 200
@ -34,80 +27,27 @@ async def test_models():
assert len(response.json()["data"]) > 0
# [TODO] Review: update the following code
# Unit test for embeddings endpoint
# [TODO] Review: check the following code
# Unit test for the guardrail endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_embedding():
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
response = client.post("/embeddings", json=request_data)
if request_data["model"] == "katanemo/bge-large-en-v1.5":
assert response.status_code == 200
assert response.json()["object"] == "list"
assert "data" in response.json()
else:
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the guard endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_guard():
async def test_guardrail_endpoint():
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
response = client.post("/guard", json=request_data)
response = client.post("/guardrails", json=request_data)
assert response.status_code == 200
assert "jailbreak_verdict" in response.json()
# [TODO] Review: update the following code
# Unit test for the zero-shot endpoint
# [TODO] Review: check the following code
# Unit test for the function calling endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_zeroshot():
request_data = {
"input": "Test input",
"labels": ["label1", "label2"],
"model": "katanemo/bart-large-mnli",
}
response = client.post("/zeroshot", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "predicted_class" in response.json()
else:
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the hallucination endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_hallucination():
request_data = {
"prompt": "Test hallucination",
"parameters": {"param1": "value1"},
"model": "katanemo/bart-large-mnli",
}
response = client.post("/hallucination", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "params_scores" in response.json()
else:
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the chat completion endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_chat_completion():
async def test_function_calling_endpoint():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
request_data = {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function-1.5B",
"tools": [], # Assuming tools is part of the req as per the function
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
"model": "Arch-Function",
"tools": [],
"metadata": {"x-arch-state": "[]"},
}
response = await client.post("/v1/chat/completions", json=request_data)
response = await client.post("/function_calling", json=request_data)
assert response.status_code == 200
assert "choices" in response.json()

View file

@ -1,949 +0,0 @@
[
{
"case": "tool_call_halluciation",
"tokens": [
"<tool_call>"
],
"expect": 1,
"logprobs": [
[
-0.3333307206630707,
-1.5310522317886353,
-3.5098977088928223,
-3.9004578590393066,
-5.775152683258057,
-5.814209461212158,
-5.9574151039123535,
-6.0094895362854,
-6.0094895362854,
-6.673445224761963
]
]
},
{
"case": "parameter_value_hallucination",
"expect": 0,
"tokens": [
"<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Sea",
",",
" Australia",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"1",
"'}}\n",
"</tool_call>"
],
"logprobs": [
[
-0.008103232830762863,
-5.085402488708496,
-6.777836799621582,
-7.558959007263184,
-9.850253105163574,
-10.266852378845215,
-10.540244102478027,
-10.722506523132324,
-10.800618171691895,
-10.917786598205566
],
[
0.0,
-23.25142478942871,
-25.139137268066406,
-26.2847843170166,
-28.992677688598633,
-29.070789337158203,
-29.55248260498047,
-29.91700553894043,
-30.20341682434082,
-30.307567596435547
],
[
0.0,
-21.66313934326172,
-23.06916046142578,
-23.32953453063965,
-25.65988540649414,
-25.985353469848633,
-26.519121170043945,
-27.07892417907715,
-27.977216720581055,
-28.458908081054688
],
[
0.0,
-28.094383239746094,
-28.56305694580078,
-29.109844207763672,
-29.44832992553711,
-31.79170036315918,
-32.0,
-32.05207443237305,
-32.31244659423828,
-32.364524841308594
],
[
0.0,
-30.489830017089844,
-31.140766143798828,
-31.81774139404297,
-34.525634765625,
-35.8275032043457,
-36.504478454589844,
-39.05614471435547,
-40.123680114746094,
-40.696502685546875
],
[
0.0,
-25.646865844726562,
-26.66232681274414,
-27.781936645507812,
-28.979660034179688,
-31.140764236450195,
-31.92188835144043,
-31.973962783813477,
-33.04149627685547,
-33.58828353881836
],
[
0.0,
-23.511798858642578,
-24.136695861816406,
-25.230268478393555,
-25.777053833007812,
-25.80309295654297,
-26.45402717590332,
-26.636289596557617,
-26.740440368652344,
-26.896663665771484
],
[
0.0,
-22.366153717041016,
-24.683483123779297,
-26.610252380371094,
-26.610252380371094,
-27.313264846801758,
-27.67778778076172,
-28.510986328125,
-28.615135192871094,
-29.13588523864746
],
[
0.0,
-22.52237319946289,
-24.292919158935547,
-24.344993591308594,
-24.39706802368164,
-24.73555564880371,
-29.943042755126953,
-29.969079971313477,
-30.021154403686523,
-30.0341739654541
],
[
0.0,
-30.17738151550293,
-30.411718368530273,
-30.88039207458496,
-30.984540939331055,
-31.270952224731445,
-31.895851135253906,
-32.46867370605469,
-32.624900817871094,
-33.484134674072266
],
[
0.0,
-28.146459579467773,
-29.396255493164062,
-30.099267959594727,
-31.127744674682617,
-31.179821014404297,
-32.807159423828125,
-33.7445068359375,
-33.770545959472656,
-34.069976806640625
],
[
0.0,
-26.323841094970703,
-26.558177947998047,
-30.515867233276367,
-30.932466506958008,
-31.37510108947754,
-31.531326293945312,
-31.70056915283203,
-32.065093994140625,
-32.364524841308594
],
[
0.0,
-26.922698974609375,
-30.28152847290039,
-31.505287170410156,
-33.30187225341797,
-33.73148727416992,
-34.27827453613281,
-34.33034896850586,
-34.460533142089844,
-34.720909118652344
],
[
0.0,
-21.532955169677734,
-26.94873809814453,
-29.109848022460938,
-30.80228042602539,
-31.55736541748047,
-33.484134674072266,
-34.681854248046875,
-35.384864807128906,
-35.853538513183594
],
[
0.0,
-19.502033233642578,
-20.46541976928711,
-24.110658645629883,
-24.501218795776367,
-25.256305694580078,
-25.82912826538086,
-25.881202697753906,
-26.063465118408203,
-26.063465118408203
],
[
0.0,
-24.37103271484375,
-25.256305694580078,
-25.933277130126953,
-26.714401245117188,
-28.2506103515625,
-31.010576248168945,
-32.07810974121094,
-34.62977981567383,
-35.241661071777344
],
[
-1.1920922133867862e-06,
-14.398697853088379,
-14.424736976623535,
-17.158666610717773,
-17.41904067993164,
-18.200162887573242,
-18.434499740600586,
-18.66883659362793,
-19.71033477783203,
-19.71033477783203
],
[
-0.0001445904199499637,
-8.98305892944336,
-11.35246467590332,
-13.1490478515625,
-13.669795989990234,
-14.073375701904297,
-14.516012191772461,
-14.555068969726562,
-15.622602462768555,
-15.635622024536133
],
[
-0.44747352600097656,
-1.0202960968017578,
-8.467000961303711,
-10.914518356323242,
-11.25300407409668,
-11.435266494750977,
-12.346576690673828,
-13.075624465942383,
-13.12769889831543,
-13.231849670410156
],
[
-3.123767137527466,
-1.1188862323760986,
-1.639634370803833,
-2.0562336444854736,
-2.8633930683135986,
-2.9675419330596924,
-3.4882919788360596,
-3.69659161567688,
-4.217339515686035,
-4.243376731872559
],
[
-7.199982064776123e-05,
-9.76410961151123,
-11.144091606140137,
-16.507802963256836,
-17.132701873779297,
-17.44515037536621,
-17.9138240814209,
-18.33042335510254,
-18.9162654876709,
-19.39795684814453
],
[
0.0,
-22.991050720214844,
-23.824249267578125,
-24.969894409179688,
-25.46460723876953,
-25.829130172729492,
-26.480066299438477,
-26.909683227539062,
-27.33930206298828,
-27.391376495361328
],
[
-0.21928852796554565,
-1.625309705734253,
-9.775025367736816,
-12.977627754211426,
-16.388530731201172,
-17.091541290283203,
-19.044347763061523,
-19.38283348083496,
-19.460947036743164,
-19.59113311767578
],
[
0.0,
-24.006507873535156,
-27.443450927734375,
-27.729862213134766,
-28.12042236328125,
-28.276647567749023,
-28.927583694458008,
-30.099267959594727,
-31.479251861572266,
-32.07810974121094
],
[
0.0,
-18.17412567138672,
-18.772987365722656,
-21.689178466796875,
-21.92351531982422,
-23.7200984954834,
-23.79821014404297,
-23.79821014404297,
-24.032546997070312,
-25.308382034301758
],
[
-0.12947827577590942,
-2.1083219051361084,
-12.419143676757812,
-15.23118782043457,
-15.595710754394531,
-15.830047607421875,
-17.001731872558594,
-17.60059356689453,
-18.121341705322266,
-18.251529693603516
],
[
0.0,
-19.449962615966797,
-24.371034622192383,
-24.917821884155273,
-25.529701232910156,
-25.85516929626465,
-26.037429809570312,
-26.115543365478516,
-26.623271942138672,
-26.649309158325195
],
[
-0.03332124650478363,
-3.4181859493255615,
-15.759925842285156,
-15.812002182006836,
-16.593124389648438,
-17.894996643066406,
-18.09027671813965,
-18.79328727722168,
-19.144792556762695,
-20.147233963012695
],
[
0.0,
-21.142393112182617,
-22.157852172851562,
-23.511798858642578,
-24.657445907592773,
-25.021968841552734,
-25.5427188873291,
-25.59479331970215,
-25.75101661682129,
-25.95931625366211
],
[
0.0,
-23.04312515258789,
-24.94385528564453,
-26.323841094970703,
-27.54759979248047,
-28.563060760498047,
-29.786819458007812,
-30.620018005371094,
-30.69812774658203,
-31.08869171142578
],
[
0.0,
-26.167617797851562,
-28.771360397338867,
-29.55248260498047,
-30.906429290771484,
-31.114728927612305,
-31.414159774780273,
-31.622459411621094,
-31.713590621948242,
-31.726608276367188
],
[
-0.05012698099017143,
-3.018392562866211,
-11.740934371948242,
-13.146955490112305,
-13.797887802124023,
-14.943536758422852,
-16.037107467651367,
-16.375595092773438,
-16.714080810546875,
-17.36501693725586
],
[
-0.9704352021217346,
-0.7360983490943909,
-2.1941938400268555,
-4.225115776062012,
-5.0062360763549805,
-5.2666120529174805,
-5.839434623718262,
-7.2714948654174805,
-8.33902645111084,
-8.495253562927246
],
[
-0.014467108063399792,
-4.258565902709961,
-8.789079666137695,
-10.429437637329102,
-10.793962478637695,
-11.835458755493164,
-11.939607620239258,
-13.31959342956543,
-13.866378784179688,
-15.038063049316406
],
[
0.0,
-20.08787727355957,
-21.350692749023438,
-21.415786743164062,
-21.50691795349121,
-21.50691795349121,
-22.7176570892334,
-24.13669776916504,
-24.188772201538086,
-24.34499740600586
]
]
},
{
"case": "fail_case",
"expect": 0,
"tokens": [
"<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Seattle",
",",
" WA",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"7",
"'}}\n",
"</tool_call>"
],
"logprobs": [
[
-0.00013815402053296566,
-9.113236427307129,
-10.571331977844238,
-14.099404335021973,
-14.28166675567627,
-15.583537101745605,
-15.81787395477295,
-16.143341064453125,
-16.143341064453125,
-16.260509490966797
],
[
0.0,
-26.896663665771484,
-27.32628059387207,
-27.41741180419922,
-32.07810974121094,
-32.07810974121094,
-32.28641128540039,
-32.29943084716797,
-32.44263458251953,
-32.520748138427734
],
[
0.0,
-22.444263458251953,
-24.527257919311523,
-27.15703773498535,
-28.016273498535156,
-28.2506103515625,
-28.693246841430664,
-29.070789337158203,
-29.565500259399414,
-29.812854766845703
],
[
0.0,
-27.860050201416016,
-28.641170501708984,
-29.448333740234375,
-30.932466506958008,
-31.63547706604004,
-32.33848571777344,
-32.85923767089844,
-33.17168426513672,
-33.45809555053711
],
[
0.0,
-31.81774139404297,
-31.895854949951172,
-32.05207824707031,
-35.43694305419922,
-36.3482551574707,
-38.61351013183594,
-39.26444625854492,
-40.61839294433594,
-41.71196365356445
],
[
0.0,
-27.33930206298828,
-27.834014892578125,
-28.849472045898438,
-30.567943572998047,
-32.98942565917969,
-33.067535400390625,
-33.067535400390625,
-35.67127990722656,
-35.69731903076172
],
[
0.0,
-25.33441925048828,
-26.063465118408203,
-26.219690322875977,
-26.2457275390625,
-26.53213882446289,
-27.365337371826172,
-28.354759216308594,
-28.667207717895508,
-28.74532127380371
],
[
0.0,
-24.423107147216797,
-24.579330444335938,
-26.81855010986328,
-28.12042236328125,
-28.32872200012207,
-28.61513328552246,
-29.16191864013672,
-29.187957763671875,
-29.240032196044922
],
[
0.0,
-22.027664184570312,
-23.850284576416016,
-23.980472564697266,
-24.292922973632812,
-24.787633895874023,
-29.279088973999023,
-29.55248260498047,
-29.903987884521484,
-30.190399169921875
],
[
0.0,
-31.609439849853516,
-31.817739486694336,
-32.54678726196289,
-32.676971435546875,
-32.781124114990234,
-32.98942565917969,
-33.106590270996094,
-33.57526397705078,
-34.369407653808594
],
[
0.0,
-29.34418296813965,
-29.63059425354004,
-30.021156311035156,
-30.984540939331055,
-33.21073913574219,
-34.30431365966797,
-34.56468963623047,
-34.70789337158203,
-34.79902648925781
],
[
0.0,
-25.438566207885742,
-25.69894027709961,
-30.190397262573242,
-30.802276611328125,
-31.58340072631836,
-31.609437942504883,
-31.64849281311035,
-31.973960876464844,
-32.29943084716797
],
[
0.0,
-27.157039642333984,
-32.104148864746094,
-32.33848571777344,
-34.04393768310547,
-34.12205505371094,
-34.40846252441406,
-34.42148208618164,
-34.772987365722656,
-34.87713623046875
],
[
0.0,
-24.813671112060547,
-26.974777221679688,
-31.010578155517578,
-31.08869171142578,
-32.1822624206543,
-35.33279037475586,
-35.489013671875,
-36.999183654785156,
-37.88446044921875
],
[
0.0,
-20.46541976928711,
-20.647682189941406,
-23.069164276123047,
-24.136699676513672,
-25.438570022583008,
-25.646869659423828,
-26.193655014038086,
-26.297805786132812,
-26.506103515625
],
[
0.0,
-27.18307113647461,
-28.30268096923828,
-28.56305694580078,
-29.526439666748047,
-32.416595458984375,
-35.202598571777344,
-36.426361083984375,
-39.31651306152344,
-39.38160705566406
],
[
0.0,
-18.7469482421875,
-20.100894927978516,
-21.402767181396484,
-21.428804397583008,
-22.20992660522461,
-22.34011459350586,
-22.730674743652344,
-23.069162368774414,
-23.980472564697266
],
[
-3.576278118089249e-07,
-15.2579345703125,
-16.481693267822266,
-17.991863250732422,
-19.215621948242188,
-20.25712013244629,
-21.350692749023438,
-22.314077377319336,
-22.496337890625,
-22.938974380493164
],
[
-0.08506780862808228,
-2.506549835205078,
-14.848289489746094,
-15.473188400268555,
-16.33242416381836,
-16.358461380004883,
-16.566761016845703,
-17.03543472290039,
-17.686370849609375,
-17.816556930541992
],
[
-0.0194891095161438,
-4.445854187011719,
-5.591499328613281,
-5.956024169921875,
-6.685070037841797,
-13.142353057861328,
-13.558952331542969,
-15.173273086547852,
-15.303461074829102,
-15.85024642944336
],
[
-0.0005990855861455202,
-7.4212646484375,
-15.675132751464844,
-15.72720718383789,
-16.76870346069336,
-16.76870346069336,
-17.706050872802734,
-18.669435501098633,
-19.398483276367188,
-19.658857345581055
],
[
0.0,
-24.110658645629883,
-25.829130172729492,
-26.011390686035156,
-26.011390686035156,
-26.532140731811523,
-26.58421516418457,
-27.651750564575195,
-27.75589942932129,
-28.055330276489258
],
[
-1.1408883333206177,
-0.38580334186553955,
-7.494022369384766,
-12.519245147705078,
-14.576202392578125,
-16.034297943115234,
-16.945608139038086,
-17.908992767333984,
-18.664077758789062,
-19.34105110168457
],
[
0.0,
-26.688365936279297,
-29.83889389038086,
-30.177383422851562,
-30.64605712890625,
-31.244916915893555,
-31.270954132080078,
-32.83319854736328,
-34.655818939208984,
-34.89015579223633
],
[
0.0,
-18.929210662841797,
-19.16354751586914,
-23.589908599853516,
-24.683481216430664,
-24.995929718017578,
-25.516677856445312,
-25.542715072631836,
-25.77705192565918,
-26.063465118408203
],
[
-0.2519786059856415,
-1.5017764568328857,
-12.437495231628418,
-15.457839012145996,
-15.744250297546387,
-16.837820053100586,
-17.41064453125,
-17.56686782836914,
-17.61894416809082,
-18.035541534423828
],
[
0.0,
-20.517494201660156,
-24.683483123779297,
-25.67290496826172,
-26.58421516418457,
-27.651750564575195,
-27.781936645507812,
-27.912124633789062,
-28.09438705444336,
-28.445892333984375
],
[
-3.40932747349143e-05,
-10.284820556640625,
-18.252273559570312,
-20.17904281616211,
-21.663175582885742,
-22.027700424194336,
-22.288074493408203,
-22.704673767089844,
-23.12127113342285,
-23.277496337890625
],
[
0.0,
-22.60049057006836,
-25.46460723876953,
-25.829130172729492,
-26.063467025756836,
-27.287227630615234,
-27.391376495361328,
-27.4694881439209,
-27.67778778076172,
-28.055330276489258
],
[
0.0,
-23.902362823486328,
-28.823436737060547,
-29.240036010742188,
-29.31814956665039,
-29.917007446289062,
-30.021160125732422,
-31.21887969970703,
-32.416603088378906,
-32.416603088378906
],
[
0.0,
-28.641170501708984,
-31.947925567626953,
-32.59886169433594,
-33.848655700683594,
-34.109031677246094,
-34.73393249511719,
-35.02033996582031,
-35.02033996582031,
-36.074859619140625
],
[
-0.013183215633034706,
-4.335395336151123,
-19.619365692138672,
-20.035964965820312,
-20.244266510009766,
-21.311800003051758,
-21.441987991333008,
-22.561595916748047,
-23.108383178710938,
-23.264606475830078
],
[
-8.344646857949556e-07,
-14.190400123596191,
-15.9088716506958,
-18.17412567138672,
-18.46053695678711,
-18.46053695678711,
-18.512611389160156,
-18.90317153930664,
-19.059398651123047,
-19.085433959960938
],
[
0.0,
-17.70545196533203,
-18.903175354003906,
-20.829944610595703,
-22.574451446533203,
-22.860862731933594,
-23.069162368774414,
-23.32953643798828,
-23.694061279296875,
-24.188772201538086
],
[
0.0,
-20.022781372070312,
-21.038240432739258,
-21.220502853393555,
-22.496337890625,
-22.769729614257812,
-23.589908599853516,
-23.65500259399414,
-23.94141387939453,
-24.266881942749023
]
]
}
]

View file

@ -64,6 +64,9 @@ def test_process_messages(mock_hanlder):
}
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
# [TODO] Review: Update the following test
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")

View file

@ -1,10 +1,7 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.model_handler.guardrails import get_guardrail_handler
# Mock constants
glb.DEVICE = "cpu" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
@ -12,36 +9,71 @@ arch_guard_model_type = {
}
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cpu"
mock_ov_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
guardrail = get_guardrail_handler(device=device)
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
guardrail["model_name"], trust_remote_code=True
)
mock_ov_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)