diff --git a/model_server/app/commons/globals.py b/model_server/app/commons/globals.py index 5a9fac29..92edec3a 100644 --- a/model_server/app/commons/globals.py +++ b/model_server/app/commons/globals.py @@ -1,43 +1,14 @@ import app.commons.utilities as utils +from openai import OpenAI from app.commons.constants import * from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler -from app.model_handler.guardrails import ArchGuardHanlder - -from transformers import AutoTokenizer, AutoModelForSequenceClassification -from optimum.intel import OVModelForSequenceClassification -from openai import OpenAI +from app.model_handler.guardrails import get_guardrail_handler logger = utils.get_model_server_logger() -def get_guardrail_handler(): - device = utils.get_device() - - model_class, model_name = None, None - if device == "cpu": - model_class = OVModelForSequenceClassification - model_name = "katanemo/Arch-Guard-cpu" - else: - model_class = AutoModelForSequenceClassification - if device == "cuda": - model_name = "katanemo/Arch-Guard" - else: - model_name = "katanemo/Arch-Guard" - - guardrail_dict = { - "device": device, - "model_name": model_name, - "tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True), - "model": model_class.from_pretrained( - model_name, device_map=device, low_cpu_mem_usage=True - ), - } - - return ArchGuardHanlder(model_dict=guardrail_dict) - - # Define the client ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index da135dec..ea13678c 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -25,7 +25,7 @@ class ArchIntentHandler(ArchBaseHandler): task_prompt: str, tool_prompt: str, format_prompt: str, - intent_instruction: str, + extra_instruction: str, generation_params: Dict, ): """ @@ -37,7 +37,7 @@ class ArchIntentHandler(ArchBaseHandler): task_prompt (str): The main task prompt for the system. tool_prompt (str): A prompt to describe tools. format_prompt (str): A prompt specifying the desired output format. - intent_instruction (str): Instructions specific to intent handling. + extra_instruction (str): Instructions specific to intent handling. generation_params (Dict): Generation parameters for the model. """ @@ -50,7 +50,7 @@ class ArchIntentHandler(ArchBaseHandler): generation_params, ) - self.intent_instruction = intent_instruction + self.extra_instruction = extra_instruction @override def _convert_tools(self, tools: List[Dict[str, Any]]) -> str: @@ -85,7 +85,7 @@ class ArchIntentHandler(ArchBaseHandler): """ messages = self._process_messages( - req.messages, req.tools, self.intent_instruction + req.messages, req.tools, self.extra_instruction ) model_response = self.client.chat.completions.create( diff --git a/model_server/app/model_handler/guardrails.py b/model_server/app/model_handler/guardrails.py index 07aec8fb..da4e6246 100644 --- a/model_server/app/model_handler/guardrails.py +++ b/model_server/app/model_handler/guardrails.py @@ -1,8 +1,11 @@ import time import torch import numpy as np +import app.commons.utilities as utils from pydantic import BaseModel +from transformers import AutoTokenizer, AutoModelForSequenceClassification +from optimum.intel import OVModelForSequenceClassification class GuardRequest(BaseModel): @@ -93,3 +96,27 @@ class ArchGuardHanlder: guard_result["latency"] = time.perf_counter() - start_time return guard_result + + +def get_guardrail_handler(device: str = None): + if device is None: + device = utils.get_device() + + model_class, model_name = None, None + if device == "cpu": + model_class = OVModelForSequenceClassification + model_name = "katanemo/Arch-Guard-cpu" + else: + model_class = AutoModelForSequenceClassification + model_name = "katanemo/Arch-Guard" + + guardrail_dict = { + "device": device, + "model_name": model_name, + "tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True), + "model": model_class.from_pretrained( + model_name, device_map=device, low_cpu_mem_usage=True + ), + } + + return ArchGuardHanlder(model_dict=guardrail_dict) diff --git a/model_server/app/tests/test_app.py b/model_server/app/tests/test_app.py index 208bac2a..5784ca55 100644 --- a/model_server/app/tests/test_app.py +++ b/model_server/app/tests/test_app.py @@ -1,32 +1,25 @@ import pytest import httpx -from fastapi.testclient import TestClient -from app.main import app # Assuming your FastAPI app is in main.py -from unittest.mock import patch -import app.commons.globals as glb -import logging -logger = logging.getLogger(__name__) +from fastapi.testclient import TestClient +from app.main import app + client = TestClient(app) -logger.info(f"Model will be loaded on device: {glb.DEVICE}") - -# [TODO] Review: update the following code +# [TODO] Review: check the following code # Unit tests for the health check endpoint @pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_healthz(): response = client.get("/healthz") assert response.status_code == 200 assert response.json() == {"status": "ok"} -# [TODO] Review: update the following code +# [TODO] Review: check the following code # Unit test for the models endpoint @pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' async def test_models(): response = client.get("/models") assert response.status_code == 200 @@ -34,80 +27,27 @@ async def test_models(): assert len(response.json()["data"]) > 0 -# [TODO] Review: update the following code -# Unit test for embeddings endpoint +# [TODO] Review: check the following code +# Unit test for the guardrail endpoint @pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' -async def test_embedding(): - request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"} - response = client.post("/embeddings", json=request_data) - if request_data["model"] == "katanemo/bge-large-en-v1.5": - assert response.status_code == 200 - assert response.json()["object"] == "list" - assert "data" in response.json() - else: - assert response.status_code == 400 - - -# [TODO] Review: update the following code -# Unit test for the guard endpoint -@pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' -async def test_guard(): +async def test_guardrail_endpoint(): request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"} - response = client.post("/guard", json=request_data) + response = client.post("/guardrails", json=request_data) assert response.status_code == 200 assert "jailbreak_verdict" in response.json() -# [TODO] Review: update the following code -# Unit test for the zero-shot endpoint +# [TODO] Review: check the following code +# Unit test for the function calling endpoint @pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' -async def test_zeroshot(): - request_data = { - "input": "Test input", - "labels": ["label1", "label2"], - "model": "katanemo/bart-large-mnli", - } - response = client.post("/zeroshot", json=request_data) - if request_data["model"] == "katanemo/bart-large-mnli": - assert response.status_code == 200 - assert "predicted_class" in response.json() - else: - assert response.status_code == 400 - - -# [TODO] Review: update the following code -# Unit test for the hallucination endpoint -@pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' -async def test_hallucination(): - request_data = { - "prompt": "Test hallucination", - "parameters": {"param1": "value1"}, - "model": "katanemo/bart-large-mnli", - } - response = client.post("/hallucination", json=request_data) - if request_data["model"] == "katanemo/bart-large-mnli": - assert response.status_code == 200 - assert "params_scores" in response.json() - else: - assert response.status_code == 400 - - -# [TODO] Review: update the following code -# Unit test for the chat completion endpoint -@pytest.mark.asyncio -@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu' -async def test_chat_completion(): +async def test_function_calling_endpoint(): async with httpx.AsyncClient(app=app, base_url="http://test") as client: request_data = { "messages": [{"role": "user", "content": "Hello!"}], - "model": "Arch-Function-1.5B", - "tools": [], # Assuming tools is part of the req as per the function - "metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed + "model": "Arch-Function", + "tools": [], + "metadata": {"x-arch-state": "[]"}, } - response = await client.post("/v1/chat/completions", json=request_data) + response = await client.post("/function_calling", json=request_data) assert response.status_code == 200 assert "choices" in response.json() diff --git a/model_server/app/tests/test_cases.json b/model_server/app/tests/test_cases.json deleted file mode 100644 index d74328ee..00000000 --- a/model_server/app/tests/test_cases.json +++ /dev/null @@ -1,949 +0,0 @@ -[ - { - "case": "tool_call_halluciation", - "tokens": [ - "" - ], - "expect": 1, - "logprobs": [ - [ - -0.3333307206630707, - -1.5310522317886353, - -3.5098977088928223, - -3.9004578590393066, - -5.775152683258057, - -5.814209461212158, - -5.9574151039123535, - -6.0094895362854, - -6.0094895362854, - -6.673445224761963 - ] - ] - }, - { - "case": "parameter_value_hallucination", - "expect": 0, - "tokens": [ - "", - "\n", - "{'", - "name", - "':", - " '", - "get", - "_current", - "_weather", - "',", - " '", - "arguments", - "':", - " {'", - "location", - "':", - " '", - "Sea", - ",", - " Australia", - "',", - " '", - "unit", - "':", - " '", - "c", - "elsius", - "',", - " '", - "days", - "':", - " '", - "1", - "'}}\n", - "" - ], - "logprobs": [ - [ - -0.008103232830762863, - -5.085402488708496, - -6.777836799621582, - -7.558959007263184, - -9.850253105163574, - -10.266852378845215, - -10.540244102478027, - -10.722506523132324, - -10.800618171691895, - -10.917786598205566 - ], - [ - 0.0, - -23.25142478942871, - -25.139137268066406, - -26.2847843170166, - -28.992677688598633, - -29.070789337158203, - -29.55248260498047, - -29.91700553894043, - -30.20341682434082, - -30.307567596435547 - ], - [ - 0.0, - -21.66313934326172, - -23.06916046142578, - -23.32953453063965, - -25.65988540649414, - -25.985353469848633, - -26.519121170043945, - -27.07892417907715, - -27.977216720581055, - -28.458908081054688 - ], - [ - 0.0, - -28.094383239746094, - -28.56305694580078, - -29.109844207763672, - -29.44832992553711, - -31.79170036315918, - -32.0, - -32.05207443237305, - -32.31244659423828, - -32.364524841308594 - ], - [ - 0.0, - -30.489830017089844, - -31.140766143798828, - -31.81774139404297, - -34.525634765625, - -35.8275032043457, - -36.504478454589844, - -39.05614471435547, - -40.123680114746094, - -40.696502685546875 - ], - [ - 0.0, - -25.646865844726562, - -26.66232681274414, - -27.781936645507812, - -28.979660034179688, - -31.140764236450195, - -31.92188835144043, - -31.973962783813477, - -33.04149627685547, - -33.58828353881836 - ], - [ - 0.0, - -23.511798858642578, - -24.136695861816406, - -25.230268478393555, - -25.777053833007812, - -25.80309295654297, - -26.45402717590332, - -26.636289596557617, - -26.740440368652344, - -26.896663665771484 - ], - [ - 0.0, - -22.366153717041016, - -24.683483123779297, - -26.610252380371094, - -26.610252380371094, - -27.313264846801758, - -27.67778778076172, - -28.510986328125, - -28.615135192871094, - -29.13588523864746 - ], - [ - 0.0, - -22.52237319946289, - -24.292919158935547, - -24.344993591308594, - -24.39706802368164, - -24.73555564880371, - -29.943042755126953, - -29.969079971313477, - -30.021154403686523, - -30.0341739654541 - ], - [ - 0.0, - -30.17738151550293, - -30.411718368530273, - -30.88039207458496, - -30.984540939331055, - -31.270952224731445, - -31.895851135253906, - -32.46867370605469, - -32.624900817871094, - -33.484134674072266 - ], - [ - 0.0, - -28.146459579467773, - -29.396255493164062, - -30.099267959594727, - -31.127744674682617, - -31.179821014404297, - -32.807159423828125, - -33.7445068359375, - -33.770545959472656, - -34.069976806640625 - ], - [ - 0.0, - -26.323841094970703, - -26.558177947998047, - -30.515867233276367, - -30.932466506958008, - -31.37510108947754, - -31.531326293945312, - -31.70056915283203, - -32.065093994140625, - -32.364524841308594 - ], - [ - 0.0, - -26.922698974609375, - -30.28152847290039, - -31.505287170410156, - -33.30187225341797, - -33.73148727416992, - -34.27827453613281, - -34.33034896850586, - -34.460533142089844, - -34.720909118652344 - ], - [ - 0.0, - -21.532955169677734, - -26.94873809814453, - -29.109848022460938, - -30.80228042602539, - -31.55736541748047, - -33.484134674072266, - -34.681854248046875, - -35.384864807128906, - -35.853538513183594 - ], - [ - 0.0, - -19.502033233642578, - -20.46541976928711, - -24.110658645629883, - -24.501218795776367, - -25.256305694580078, - -25.82912826538086, - -25.881202697753906, - -26.063465118408203, - -26.063465118408203 - ], - [ - 0.0, - -24.37103271484375, - -25.256305694580078, - -25.933277130126953, - -26.714401245117188, - -28.2506103515625, - -31.010576248168945, - -32.07810974121094, - -34.62977981567383, - -35.241661071777344 - ], - [ - -1.1920922133867862e-06, - -14.398697853088379, - -14.424736976623535, - -17.158666610717773, - -17.41904067993164, - -18.200162887573242, - -18.434499740600586, - -18.66883659362793, - -19.71033477783203, - -19.71033477783203 - ], - [ - -0.0001445904199499637, - -8.98305892944336, - -11.35246467590332, - -13.1490478515625, - -13.669795989990234, - -14.073375701904297, - -14.516012191772461, - -14.555068969726562, - -15.622602462768555, - -15.635622024536133 - ], - [ - -0.44747352600097656, - -1.0202960968017578, - -8.467000961303711, - -10.914518356323242, - -11.25300407409668, - -11.435266494750977, - -12.346576690673828, - -13.075624465942383, - -13.12769889831543, - -13.231849670410156 - ], - [ - -3.123767137527466, - -1.1188862323760986, - -1.639634370803833, - -2.0562336444854736, - -2.8633930683135986, - -2.9675419330596924, - -3.4882919788360596, - -3.69659161567688, - -4.217339515686035, - -4.243376731872559 - ], - [ - -7.199982064776123e-05, - -9.76410961151123, - -11.144091606140137, - -16.507802963256836, - -17.132701873779297, - -17.44515037536621, - -17.9138240814209, - -18.33042335510254, - -18.9162654876709, - -19.39795684814453 - ], - [ - 0.0, - -22.991050720214844, - -23.824249267578125, - -24.969894409179688, - -25.46460723876953, - -25.829130172729492, - -26.480066299438477, - -26.909683227539062, - -27.33930206298828, - -27.391376495361328 - ], - [ - -0.21928852796554565, - -1.625309705734253, - -9.775025367736816, - -12.977627754211426, - -16.388530731201172, - -17.091541290283203, - -19.044347763061523, - -19.38283348083496, - -19.460947036743164, - -19.59113311767578 - ], - [ - 0.0, - -24.006507873535156, - -27.443450927734375, - -27.729862213134766, - -28.12042236328125, - -28.276647567749023, - -28.927583694458008, - -30.099267959594727, - -31.479251861572266, - -32.07810974121094 - ], - [ - 0.0, - -18.17412567138672, - -18.772987365722656, - -21.689178466796875, - -21.92351531982422, - -23.7200984954834, - -23.79821014404297, - -23.79821014404297, - -24.032546997070312, - -25.308382034301758 - ], - [ - -0.12947827577590942, - -2.1083219051361084, - -12.419143676757812, - -15.23118782043457, - -15.595710754394531, - -15.830047607421875, - -17.001731872558594, - -17.60059356689453, - -18.121341705322266, - -18.251529693603516 - ], - [ - 0.0, - -19.449962615966797, - -24.371034622192383, - -24.917821884155273, - -25.529701232910156, - -25.85516929626465, - -26.037429809570312, - -26.115543365478516, - -26.623271942138672, - -26.649309158325195 - ], - [ - -0.03332124650478363, - -3.4181859493255615, - -15.759925842285156, - -15.812002182006836, - -16.593124389648438, - -17.894996643066406, - -18.09027671813965, - -18.79328727722168, - -19.144792556762695, - -20.147233963012695 - ], - [ - 0.0, - -21.142393112182617, - -22.157852172851562, - -23.511798858642578, - -24.657445907592773, - -25.021968841552734, - -25.5427188873291, - -25.59479331970215, - -25.75101661682129, - -25.95931625366211 - ], - [ - 0.0, - -23.04312515258789, - -24.94385528564453, - -26.323841094970703, - -27.54759979248047, - -28.563060760498047, - -29.786819458007812, - -30.620018005371094, - -30.69812774658203, - -31.08869171142578 - ], - [ - 0.0, - -26.167617797851562, - -28.771360397338867, - -29.55248260498047, - -30.906429290771484, - -31.114728927612305, - -31.414159774780273, - -31.622459411621094, - -31.713590621948242, - -31.726608276367188 - ], - [ - -0.05012698099017143, - -3.018392562866211, - -11.740934371948242, - -13.146955490112305, - -13.797887802124023, - -14.943536758422852, - -16.037107467651367, - -16.375595092773438, - -16.714080810546875, - -17.36501693725586 - ], - [ - -0.9704352021217346, - -0.7360983490943909, - -2.1941938400268555, - -4.225115776062012, - -5.0062360763549805, - -5.2666120529174805, - -5.839434623718262, - -7.2714948654174805, - -8.33902645111084, - -8.495253562927246 - ], - [ - -0.014467108063399792, - -4.258565902709961, - -8.789079666137695, - -10.429437637329102, - -10.793962478637695, - -11.835458755493164, - -11.939607620239258, - -13.31959342956543, - -13.866378784179688, - -15.038063049316406 - ], - [ - 0.0, - -20.08787727355957, - -21.350692749023438, - -21.415786743164062, - -21.50691795349121, - -21.50691795349121, - -22.7176570892334, - -24.13669776916504, - -24.188772201538086, - -24.34499740600586 - ] - ] - }, - { - "case": "fail_case", - "expect": 0, - "tokens": [ - "", - "\n", - "{'", - "name", - "':", - " '", - "get", - "_current", - "_weather", - "',", - " '", - "arguments", - "':", - " {'", - "location", - "':", - " '", - "Seattle", - ",", - " WA", - "',", - " '", - "unit", - "':", - " '", - "c", - "elsius", - "',", - " '", - "days", - "':", - " '", - "7", - "'}}\n", - "" - ], - "logprobs": [ - [ - -0.00013815402053296566, - -9.113236427307129, - -10.571331977844238, - -14.099404335021973, - -14.28166675567627, - -15.583537101745605, - -15.81787395477295, - -16.143341064453125, - -16.143341064453125, - -16.260509490966797 - ], - [ - 0.0, - -26.896663665771484, - -27.32628059387207, - -27.41741180419922, - -32.07810974121094, - -32.07810974121094, - -32.28641128540039, - -32.29943084716797, - -32.44263458251953, - -32.520748138427734 - ], - [ - 0.0, - -22.444263458251953, - -24.527257919311523, - -27.15703773498535, - -28.016273498535156, - -28.2506103515625, - -28.693246841430664, - -29.070789337158203, - -29.565500259399414, - -29.812854766845703 - ], - [ - 0.0, - -27.860050201416016, - -28.641170501708984, - -29.448333740234375, - -30.932466506958008, - -31.63547706604004, - -32.33848571777344, - -32.85923767089844, - -33.17168426513672, - -33.45809555053711 - ], - [ - 0.0, - -31.81774139404297, - -31.895854949951172, - -32.05207824707031, - -35.43694305419922, - -36.3482551574707, - -38.61351013183594, - -39.26444625854492, - -40.61839294433594, - -41.71196365356445 - ], - [ - 0.0, - -27.33930206298828, - -27.834014892578125, - -28.849472045898438, - -30.567943572998047, - -32.98942565917969, - -33.067535400390625, - -33.067535400390625, - -35.67127990722656, - -35.69731903076172 - ], - [ - 0.0, - -25.33441925048828, - -26.063465118408203, - -26.219690322875977, - -26.2457275390625, - -26.53213882446289, - -27.365337371826172, - -28.354759216308594, - -28.667207717895508, - -28.74532127380371 - ], - [ - 0.0, - -24.423107147216797, - -24.579330444335938, - -26.81855010986328, - -28.12042236328125, - -28.32872200012207, - -28.61513328552246, - -29.16191864013672, - -29.187957763671875, - -29.240032196044922 - ], - [ - 0.0, - -22.027664184570312, - -23.850284576416016, - -23.980472564697266, - -24.292922973632812, - -24.787633895874023, - -29.279088973999023, - -29.55248260498047, - -29.903987884521484, - -30.190399169921875 - ], - [ - 0.0, - -31.609439849853516, - -31.817739486694336, - -32.54678726196289, - -32.676971435546875, - -32.781124114990234, - -32.98942565917969, - -33.106590270996094, - -33.57526397705078, - -34.369407653808594 - ], - [ - 0.0, - -29.34418296813965, - -29.63059425354004, - -30.021156311035156, - -30.984540939331055, - -33.21073913574219, - -34.30431365966797, - -34.56468963623047, - -34.70789337158203, - -34.79902648925781 - ], - [ - 0.0, - -25.438566207885742, - -25.69894027709961, - -30.190397262573242, - -30.802276611328125, - -31.58340072631836, - -31.609437942504883, - -31.64849281311035, - -31.973960876464844, - -32.29943084716797 - ], - [ - 0.0, - -27.157039642333984, - -32.104148864746094, - -32.33848571777344, - -34.04393768310547, - -34.12205505371094, - -34.40846252441406, - -34.42148208618164, - -34.772987365722656, - -34.87713623046875 - ], - [ - 0.0, - -24.813671112060547, - -26.974777221679688, - -31.010578155517578, - -31.08869171142578, - -32.1822624206543, - -35.33279037475586, - -35.489013671875, - -36.999183654785156, - -37.88446044921875 - ], - [ - 0.0, - -20.46541976928711, - -20.647682189941406, - -23.069164276123047, - -24.136699676513672, - -25.438570022583008, - -25.646869659423828, - -26.193655014038086, - -26.297805786132812, - -26.506103515625 - ], - [ - 0.0, - -27.18307113647461, - -28.30268096923828, - -28.56305694580078, - -29.526439666748047, - -32.416595458984375, - -35.202598571777344, - -36.426361083984375, - -39.31651306152344, - -39.38160705566406 - ], - [ - 0.0, - -18.7469482421875, - -20.100894927978516, - -21.402767181396484, - -21.428804397583008, - -22.20992660522461, - -22.34011459350586, - -22.730674743652344, - -23.069162368774414, - -23.980472564697266 - ], - [ - -3.576278118089249e-07, - -15.2579345703125, - -16.481693267822266, - -17.991863250732422, - -19.215621948242188, - -20.25712013244629, - -21.350692749023438, - -22.314077377319336, - -22.496337890625, - -22.938974380493164 - ], - [ - -0.08506780862808228, - -2.506549835205078, - -14.848289489746094, - -15.473188400268555, - -16.33242416381836, - -16.358461380004883, - -16.566761016845703, - -17.03543472290039, - -17.686370849609375, - -17.816556930541992 - ], - [ - -0.0194891095161438, - -4.445854187011719, - -5.591499328613281, - -5.956024169921875, - -6.685070037841797, - -13.142353057861328, - -13.558952331542969, - -15.173273086547852, - -15.303461074829102, - -15.85024642944336 - ], - [ - -0.0005990855861455202, - -7.4212646484375, - -15.675132751464844, - -15.72720718383789, - -16.76870346069336, - -16.76870346069336, - -17.706050872802734, - -18.669435501098633, - -19.398483276367188, - -19.658857345581055 - ], - [ - 0.0, - -24.110658645629883, - -25.829130172729492, - -26.011390686035156, - -26.011390686035156, - -26.532140731811523, - -26.58421516418457, - -27.651750564575195, - -27.75589942932129, - -28.055330276489258 - ], - [ - -1.1408883333206177, - -0.38580334186553955, - -7.494022369384766, - -12.519245147705078, - -14.576202392578125, - -16.034297943115234, - -16.945608139038086, - -17.908992767333984, - -18.664077758789062, - -19.34105110168457 - ], - [ - 0.0, - -26.688365936279297, - -29.83889389038086, - -30.177383422851562, - -30.64605712890625, - -31.244916915893555, - -31.270954132080078, - -32.83319854736328, - -34.655818939208984, - -34.89015579223633 - ], - [ - 0.0, - -18.929210662841797, - -19.16354751586914, - -23.589908599853516, - -24.683481216430664, - -24.995929718017578, - -25.516677856445312, - -25.542715072631836, - -25.77705192565918, - -26.063465118408203 - ], - [ - -0.2519786059856415, - -1.5017764568328857, - -12.437495231628418, - -15.457839012145996, - -15.744250297546387, - -16.837820053100586, - -17.41064453125, - -17.56686782836914, - -17.61894416809082, - -18.035541534423828 - ], - [ - 0.0, - -20.517494201660156, - -24.683483123779297, - -25.67290496826172, - -26.58421516418457, - -27.651750564575195, - -27.781936645507812, - -27.912124633789062, - -28.09438705444336, - -28.445892333984375 - ], - [ - -3.40932747349143e-05, - -10.284820556640625, - -18.252273559570312, - -20.17904281616211, - -21.663175582885742, - -22.027700424194336, - -22.288074493408203, - -22.704673767089844, - -23.12127113342285, - -23.277496337890625 - ], - [ - 0.0, - -22.60049057006836, - -25.46460723876953, - -25.829130172729492, - -26.063467025756836, - -27.287227630615234, - -27.391376495361328, - -27.4694881439209, - -27.67778778076172, - -28.055330276489258 - ], - [ - 0.0, - -23.902362823486328, - -28.823436737060547, - -29.240036010742188, - -29.31814956665039, - -29.917007446289062, - -30.021160125732422, - -31.21887969970703, - -32.416603088378906, - -32.416603088378906 - ], - [ - 0.0, - -28.641170501708984, - -31.947925567626953, - -32.59886169433594, - -33.848655700683594, - -34.109031677246094, - -34.73393249511719, - -35.02033996582031, - -35.02033996582031, - -36.074859619140625 - ], - [ - -0.013183215633034706, - -4.335395336151123, - -19.619365692138672, - -20.035964965820312, - -20.244266510009766, - -21.311800003051758, - -21.441987991333008, - -22.561595916748047, - -23.108383178710938, - -23.264606475830078 - ], - [ - -8.344646857949556e-07, - -14.190400123596191, - -15.9088716506958, - -18.17412567138672, - -18.46053695678711, - -18.46053695678711, - -18.512611389160156, - -18.90317153930664, - -19.059398651123047, - -19.085433959960938 - ], - [ - 0.0, - -17.70545196533203, - -18.903175354003906, - -20.829944610595703, - -22.574451446533203, - -22.860862731933594, - -23.069162368774414, - -23.32953643798828, - -23.694061279296875, - -24.188772201538086 - ], - [ - 0.0, - -20.022781372070312, - -21.038240432739258, - -21.220502853393555, - -22.496337890625, - -22.769729614257812, - -23.589908599853516, - -23.65500259399414, - -23.94141387939453, - -24.266881942749023 - ] - ] - } -] diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index 2df2a1f4..9f15507e 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -64,6 +64,9 @@ def test_process_messages(mock_hanlder): } +# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated. + + # [TODO] Review: Update the following test @patch("app.commons.constants.arch_function_client") @patch("app.commons.constants.arch_function_hanlder") diff --git a/model_server/app/tests/test_guardrails.py b/model_server/app/tests/test_guardrails.py index c490a662..de9176b3 100644 --- a/model_server/app/tests/test_guardrails.py +++ b/model_server/app/tests/test_guardrails.py @@ -1,10 +1,7 @@ -import os -import pytest from unittest.mock import patch, MagicMock -import app.commons.globals as glb +from app.model_handler.guardrails import get_guardrail_handler # Mock constants -glb.DEVICE = "cpu" # Adjust as needed for your test case arch_guard_model_type = { "cpu": "katanemo/Arch-Guard-cpu", "cuda": "katanemo/Arch-Guard", @@ -12,36 +9,71 @@ arch_guard_model_type = { } -# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps` -# Test for get_prompt_guard function -@patch("app.loader.AutoTokenizer.from_pretrained") -@patch("app.loader.OVModelForSequenceClassification.from_pretrained") -@patch("app.loader.AutoModelForSequenceClassification.from_pretrained") -def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer): - # Mock model based on device - if glb.DEVICE == "cpu": - mock_ov_model.return_value = MagicMock() - else: - mock_auto_model.return_value = MagicMock() +# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps` +# Test for `get_guardrail_handler()` function on `cpu` +@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained") +@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained") +@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained") +def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer): + device = "cpu" + mock_ov_model.return_value = MagicMock() mock_tokenizer.return_value = MagicMock() - prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE]) + guardrail = get_guardrail_handler(device=device) - # Assertions - assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE] mock_tokenizer.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], trust_remote_code=True + guardrail["model_name"], trust_remote_code=True + ) + + mock_ov_model.assert_called_once_with( + guardrail["model_name"], + device_map=device, + low_cpu_mem_usage=True, + ) + + +# Test for `get_guardrail_handler()` function on `cuda` +@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained") +@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained") +@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained") +def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer): + device = "cuda" + + mock_auto_model.return_value = MagicMock() + mock_tokenizer.return_value = MagicMock() + + guardrail = get_guardrail_handler(device=device) + + mock_tokenizer.assert_called_once_with( + guardrail["model_name"], trust_remote_code=True + ) + + mock_auto_model.assert_called_once_with( + guardrail["model_name"], + device_map=device, + low_cpu_mem_usage=True, + ) + + +# Test for `get_guardrail_handler()` function on `mps` +@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained") +@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained") +@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained") +def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer): + device = "mps" + + mock_auto_model.return_value = MagicMock() + mock_tokenizer.return_value = MagicMock() + + guardrail = get_guardrail_handler(device=device) + + mock_tokenizer.assert_called_once_with( + guardrail["model_name"], trust_remote_code=True + ) + + mock_auto_model.assert_called_once_with( + guardrail["model_name"], + device_map=device, + low_cpu_mem_usage=True, ) - if glb.DEVICE == "cpu": - mock_ov_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - ) - else: - mock_auto_model.assert_called_once_with( - arch_guard_model_type[glb.DEVICE], - device_map=glb.DEVICE, - low_cpu_mem_usage=True, - )