From 4fcfd8363905b068a41a0b43fc726faa230af699 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:19:41 -0800 Subject: [PATCH] Refine model_server --- model_server/app/commons/constants.py | 16 +-- model_server/app/commons/globals.py | 8 +- model_server/app/main.py | 23 ++-- .../app/model_handler/base_handler.py | 8 +- .../app/model_handler/function_calling.py | 44 +++++-- model_server/app/model_handler/guardrails.py | 114 +++++++++++++----- 6 files changed, 149 insertions(+), 64 deletions(-) diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index f4c062f8..4dcd24e2 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -4,23 +4,23 @@ ARCH_INTENT_INSTRUCTION = "Are there any tools can help?" ARCH_INTENT_TASK_PROMPT = """ You are a helpful assistant. -""".strip() +""" -ARCH_INTENT_TOOL_PROMPT = """ +ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """ You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below. {tool_text} -""".strip() +""" ARCH_INTENT_FORMAT_PROMPT = """ Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation: - First line must read 'Yes' or 'No'. - If yes, a second line must include a comma-separated list of tool indexes. -""".strip() +""" ARCH_INTENT_GENERATION_CONFIG = { @@ -37,10 +37,10 @@ ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function" ARCH_FUNCTION_TASK_PROMPT = """ You are a helpful assistant. -""".strip() +""" -ARCH_FUNCTION_TOOL_PROMPT = """ +ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """ # Tools You may call one or more functions to assist with the user query. @@ -49,7 +49,7 @@ You are provided with function signatures within XML tags: {tool_text} -""".strip() +""" ARCH_FUNCTION_FORMAT_PROMPT = """ @@ -57,7 +57,7 @@ For each function call, return a json object with function name and arguments wi {"name": , "arguments": } -""".strip() +""" ARCH_FUNCTION_GENERATION_CONFIG = { "generation_params": { diff --git a/model_server/app/commons/globals.py b/model_server/app/commons/globals.py index 92edec3a..e62286b3 100644 --- a/model_server/app/commons/globals.py +++ b/model_server/app/commons/globals.py @@ -10,7 +10,9 @@ logger = utils.get_model_server_logger() # Define the client -ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") +ARCH_ENDPOINT = "https://api.fc.archgw.com/v1" +ARCH_API_KEY = "EMPTY" +ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY) # Define model handlers @@ -19,7 +21,7 @@ handler_map = { ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ARCH_INTENT_TASK_PROMPT, - ARCH_INTENT_TOOL_PROMPT, + ARCH_INTENT_TOOL_PROMPT_TEMPLATE, ARCH_INTENT_FORMAT_PROMPT, ARCH_INTENT_INSTRUCTION, **ARCH_INTENT_GENERATION_CONFIG, @@ -28,7 +30,7 @@ handler_map = { ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ARCH_FUNCTION_TASK_PROMPT, - ARCH_FUNCTION_TOOL_PROMPT, + ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE, ARCH_FUNCTION_FORMAT_PROMPT, **ARCH_FUNCTION_GENERATION_CONFIG, ), diff --git a/model_server/app/main.py b/model_server/app/main.py index 6d1de4e7..d615c086 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -4,7 +4,7 @@ from app.commons.globals import handler_map from app.model_handler.base_handler import ChatMessage from app.model_handler.guardrails import GuardRequest -from fastapi import FastAPI, Response, Request +from fastapi import FastAPI, Response from opentelemetry import trace from opentelemetry.sdk.trace import TracerProvider from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -53,24 +53,25 @@ async def models(): @app.post("/function_calling") -async def function_calling(req: ChatMessage, res: Response, request: Request): +async def function_calling(req: ChatMessage, res: Response): try: - intent_result = await handler_map["Arch-Intent"].chat_completion(req) + intent_response = await handler_map["Arch-Intent"].chat_completion(req) - if intent_result.choices[0].message.content == "Yes": + if handler_map["Arch-Intent"].detect_intent(intent_response): + # [TODO] measure agreement between intent detection and function calling try: - function_result = await handler_map["Arch-Function"].chat_completion( - req - ) - return function_result + function_calling_response = await handler_map[ + "Arch-Function" + ].chat_completion(req) + return function_calling_response except Exception as e: - # [TODO] + # [TODO] Review: update how to collect debugging outputs # logger.error(f"Error in chat_completion from `Arch-Function`: {e}") res.status_code = 500 return {"error": f"[Arch-Function] - {e}"} except Exception as e: - # [TODO] + # [TODO] Review: update how to collect debugging outputs # logger.error(f"Error in chat_completion from `Arch-Intent`: {e}") res.status_code = 500 return {"error": f"[Arch-Intent] - {e}"} @@ -82,6 +83,6 @@ async def guardrails(req: GuardRequest, res: Response, max_num_words=300): guard_result = handler_map["Arch-Guard"].predict(req) return guard_result except Exception as e: - # [TODO] + # [TODO] Review: update how to collect debugging outputs res.status_code = 500 return {"error": f"[Arch-Guard] - {e}"} diff --git a/model_server/app/model_handler/base_handler.py b/model_server/app/model_handler/base_handler.py index 6b88a2bb..cfc7a0f8 100644 --- a/model_server/app/model_handler/base_handler.py +++ b/model_server/app/model_handler/base_handler.py @@ -38,7 +38,7 @@ class ArchBaseHandler: client: OpenAI, model_name: str, task_prompt: str, - tool_prompt: str, + tool_prompt_template: str, format_prompt: str, generation_params: Dict, ): @@ -59,7 +59,7 @@ class ArchBaseHandler: self.model_name = model_name self.task_prompt = task_prompt - self.tool_prompt = tool_prompt + self.tool_prompt_template = tool_prompt_template self.format_prompt = format_prompt self.generation_params = generation_params @@ -78,7 +78,7 @@ class ArchBaseHandler: raise NotImplementedError() @final - def _format_system(self, tools: List[Dict[str, Any]]) -> str: + def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str: """ Formats the system prompt using provided tools. @@ -94,7 +94,7 @@ class ArchBaseHandler: system_prompt = ( self.task_prompt + "\n\n" - + self.tool_prompt.format(tool_text=tool_text) + + self.tool_prompt_template.format(tool_text=tool_text) + "\n\n" + self.format_prompt ) diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index 3630e8bf..a3bb2dc3 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -23,7 +23,7 @@ class ArchIntentHandler(ArchBaseHandler): client: OpenAI, model_name: str, task_prompt: str, - tool_prompt: str, + tool_prompt_template: str, format_prompt: str, extra_instruction: str, generation_params: Dict, @@ -35,7 +35,7 @@ class ArchIntentHandler(ArchBaseHandler): client (OpenAI): An OpenAI client instance. model_name (str): Name of the model to use. task_prompt (str): The main task prompt for the system. - tool_prompt (str): A prompt to describe tools. + tool_prompt_template (str): A prompt to describe tools. format_prompt (str): A prompt specifying the desired output format. extra_instruction (str): Instructions specific to intent handling. generation_params (Dict): Generation parameters for the model. @@ -45,7 +45,7 @@ class ArchIntentHandler(ArchBaseHandler): client, model_name, task_prompt, - tool_prompt, + tool_prompt_template, format_prompt, generation_params, ) @@ -69,6 +69,19 @@ class ArchIntentHandler(ArchBaseHandler): ] return "\n".join(converted) + def detect_intent(self, content: str) -> bool: + """ + Detect if any intent match with prompts + + Args: + content: str: Model response that contains intent detection results + + Returns: + bool: A boolean value to indicate if any intent match with prompts or not + """ + + return content.choices[0].message.content == "Yes" + @override async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: """ @@ -110,7 +123,7 @@ class ArchFunctionHandler(ArchBaseHandler): client: OpenAI, model_name: str, task_prompt: str, - tool_prompt: str, + tool_prompt_template: str, format_prompt: str, generation_params: Dict, prefill_params: Dict, @@ -123,7 +136,7 @@ class ArchFunctionHandler(ArchBaseHandler): client (OpenAI): An OpenAI client instance. model_name (str): Name of the model to use. task_prompt (str): The main task prompt for the system. - tool_prompt (str): A prompt to describe tools. + tool_prompt_template (str): A prompt to describe tools. format_prompt (str): A prompt specifying the desired output format. generation_params (Dict): Generation parameters for the model. prefill_params (Dict): Additional parameters for prefilling responses. @@ -134,7 +147,7 @@ class ArchFunctionHandler(ArchBaseHandler): client, model_name, task_prompt, - tool_prompt, + tool_prompt_template, format_prompt, generation_params, ) @@ -392,15 +405,24 @@ class ArchFunctionHandler(ArchBaseHandler): else: model_response = response.choices[0].message.content - tool_calls, is_valid, error_message = self._extract_tool_calls(model_response) + ( + tool_calls, + extraction_status, + extraction_error_message, + ) = self._extract_tool_calls(model_response) if tool_calls: - is_valid, error_tool_call, error_message = self._verify_tool_calls( - tools=req.tools, tool_calls=tool_calls - ) + # [TODO] Review: define the behavior in the case that tool call extraction fails + # if not extraction_status: + + ( + verification_status, + invalid_tool_call, + verification_error_message, + ) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls) # [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately - if is_valid: + if verification_status: model_response = Message(content="", tool_calls=tool_calls) # else: diff --git a/model_server/app/model_handler/guardrails.py b/model_server/app/model_handler/guardrails.py index da4e6246..f733552e 100644 --- a/model_server/app/model_handler/guardrails.py +++ b/model_server/app/model_handler/guardrails.py @@ -6,6 +6,7 @@ import app.commons.utilities as utils from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification from optimum.intel import OVModelForSequenceClassification +from typing import List class GuardRequest(BaseModel): @@ -13,8 +14,22 @@ class GuardRequest(BaseModel): task: str +class GuardResponse(BaseModel): + prob: List + verdict: bool + sentence: List + latency: float = 0 + + class ArchGuardHanlder: def __init__(self, model_dict): + """ + Initializes the ArchGuardHanlder with the given model dictionary. + + Args: + model_dict (dict): A dictionary containing the model, tokenizer, and device information. + """ + self.model = model_dict["model"] self.tokenizer = model_dict["tokenizer"] self.device = model_dict["device"] @@ -23,9 +38,17 @@ class ArchGuardHanlder: def _split_text_into_chunks(self, text, max_num_words=300): """ - Split the text into chunks of `max_num_words` words + Splits the input text into chunks of up to `max_num_words` words. + + Args: + text (str): The input text to be split. + max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300. + + Returns: + List[str]: A list of text chunks. """ - words = text.split() # Split text into words + + words = text.split() chunks = [ " ".join(words[i : i + max_num_words]) @@ -36,19 +59,44 @@ class ArchGuardHanlder: @staticmethod def softmax(x): + """ + Computes the softmax of the input array. + + Args: + x (np.ndarray): The input array. + + Returns: + np.ndarray: The softmax of the input. + """ return np.exp(x) / np.exp(x).sum(axis=0) - def _predict_text(self, task, text, max_length=512): + def _predict_text(self, task, text, max_length=512) -> GuardResponse: + """ + Predicts the result for the provided text for a specific task. + + Args: + task (str): The task to perform (e.g., "jailbreak"). + text (str): The input text to classify. + max_length (int, optional): The maximum length for tokenization. Defaults to 512. + + Returns: + GuardResponse: A GuardResponse object containing the prediction. + """ + inputs = self.tokenizer( text, truncation=True, max_length=max_length, return_tensors="pt" ).to(self.device) + start_time = time.perf_counter() + with torch.no_grad(): logits = self.model(**inputs).logits.cpu().detach().numpy()[0] prob = ArchGuardHanlder.softmax(logits)[ self.support_tasks[task]["positive_class"] ] + latency = time.perf_counter() - start_time + if prob > self.support_tasks[task]["threshold"]: verdict = True sentence = text @@ -56,49 +104,61 @@ class ArchGuardHanlder: verdict = False sentence = None - result_dict = { - "prob": prob.item(), - "verdict": verdict, - "sentence": sentence, - } + return GuardResponse( + prob=prob.item(), verdict=verdict, sentence=sentence, latency=latency + ) - return result_dict - - def predict(self, req: GuardRequest, max_num_words=300): + def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse: """ - Note: currently only support jailbreak check + Makes a prediction based on the GuardRequest input. + + Args: + req (GuardRequest): The GuardRequest object containing the input text and task. + max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300. + + Returns: + GuardResponse: A GuardResponse object containing the prediction. + + Note: + currently only support jailbreak check """ if req.task not in self.support_tasks: raise NotImplementedError(f"{req.task} is not supported!") - guard_result = { - "prob": [], - "verdict": False, - "sentence": [], - } - - start_time = time.perf_counter() - if len(req.input.split()) < max_num_words: - guard_result = self._predict_text(req.task, req.input) + return self._predict_text(req.task, req.input) else: # split into chunks if text is long text_chunks = self._split_text_into_chunks(req.input) + prob, verdict, sentence, latency = [], False, [], 0 + for chunk in text_chunks: chunk_result = self._predict_text(req.task, chunk) - if chunk_result["verdict"]: - guard_result["verdict"] = True - guard_result["sentence"].append(chunk_result["sentence"]) - guard_result["prob"].append(chunk_result["prob"].item()) - guard_result["latency"] = time.perf_counter() - start_time + if chunk_result.verdict: + prob.append(chunk_result.prob) + verdict = True + sentence.append(chunk_result.sentence) + latency += chunk_result.latency - return guard_result + return GuardResponse( + prob=prob, verdict=verdict, sentence=sentence, latency=latency + ) def get_guardrail_handler(device: str = None): + """ + Initializes and returns an instance of ArchGuardHanlder based on the specified device. + + Args: + device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None. + + Returns: + ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device. + """ + if device is None: device = utils.get_device()