Refine model_server

This commit is contained in:
Shuguang Chen 2024-12-05 15:19:41 -08:00
parent a5bd005411
commit 4fcfd83639
6 changed files with 149 additions and 64 deletions

View file

@ -38,7 +38,7 @@ class ArchBaseHandler:
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
@ -59,7 +59,7 @@ class ArchBaseHandler:
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt = tool_prompt
self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
@ -78,7 +78,7 @@ class ArchBaseHandler:
raise NotImplementedError()
@final
def _format_system(self, tools: List[Dict[str, Any]]) -> str:
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
@ -94,7 +94,7 @@ class ArchBaseHandler:
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt.format(tool_text=tool_text)
+ self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)

View file

@ -23,7 +23,7 @@ class ArchIntentHandler(ArchBaseHandler):
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
tool_prompt_template: str,
format_prompt: str,
extra_instruction: str,
generation_params: Dict,
@ -35,7 +35,7 @@ class ArchIntentHandler(ArchBaseHandler):
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
@ -45,7 +45,7 @@ class ArchIntentHandler(ArchBaseHandler):
client,
model_name,
task_prompt,
tool_prompt,
tool_prompt_template,
format_prompt,
generation_params,
)
@ -69,6 +69,19 @@ class ArchIntentHandler(ArchBaseHandler):
]
return "\n".join(converted)
def detect_intent(self, content: str) -> bool:
"""
Detect if any intent match with prompts
Args:
content: str: Model response that contains intent detection results
Returns:
bool: A boolean value to indicate if any intent match with prompts or not
"""
return content.choices[0].message.content == "Yes"
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
@ -110,7 +123,7 @@ class ArchFunctionHandler(ArchBaseHandler):
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
@ -123,7 +136,7 @@ class ArchFunctionHandler(ArchBaseHandler):
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
prefill_params (Dict): Additional parameters for prefilling responses.
@ -134,7 +147,7 @@ class ArchFunctionHandler(ArchBaseHandler):
client,
model_name,
task_prompt,
tool_prompt,
tool_prompt_template,
format_prompt,
generation_params,
)
@ -392,15 +405,24 @@ class ArchFunctionHandler(ArchBaseHandler):
else:
model_response = response.choices[0].message.content
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
(
tool_calls,
extraction_status,
extraction_error_message,
) = self._extract_tool_calls(model_response)
if tool_calls:
is_valid, error_tool_call, error_message = self._verify_tool_calls(
tools=req.tools, tool_calls=tool_calls
)
# [TODO] Review: define the behavior in the case that tool call extraction fails
# if not extraction_status:
(
verification_status,
invalid_tool_call,
verification_error_message,
) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if is_valid:
if verification_status:
model_response = Message(content="", tool_calls=tool_calls)
# else:

View file

@ -6,6 +6,7 @@ import app.commons.utilities as utils
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from typing import List
class GuardRequest(BaseModel):
@ -13,8 +14,22 @@ class GuardRequest(BaseModel):
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
class ArchGuardHanlder:
def __init__(self, model_dict):
"""
Initializes the ArchGuardHanlder with the given model dictionary.
Args:
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
"""
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
@ -23,9 +38,17 @@ class ArchGuardHanlder:
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Split the text into chunks of `max_num_words` words
Splits the input text into chunks of up to `max_num_words` words.
Args:
text (str): The input text to be split.
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
Returns:
List[str]: A list of text chunks.
"""
words = text.split() # Split text into words
words = text.split()
chunks = [
" ".join(words[i : i + max_num_words])
@ -36,19 +59,44 @@ class ArchGuardHanlder:
@staticmethod
def softmax(x):
"""
Computes the softmax of the input array.
Args:
x (np.ndarray): The input array.
Returns:
np.ndarray: The softmax of the input.
"""
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512):
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
"""
Predicts the result for the provided text for a specific task.
Args:
task (str): The task to perform (e.g., "jailbreak").
text (str): The input text to classify.
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
"""
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
start_time = time.perf_counter()
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
latency = time.perf_counter() - start_time
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
@ -56,49 +104,61 @@ class ArchGuardHanlder:
verdict = False
sentence = None
result_dict = {
"prob": prob.item(),
"verdict": verdict,
"sentence": sentence,
}
return GuardResponse(
prob=prob.item(), verdict=verdict, sentence=sentence, latency=latency
)
return result_dict
def predict(self, req: GuardRequest, max_num_words=300):
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
"""
Note: currently only support jailbreak check
Makes a prediction based on the GuardRequest input.
Args:
req (GuardRequest): The GuardRequest object containing the input text and task.
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
Note:
currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
guard_result = {
"prob": [],
"verdict": False,
"sentence": [],
}
start_time = time.perf_counter()
if len(req.input.split()) < max_num_words:
guard_result = self._predict_text(req.task, req.input)
return self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
prob, verdict, sentence, latency = [], False, [], 0
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result["verdict"]:
guard_result["verdict"] = True
guard_result["sentence"].append(chunk_result["sentence"])
guard_result["prob"].append(chunk_result["prob"].item())
guard_result["latency"] = time.perf_counter() - start_time
if chunk_result.verdict:
prob.append(chunk_result.prob)
verdict = True
sentence.append(chunk_result.sentence)
latency += chunk_result.latency
return guard_result
return GuardResponse(
prob=prob, verdict=verdict, sentence=sentence, latency=latency
)
def get_guardrail_handler(device: str = None):
"""
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
Args:
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
Returns:
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
"""
if device is None:
device = utils.get_device()