mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Update guardrail_handler and its associated tests
This commit is contained in:
parent
b686cf8b87
commit
09f7e1e604
7 changed files with 115 additions and 1091 deletions
|
|
@ -1,43 +1,14 @@
|
|||
import app.commons.utilities as utils
|
||||
|
||||
from openai import OpenAI
|
||||
from app.commons.constants import *
|
||||
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
|
||||
from app.model_handler.guardrails import ArchGuardHanlder
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
from openai import OpenAI
|
||||
from app.model_handler.guardrails import get_guardrail_handler
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_guardrail_handler():
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
if device == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard-cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
if device == "cuda":
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
else:
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return ArchGuardHanlder(model_dict=guardrail_dict)
|
||||
|
||||
|
||||
# Define the client
|
||||
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
format_prompt: str,
|
||||
intent_instruction: str,
|
||||
extra_instruction: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
"""
|
||||
|
|
@ -37,7 +37,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
intent_instruction (str): Instructions specific to intent handling.
|
||||
extra_instruction (str): Instructions specific to intent handling.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
"""
|
||||
|
||||
|
|
@ -50,7 +50,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
generation_params,
|
||||
)
|
||||
|
||||
self.intent_instruction = intent_instruction
|
||||
self.extra_instruction = extra_instruction
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
|
|
@ -85,7 +85,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
"""
|
||||
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, self.intent_instruction
|
||||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import app.commons.utilities as utils
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
|
|
@ -93,3 +96,27 @@ class ArchGuardHanlder:
|
|||
guard_result["latency"] = time.perf_counter() - start_time
|
||||
|
||||
return guard_result
|
||||
|
||||
|
||||
def get_guardrail_handler(device: str = None):
|
||||
if device is None:
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
if device == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard-cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return ArchGuardHanlder(model_dict=guardrail_dict)
|
||||
|
|
|
|||
|
|
@ -1,32 +1,25 @@
|
|||
import pytest
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app # Assuming your FastAPI app is in main.py
|
||||
from unittest.mock import patch
|
||||
import app.commons.globals as glb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# [TODO] Review: check the following code
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_healthz():
|
||||
response = client.get("/healthz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# [TODO] Review: check the following code
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_models():
|
||||
response = client.get("/models")
|
||||
assert response.status_code == 200
|
||||
|
|
@ -34,80 +27,27 @@ async def test_models():
|
|||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for embeddings endpoint
|
||||
# [TODO] Review: check the following code
|
||||
# Unit test for the guardrail endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_embedding():
|
||||
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
|
||||
response = client.post("/embeddings", json=request_data)
|
||||
if request_data["model"] == "katanemo/bge-large-en-v1.5":
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert "data" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the guard endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_guard():
|
||||
async def test_guardrail_endpoint():
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guard", json=request_data)
|
||||
response = client.post("/guardrails", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the zero-shot endpoint
|
||||
# [TODO] Review: check the following code
|
||||
# Unit test for the function calling endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_zeroshot():
|
||||
request_data = {
|
||||
"input": "Test input",
|
||||
"labels": ["label1", "label2"],
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/zeroshot", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "predicted_class" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the hallucination endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_hallucination():
|
||||
request_data = {
|
||||
"prompt": "Test hallucination",
|
||||
"parameters": {"param1": "value1"},
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/hallucination", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "params_scores" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the chat completion endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_chat_completion():
|
||||
async def test_function_calling_endpoint():
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
request_data = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tools": [], # Assuming tools is part of the req as per the function
|
||||
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
|
||||
"model": "Arch-Function",
|
||||
"tools": [],
|
||||
"metadata": {"x-arch-state": "[]"},
|
||||
}
|
||||
response = await client.post("/v1/chat/completions", json=request_data)
|
||||
response = await client.post("/function_calling", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "choices" in response.json()
|
||||
|
|
|
|||
|
|
@ -1,949 +0,0 @@
|
|||
[
|
||||
{
|
||||
"case": "tool_call_halluciation",
|
||||
"tokens": [
|
||||
"<tool_call>"
|
||||
],
|
||||
"expect": 1,
|
||||
"logprobs": [
|
||||
[
|
||||
-0.3333307206630707,
|
||||
-1.5310522317886353,
|
||||
-3.5098977088928223,
|
||||
-3.9004578590393066,
|
||||
-5.775152683258057,
|
||||
-5.814209461212158,
|
||||
-5.9574151039123535,
|
||||
-6.0094895362854,
|
||||
-6.0094895362854,
|
||||
-6.673445224761963
|
||||
]
|
||||
]
|
||||
},
|
||||
{
|
||||
"case": "parameter_value_hallucination",
|
||||
"expect": 0,
|
||||
"tokens": [
|
||||
"<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Sea",
|
||||
",",
|
||||
" Australia",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"1",
|
||||
"'}}\n",
|
||||
"</tool_call>"
|
||||
],
|
||||
"logprobs": [
|
||||
[
|
||||
-0.008103232830762863,
|
||||
-5.085402488708496,
|
||||
-6.777836799621582,
|
||||
-7.558959007263184,
|
||||
-9.850253105163574,
|
||||
-10.266852378845215,
|
||||
-10.540244102478027,
|
||||
-10.722506523132324,
|
||||
-10.800618171691895,
|
||||
-10.917786598205566
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-23.25142478942871,
|
||||
-25.139137268066406,
|
||||
-26.2847843170166,
|
||||
-28.992677688598633,
|
||||
-29.070789337158203,
|
||||
-29.55248260498047,
|
||||
-29.91700553894043,
|
||||
-30.20341682434082,
|
||||
-30.307567596435547
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-21.66313934326172,
|
||||
-23.06916046142578,
|
||||
-23.32953453063965,
|
||||
-25.65988540649414,
|
||||
-25.985353469848633,
|
||||
-26.519121170043945,
|
||||
-27.07892417907715,
|
||||
-27.977216720581055,
|
||||
-28.458908081054688
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-28.094383239746094,
|
||||
-28.56305694580078,
|
||||
-29.109844207763672,
|
||||
-29.44832992553711,
|
||||
-31.79170036315918,
|
||||
-32.0,
|
||||
-32.05207443237305,
|
||||
-32.31244659423828,
|
||||
-32.364524841308594
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-30.489830017089844,
|
||||
-31.140766143798828,
|
||||
-31.81774139404297,
|
||||
-34.525634765625,
|
||||
-35.8275032043457,
|
||||
-36.504478454589844,
|
||||
-39.05614471435547,
|
||||
-40.123680114746094,
|
||||
-40.696502685546875
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-25.646865844726562,
|
||||
-26.66232681274414,
|
||||
-27.781936645507812,
|
||||
-28.979660034179688,
|
||||
-31.140764236450195,
|
||||
-31.92188835144043,
|
||||
-31.973962783813477,
|
||||
-33.04149627685547,
|
||||
-33.58828353881836
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-23.511798858642578,
|
||||
-24.136695861816406,
|
||||
-25.230268478393555,
|
||||
-25.777053833007812,
|
||||
-25.80309295654297,
|
||||
-26.45402717590332,
|
||||
-26.636289596557617,
|
||||
-26.740440368652344,
|
||||
-26.896663665771484
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-22.366153717041016,
|
||||
-24.683483123779297,
|
||||
-26.610252380371094,
|
||||
-26.610252380371094,
|
||||
-27.313264846801758,
|
||||
-27.67778778076172,
|
||||
-28.510986328125,
|
||||
-28.615135192871094,
|
||||
-29.13588523864746
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-22.52237319946289,
|
||||
-24.292919158935547,
|
||||
-24.344993591308594,
|
||||
-24.39706802368164,
|
||||
-24.73555564880371,
|
||||
-29.943042755126953,
|
||||
-29.969079971313477,
|
||||
-30.021154403686523,
|
||||
-30.0341739654541
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-30.17738151550293,
|
||||
-30.411718368530273,
|
||||
-30.88039207458496,
|
||||
-30.984540939331055,
|
||||
-31.270952224731445,
|
||||
-31.895851135253906,
|
||||
-32.46867370605469,
|
||||
-32.624900817871094,
|
||||
-33.484134674072266
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-28.146459579467773,
|
||||
-29.396255493164062,
|
||||
-30.099267959594727,
|
||||
-31.127744674682617,
|
||||
-31.179821014404297,
|
||||
-32.807159423828125,
|
||||
-33.7445068359375,
|
||||
-33.770545959472656,
|
||||
-34.069976806640625
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-26.323841094970703,
|
||||
-26.558177947998047,
|
||||
-30.515867233276367,
|
||||
-30.932466506958008,
|
||||
-31.37510108947754,
|
||||
-31.531326293945312,
|
||||
-31.70056915283203,
|
||||
-32.065093994140625,
|
||||
-32.364524841308594
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-26.922698974609375,
|
||||
-30.28152847290039,
|
||||
-31.505287170410156,
|
||||
-33.30187225341797,
|
||||
-33.73148727416992,
|
||||
-34.27827453613281,
|
||||
-34.33034896850586,
|
||||
-34.460533142089844,
|
||||
-34.720909118652344
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-21.532955169677734,
|
||||
-26.94873809814453,
|
||||
-29.109848022460938,
|
||||
-30.80228042602539,
|
||||
-31.55736541748047,
|
||||
-33.484134674072266,
|
||||
-34.681854248046875,
|
||||
-35.384864807128906,
|
||||
-35.853538513183594
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-19.502033233642578,
|
||||
-20.46541976928711,
|
||||
-24.110658645629883,
|
||||
-24.501218795776367,
|
||||
-25.256305694580078,
|
||||
-25.82912826538086,
|
||||
-25.881202697753906,
|
||||
-26.063465118408203,
|
||||
-26.063465118408203
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-24.37103271484375,
|
||||
-25.256305694580078,
|
||||
-25.933277130126953,
|
||||
-26.714401245117188,
|
||||
-28.2506103515625,
|
||||
-31.010576248168945,
|
||||
-32.07810974121094,
|
||||
-34.62977981567383,
|
||||
-35.241661071777344
|
||||
],
|
||||
[
|
||||
-1.1920922133867862e-06,
|
||||
-14.398697853088379,
|
||||
-14.424736976623535,
|
||||
-17.158666610717773,
|
||||
-17.41904067993164,
|
||||
-18.200162887573242,
|
||||
-18.434499740600586,
|
||||
-18.66883659362793,
|
||||
-19.71033477783203,
|
||||
-19.71033477783203
|
||||
],
|
||||
[
|
||||
-0.0001445904199499637,
|
||||
-8.98305892944336,
|
||||
-11.35246467590332,
|
||||
-13.1490478515625,
|
||||
-13.669795989990234,
|
||||
-14.073375701904297,
|
||||
-14.516012191772461,
|
||||
-14.555068969726562,
|
||||
-15.622602462768555,
|
||||
-15.635622024536133
|
||||
],
|
||||
[
|
||||
-0.44747352600097656,
|
||||
-1.0202960968017578,
|
||||
-8.467000961303711,
|
||||
-10.914518356323242,
|
||||
-11.25300407409668,
|
||||
-11.435266494750977,
|
||||
-12.346576690673828,
|
||||
-13.075624465942383,
|
||||
-13.12769889831543,
|
||||
-13.231849670410156
|
||||
],
|
||||
[
|
||||
-3.123767137527466,
|
||||
-1.1188862323760986,
|
||||
-1.639634370803833,
|
||||
-2.0562336444854736,
|
||||
-2.8633930683135986,
|
||||
-2.9675419330596924,
|
||||
-3.4882919788360596,
|
||||
-3.69659161567688,
|
||||
-4.217339515686035,
|
||||
-4.243376731872559
|
||||
],
|
||||
[
|
||||
-7.199982064776123e-05,
|
||||
-9.76410961151123,
|
||||
-11.144091606140137,
|
||||
-16.507802963256836,
|
||||
-17.132701873779297,
|
||||
-17.44515037536621,
|
||||
-17.9138240814209,
|
||||
-18.33042335510254,
|
||||
-18.9162654876709,
|
||||
-19.39795684814453
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-22.991050720214844,
|
||||
-23.824249267578125,
|
||||
-24.969894409179688,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.480066299438477,
|
||||
-26.909683227539062,
|
||||
-27.33930206298828,
|
||||
-27.391376495361328
|
||||
],
|
||||
[
|
||||
-0.21928852796554565,
|
||||
-1.625309705734253,
|
||||
-9.775025367736816,
|
||||
-12.977627754211426,
|
||||
-16.388530731201172,
|
||||
-17.091541290283203,
|
||||
-19.044347763061523,
|
||||
-19.38283348083496,
|
||||
-19.460947036743164,
|
||||
-19.59113311767578
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-24.006507873535156,
|
||||
-27.443450927734375,
|
||||
-27.729862213134766,
|
||||
-28.12042236328125,
|
||||
-28.276647567749023,
|
||||
-28.927583694458008,
|
||||
-30.099267959594727,
|
||||
-31.479251861572266,
|
||||
-32.07810974121094
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-18.17412567138672,
|
||||
-18.772987365722656,
|
||||
-21.689178466796875,
|
||||
-21.92351531982422,
|
||||
-23.7200984954834,
|
||||
-23.79821014404297,
|
||||
-23.79821014404297,
|
||||
-24.032546997070312,
|
||||
-25.308382034301758
|
||||
],
|
||||
[
|
||||
-0.12947827577590942,
|
||||
-2.1083219051361084,
|
||||
-12.419143676757812,
|
||||
-15.23118782043457,
|
||||
-15.595710754394531,
|
||||
-15.830047607421875,
|
||||
-17.001731872558594,
|
||||
-17.60059356689453,
|
||||
-18.121341705322266,
|
||||
-18.251529693603516
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-19.449962615966797,
|
||||
-24.371034622192383,
|
||||
-24.917821884155273,
|
||||
-25.529701232910156,
|
||||
-25.85516929626465,
|
||||
-26.037429809570312,
|
||||
-26.115543365478516,
|
||||
-26.623271942138672,
|
||||
-26.649309158325195
|
||||
],
|
||||
[
|
||||
-0.03332124650478363,
|
||||
-3.4181859493255615,
|
||||
-15.759925842285156,
|
||||
-15.812002182006836,
|
||||
-16.593124389648438,
|
||||
-17.894996643066406,
|
||||
-18.09027671813965,
|
||||
-18.79328727722168,
|
||||
-19.144792556762695,
|
||||
-20.147233963012695
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-21.142393112182617,
|
||||
-22.157852172851562,
|
||||
-23.511798858642578,
|
||||
-24.657445907592773,
|
||||
-25.021968841552734,
|
||||
-25.5427188873291,
|
||||
-25.59479331970215,
|
||||
-25.75101661682129,
|
||||
-25.95931625366211
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-23.04312515258789,
|
||||
-24.94385528564453,
|
||||
-26.323841094970703,
|
||||
-27.54759979248047,
|
||||
-28.563060760498047,
|
||||
-29.786819458007812,
|
||||
-30.620018005371094,
|
||||
-30.69812774658203,
|
||||
-31.08869171142578
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-26.167617797851562,
|
||||
-28.771360397338867,
|
||||
-29.55248260498047,
|
||||
-30.906429290771484,
|
||||
-31.114728927612305,
|
||||
-31.414159774780273,
|
||||
-31.622459411621094,
|
||||
-31.713590621948242,
|
||||
-31.726608276367188
|
||||
],
|
||||
[
|
||||
-0.05012698099017143,
|
||||
-3.018392562866211,
|
||||
-11.740934371948242,
|
||||
-13.146955490112305,
|
||||
-13.797887802124023,
|
||||
-14.943536758422852,
|
||||
-16.037107467651367,
|
||||
-16.375595092773438,
|
||||
-16.714080810546875,
|
||||
-17.36501693725586
|
||||
],
|
||||
[
|
||||
-0.9704352021217346,
|
||||
-0.7360983490943909,
|
||||
-2.1941938400268555,
|
||||
-4.225115776062012,
|
||||
-5.0062360763549805,
|
||||
-5.2666120529174805,
|
||||
-5.839434623718262,
|
||||
-7.2714948654174805,
|
||||
-8.33902645111084,
|
||||
-8.495253562927246
|
||||
],
|
||||
[
|
||||
-0.014467108063399792,
|
||||
-4.258565902709961,
|
||||
-8.789079666137695,
|
||||
-10.429437637329102,
|
||||
-10.793962478637695,
|
||||
-11.835458755493164,
|
||||
-11.939607620239258,
|
||||
-13.31959342956543,
|
||||
-13.866378784179688,
|
||||
-15.038063049316406
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-20.08787727355957,
|
||||
-21.350692749023438,
|
||||
-21.415786743164062,
|
||||
-21.50691795349121,
|
||||
-21.50691795349121,
|
||||
-22.7176570892334,
|
||||
-24.13669776916504,
|
||||
-24.188772201538086,
|
||||
-24.34499740600586
|
||||
]
|
||||
]
|
||||
},
|
||||
{
|
||||
"case": "fail_case",
|
||||
"expect": 0,
|
||||
"tokens": [
|
||||
"<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Seattle",
|
||||
",",
|
||||
" WA",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"7",
|
||||
"'}}\n",
|
||||
"</tool_call>"
|
||||
],
|
||||
"logprobs": [
|
||||
[
|
||||
-0.00013815402053296566,
|
||||
-9.113236427307129,
|
||||
-10.571331977844238,
|
||||
-14.099404335021973,
|
||||
-14.28166675567627,
|
||||
-15.583537101745605,
|
||||
-15.81787395477295,
|
||||
-16.143341064453125,
|
||||
-16.143341064453125,
|
||||
-16.260509490966797
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-26.896663665771484,
|
||||
-27.32628059387207,
|
||||
-27.41741180419922,
|
||||
-32.07810974121094,
|
||||
-32.07810974121094,
|
||||
-32.28641128540039,
|
||||
-32.29943084716797,
|
||||
-32.44263458251953,
|
||||
-32.520748138427734
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-22.444263458251953,
|
||||
-24.527257919311523,
|
||||
-27.15703773498535,
|
||||
-28.016273498535156,
|
||||
-28.2506103515625,
|
||||
-28.693246841430664,
|
||||
-29.070789337158203,
|
||||
-29.565500259399414,
|
||||
-29.812854766845703
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-27.860050201416016,
|
||||
-28.641170501708984,
|
||||
-29.448333740234375,
|
||||
-30.932466506958008,
|
||||
-31.63547706604004,
|
||||
-32.33848571777344,
|
||||
-32.85923767089844,
|
||||
-33.17168426513672,
|
||||
-33.45809555053711
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-31.81774139404297,
|
||||
-31.895854949951172,
|
||||
-32.05207824707031,
|
||||
-35.43694305419922,
|
||||
-36.3482551574707,
|
||||
-38.61351013183594,
|
||||
-39.26444625854492,
|
||||
-40.61839294433594,
|
||||
-41.71196365356445
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-27.33930206298828,
|
||||
-27.834014892578125,
|
||||
-28.849472045898438,
|
||||
-30.567943572998047,
|
||||
-32.98942565917969,
|
||||
-33.067535400390625,
|
||||
-33.067535400390625,
|
||||
-35.67127990722656,
|
||||
-35.69731903076172
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-25.33441925048828,
|
||||
-26.063465118408203,
|
||||
-26.219690322875977,
|
||||
-26.2457275390625,
|
||||
-26.53213882446289,
|
||||
-27.365337371826172,
|
||||
-28.354759216308594,
|
||||
-28.667207717895508,
|
||||
-28.74532127380371
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-24.423107147216797,
|
||||
-24.579330444335938,
|
||||
-26.81855010986328,
|
||||
-28.12042236328125,
|
||||
-28.32872200012207,
|
||||
-28.61513328552246,
|
||||
-29.16191864013672,
|
||||
-29.187957763671875,
|
||||
-29.240032196044922
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-22.027664184570312,
|
||||
-23.850284576416016,
|
||||
-23.980472564697266,
|
||||
-24.292922973632812,
|
||||
-24.787633895874023,
|
||||
-29.279088973999023,
|
||||
-29.55248260498047,
|
||||
-29.903987884521484,
|
||||
-30.190399169921875
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-31.609439849853516,
|
||||
-31.817739486694336,
|
||||
-32.54678726196289,
|
||||
-32.676971435546875,
|
||||
-32.781124114990234,
|
||||
-32.98942565917969,
|
||||
-33.106590270996094,
|
||||
-33.57526397705078,
|
||||
-34.369407653808594
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-29.34418296813965,
|
||||
-29.63059425354004,
|
||||
-30.021156311035156,
|
||||
-30.984540939331055,
|
||||
-33.21073913574219,
|
||||
-34.30431365966797,
|
||||
-34.56468963623047,
|
||||
-34.70789337158203,
|
||||
-34.79902648925781
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-25.438566207885742,
|
||||
-25.69894027709961,
|
||||
-30.190397262573242,
|
||||
-30.802276611328125,
|
||||
-31.58340072631836,
|
||||
-31.609437942504883,
|
||||
-31.64849281311035,
|
||||
-31.973960876464844,
|
||||
-32.29943084716797
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-27.157039642333984,
|
||||
-32.104148864746094,
|
||||
-32.33848571777344,
|
||||
-34.04393768310547,
|
||||
-34.12205505371094,
|
||||
-34.40846252441406,
|
||||
-34.42148208618164,
|
||||
-34.772987365722656,
|
||||
-34.87713623046875
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-24.813671112060547,
|
||||
-26.974777221679688,
|
||||
-31.010578155517578,
|
||||
-31.08869171142578,
|
||||
-32.1822624206543,
|
||||
-35.33279037475586,
|
||||
-35.489013671875,
|
||||
-36.999183654785156,
|
||||
-37.88446044921875
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-20.46541976928711,
|
||||
-20.647682189941406,
|
||||
-23.069164276123047,
|
||||
-24.136699676513672,
|
||||
-25.438570022583008,
|
||||
-25.646869659423828,
|
||||
-26.193655014038086,
|
||||
-26.297805786132812,
|
||||
-26.506103515625
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-27.18307113647461,
|
||||
-28.30268096923828,
|
||||
-28.56305694580078,
|
||||
-29.526439666748047,
|
||||
-32.416595458984375,
|
||||
-35.202598571777344,
|
||||
-36.426361083984375,
|
||||
-39.31651306152344,
|
||||
-39.38160705566406
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-18.7469482421875,
|
||||
-20.100894927978516,
|
||||
-21.402767181396484,
|
||||
-21.428804397583008,
|
||||
-22.20992660522461,
|
||||
-22.34011459350586,
|
||||
-22.730674743652344,
|
||||
-23.069162368774414,
|
||||
-23.980472564697266
|
||||
],
|
||||
[
|
||||
-3.576278118089249e-07,
|
||||
-15.2579345703125,
|
||||
-16.481693267822266,
|
||||
-17.991863250732422,
|
||||
-19.215621948242188,
|
||||
-20.25712013244629,
|
||||
-21.350692749023438,
|
||||
-22.314077377319336,
|
||||
-22.496337890625,
|
||||
-22.938974380493164
|
||||
],
|
||||
[
|
||||
-0.08506780862808228,
|
||||
-2.506549835205078,
|
||||
-14.848289489746094,
|
||||
-15.473188400268555,
|
||||
-16.33242416381836,
|
||||
-16.358461380004883,
|
||||
-16.566761016845703,
|
||||
-17.03543472290039,
|
||||
-17.686370849609375,
|
||||
-17.816556930541992
|
||||
],
|
||||
[
|
||||
-0.0194891095161438,
|
||||
-4.445854187011719,
|
||||
-5.591499328613281,
|
||||
-5.956024169921875,
|
||||
-6.685070037841797,
|
||||
-13.142353057861328,
|
||||
-13.558952331542969,
|
||||
-15.173273086547852,
|
||||
-15.303461074829102,
|
||||
-15.85024642944336
|
||||
],
|
||||
[
|
||||
-0.0005990855861455202,
|
||||
-7.4212646484375,
|
||||
-15.675132751464844,
|
||||
-15.72720718383789,
|
||||
-16.76870346069336,
|
||||
-16.76870346069336,
|
||||
-17.706050872802734,
|
||||
-18.669435501098633,
|
||||
-19.398483276367188,
|
||||
-19.658857345581055
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-24.110658645629883,
|
||||
-25.829130172729492,
|
||||
-26.011390686035156,
|
||||
-26.011390686035156,
|
||||
-26.532140731811523,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.75589942932129,
|
||||
-28.055330276489258
|
||||
],
|
||||
[
|
||||
-1.1408883333206177,
|
||||
-0.38580334186553955,
|
||||
-7.494022369384766,
|
||||
-12.519245147705078,
|
||||
-14.576202392578125,
|
||||
-16.034297943115234,
|
||||
-16.945608139038086,
|
||||
-17.908992767333984,
|
||||
-18.664077758789062,
|
||||
-19.34105110168457
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-26.688365936279297,
|
||||
-29.83889389038086,
|
||||
-30.177383422851562,
|
||||
-30.64605712890625,
|
||||
-31.244916915893555,
|
||||
-31.270954132080078,
|
||||
-32.83319854736328,
|
||||
-34.655818939208984,
|
||||
-34.89015579223633
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-18.929210662841797,
|
||||
-19.16354751586914,
|
||||
-23.589908599853516,
|
||||
-24.683481216430664,
|
||||
-24.995929718017578,
|
||||
-25.516677856445312,
|
||||
-25.542715072631836,
|
||||
-25.77705192565918,
|
||||
-26.063465118408203
|
||||
],
|
||||
[
|
||||
-0.2519786059856415,
|
||||
-1.5017764568328857,
|
||||
-12.437495231628418,
|
||||
-15.457839012145996,
|
||||
-15.744250297546387,
|
||||
-16.837820053100586,
|
||||
-17.41064453125,
|
||||
-17.56686782836914,
|
||||
-17.61894416809082,
|
||||
-18.035541534423828
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-20.517494201660156,
|
||||
-24.683483123779297,
|
||||
-25.67290496826172,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.781936645507812,
|
||||
-27.912124633789062,
|
||||
-28.09438705444336,
|
||||
-28.445892333984375
|
||||
],
|
||||
[
|
||||
-3.40932747349143e-05,
|
||||
-10.284820556640625,
|
||||
-18.252273559570312,
|
||||
-20.17904281616211,
|
||||
-21.663175582885742,
|
||||
-22.027700424194336,
|
||||
-22.288074493408203,
|
||||
-22.704673767089844,
|
||||
-23.12127113342285,
|
||||
-23.277496337890625
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-22.60049057006836,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.063467025756836,
|
||||
-27.287227630615234,
|
||||
-27.391376495361328,
|
||||
-27.4694881439209,
|
||||
-27.67778778076172,
|
||||
-28.055330276489258
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-23.902362823486328,
|
||||
-28.823436737060547,
|
||||
-29.240036010742188,
|
||||
-29.31814956665039,
|
||||
-29.917007446289062,
|
||||
-30.021160125732422,
|
||||
-31.21887969970703,
|
||||
-32.416603088378906,
|
||||
-32.416603088378906
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-28.641170501708984,
|
||||
-31.947925567626953,
|
||||
-32.59886169433594,
|
||||
-33.848655700683594,
|
||||
-34.109031677246094,
|
||||
-34.73393249511719,
|
||||
-35.02033996582031,
|
||||
-35.02033996582031,
|
||||
-36.074859619140625
|
||||
],
|
||||
[
|
||||
-0.013183215633034706,
|
||||
-4.335395336151123,
|
||||
-19.619365692138672,
|
||||
-20.035964965820312,
|
||||
-20.244266510009766,
|
||||
-21.311800003051758,
|
||||
-21.441987991333008,
|
||||
-22.561595916748047,
|
||||
-23.108383178710938,
|
||||
-23.264606475830078
|
||||
],
|
||||
[
|
||||
-8.344646857949556e-07,
|
||||
-14.190400123596191,
|
||||
-15.9088716506958,
|
||||
-18.17412567138672,
|
||||
-18.46053695678711,
|
||||
-18.46053695678711,
|
||||
-18.512611389160156,
|
||||
-18.90317153930664,
|
||||
-19.059398651123047,
|
||||
-19.085433959960938
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-17.70545196533203,
|
||||
-18.903175354003906,
|
||||
-20.829944610595703,
|
||||
-22.574451446533203,
|
||||
-22.860862731933594,
|
||||
-23.069162368774414,
|
||||
-23.32953643798828,
|
||||
-23.694061279296875,
|
||||
-24.188772201538086
|
||||
],
|
||||
[
|
||||
0.0,
|
||||
-20.022781372070312,
|
||||
-21.038240432739258,
|
||||
-21.220502853393555,
|
||||
-22.496337890625,
|
||||
-22.769729614257812,
|
||||
-23.589908599853516,
|
||||
-23.65500259399414,
|
||||
-23.94141387939453,
|
||||
-24.266881942749023
|
||||
]
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -64,6 +64,9 @@ def test_process_messages(mock_hanlder):
|
|||
}
|
||||
|
||||
|
||||
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
|
||||
|
||||
|
||||
# [TODO] Review: Update the following test
|
||||
@patch("app.commons.constants.arch_function_client")
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.model_handler.guardrails import get_guardrail_handler
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
|
|
@ -12,36 +9,71 @@ arch_guard_model_type = {
|
|||
}
|
||||
|
||||
|
||||
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for `get_guardrail_handler()` function on `cpu`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cpu"
|
||||
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_ov_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `cuda`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cuda"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `mps`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "mps"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue