Update guardrail_handler and its associated tests

This commit is contained in:
Shuguang Chen 2024-12-05 11:30:58 -08:00
parent b686cf8b87
commit 09f7e1e604
7 changed files with 115 additions and 1091 deletions

View file

@ -25,7 +25,7 @@ class ArchIntentHandler(ArchBaseHandler):
task_prompt: str,
tool_prompt: str,
format_prompt: str,
intent_instruction: str,
extra_instruction: str,
generation_params: Dict,
):
"""
@ -37,7 +37,7 @@ class ArchIntentHandler(ArchBaseHandler):
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
intent_instruction (str): Instructions specific to intent handling.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
"""
@ -50,7 +50,7 @@ class ArchIntentHandler(ArchBaseHandler):
generation_params,
)
self.intent_instruction = intent_instruction
self.extra_instruction = extra_instruction
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
@ -85,7 +85,7 @@ class ArchIntentHandler(ArchBaseHandler):
"""
messages = self._process_messages(
req.messages, req.tools, self.intent_instruction
req.messages, req.tools, self.extra_instruction
)
model_response = self.client.chat.completions.create(

View file

@ -1,8 +1,11 @@
import time
import torch
import numpy as np
import app.commons.utilities as utils
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
class GuardRequest(BaseModel):
@ -93,3 +96,27 @@ class ArchGuardHanlder:
guard_result["latency"] = time.perf_counter() - start_time
return guard_result
def get_guardrail_handler(device: str = None):
if device is None:
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)