From 4fcfd8363905b068a41a0b43fc726faa230af699 Mon Sep 17 00:00:00 2001
From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com>
Date: Thu, 5 Dec 2024 15:19:41 -0800
Subject: [PATCH] Refine model_server
---
model_server/app/commons/constants.py | 16 +--
model_server/app/commons/globals.py | 8 +-
model_server/app/main.py | 23 ++--
.../app/model_handler/base_handler.py | 8 +-
.../app/model_handler/function_calling.py | 44 +++++--
model_server/app/model_handler/guardrails.py | 114 +++++++++++++-----
6 files changed, 149 insertions(+), 64 deletions(-)
diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py
index f4c062f8..4dcd24e2 100644
--- a/model_server/app/commons/constants.py
+++ b/model_server/app/commons/constants.py
@@ -4,23 +4,23 @@ ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
ARCH_INTENT_TASK_PROMPT = """
You are a helpful assistant.
-""".strip()
+"""
-ARCH_INTENT_TOOL_PROMPT = """
+ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
{tool_text}
-""".strip()
+"""
ARCH_INTENT_FORMAT_PROMPT = """
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
-""".strip()
+"""
ARCH_INTENT_GENERATION_CONFIG = {
@@ -37,10 +37,10 @@ ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_FUNCTION_TASK_PROMPT = """
You are a helpful assistant.
-""".strip()
+"""
-ARCH_FUNCTION_TOOL_PROMPT = """
+ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
# Tools
You may call one or more functions to assist with the user query.
@@ -49,7 +49,7 @@ You are provided with function signatures within XML tags:
{tool_text}
-""".strip()
+"""
ARCH_FUNCTION_FORMAT_PROMPT = """
@@ -57,7 +57,7 @@ For each function call, return a json object with function name and arguments wi
{"name": , "arguments": }
-""".strip()
+"""
ARCH_FUNCTION_GENERATION_CONFIG = {
"generation_params": {
diff --git a/model_server/app/commons/globals.py b/model_server/app/commons/globals.py
index 92edec3a..e62286b3 100644
--- a/model_server/app/commons/globals.py
+++ b/model_server/app/commons/globals.py
@@ -10,7 +10,9 @@ logger = utils.get_model_server_logger()
# Define the client
-ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
+ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
+ARCH_API_KEY = "EMPTY"
+ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
# Define model handlers
@@ -19,7 +21,7 @@ handler_map = {
ARCH_CLIENT,
ARCH_INTENT_MODEL_ALIAS,
ARCH_INTENT_TASK_PROMPT,
- ARCH_INTENT_TOOL_PROMPT,
+ ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
ARCH_INTENT_FORMAT_PROMPT,
ARCH_INTENT_INSTRUCTION,
**ARCH_INTENT_GENERATION_CONFIG,
@@ -28,7 +30,7 @@ handler_map = {
ARCH_CLIENT,
ARCH_FUNCTION_MODEL_ALIAS,
ARCH_FUNCTION_TASK_PROMPT,
- ARCH_FUNCTION_TOOL_PROMPT,
+ ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
ARCH_FUNCTION_FORMAT_PROMPT,
**ARCH_FUNCTION_GENERATION_CONFIG,
),
diff --git a/model_server/app/main.py b/model_server/app/main.py
index 6d1de4e7..d615c086 100644
--- a/model_server/app/main.py
+++ b/model_server/app/main.py
@@ -4,7 +4,7 @@ from app.commons.globals import handler_map
from app.model_handler.base_handler import ChatMessage
from app.model_handler.guardrails import GuardRequest
-from fastapi import FastAPI, Response, Request
+from fastapi import FastAPI, Response
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
@@ -53,24 +53,25 @@ async def models():
@app.post("/function_calling")
-async def function_calling(req: ChatMessage, res: Response, request: Request):
+async def function_calling(req: ChatMessage, res: Response):
try:
- intent_result = await handler_map["Arch-Intent"].chat_completion(req)
+ intent_response = await handler_map["Arch-Intent"].chat_completion(req)
- if intent_result.choices[0].message.content == "Yes":
+ if handler_map["Arch-Intent"].detect_intent(intent_response):
+ # [TODO] measure agreement between intent detection and function calling
try:
- function_result = await handler_map["Arch-Function"].chat_completion(
- req
- )
- return function_result
+ function_calling_response = await handler_map[
+ "Arch-Function"
+ ].chat_completion(req)
+ return function_calling_response
except Exception as e:
- # [TODO]
+ # [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
except Exception as e:
- # [TODO]
+ # [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
res.status_code = 500
return {"error": f"[Arch-Intent] - {e}"}
@@ -82,6 +83,6 @@ async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
guard_result = handler_map["Arch-Guard"].predict(req)
return guard_result
except Exception as e:
- # [TODO]
+ # [TODO] Review: update how to collect debugging outputs
res.status_code = 500
return {"error": f"[Arch-Guard] - {e}"}
diff --git a/model_server/app/model_handler/base_handler.py b/model_server/app/model_handler/base_handler.py
index 6b88a2bb..cfc7a0f8 100644
--- a/model_server/app/model_handler/base_handler.py
+++ b/model_server/app/model_handler/base_handler.py
@@ -38,7 +38,7 @@ class ArchBaseHandler:
client: OpenAI,
model_name: str,
task_prompt: str,
- tool_prompt: str,
+ tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
@@ -59,7 +59,7 @@ class ArchBaseHandler:
self.model_name = model_name
self.task_prompt = task_prompt
- self.tool_prompt = tool_prompt
+ self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
@@ -78,7 +78,7 @@ class ArchBaseHandler:
raise NotImplementedError()
@final
- def _format_system(self, tools: List[Dict[str, Any]]) -> str:
+ def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
@@ -94,7 +94,7 @@ class ArchBaseHandler:
system_prompt = (
self.task_prompt
+ "\n\n"
- + self.tool_prompt.format(tool_text=tool_text)
+ + self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py
index 3630e8bf..a3bb2dc3 100644
--- a/model_server/app/model_handler/function_calling.py
+++ b/model_server/app/model_handler/function_calling.py
@@ -23,7 +23,7 @@ class ArchIntentHandler(ArchBaseHandler):
client: OpenAI,
model_name: str,
task_prompt: str,
- tool_prompt: str,
+ tool_prompt_template: str,
format_prompt: str,
extra_instruction: str,
generation_params: Dict,
@@ -35,7 +35,7 @@ class ArchIntentHandler(ArchBaseHandler):
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
- tool_prompt (str): A prompt to describe tools.
+ tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
@@ -45,7 +45,7 @@ class ArchIntentHandler(ArchBaseHandler):
client,
model_name,
task_prompt,
- tool_prompt,
+ tool_prompt_template,
format_prompt,
generation_params,
)
@@ -69,6 +69,19 @@ class ArchIntentHandler(ArchBaseHandler):
]
return "\n".join(converted)
+ def detect_intent(self, content: str) -> bool:
+ """
+ Detect if any intent match with prompts
+
+ Args:
+ content: str: Model response that contains intent detection results
+
+ Returns:
+ bool: A boolean value to indicate if any intent match with prompts or not
+ """
+
+ return content.choices[0].message.content == "Yes"
+
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
@@ -110,7 +123,7 @@ class ArchFunctionHandler(ArchBaseHandler):
client: OpenAI,
model_name: str,
task_prompt: str,
- tool_prompt: str,
+ tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
@@ -123,7 +136,7 @@ class ArchFunctionHandler(ArchBaseHandler):
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
- tool_prompt (str): A prompt to describe tools.
+ tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
prefill_params (Dict): Additional parameters for prefilling responses.
@@ -134,7 +147,7 @@ class ArchFunctionHandler(ArchBaseHandler):
client,
model_name,
task_prompt,
- tool_prompt,
+ tool_prompt_template,
format_prompt,
generation_params,
)
@@ -392,15 +405,24 @@ class ArchFunctionHandler(ArchBaseHandler):
else:
model_response = response.choices[0].message.content
- tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
+ (
+ tool_calls,
+ extraction_status,
+ extraction_error_message,
+ ) = self._extract_tool_calls(model_response)
if tool_calls:
- is_valid, error_tool_call, error_message = self._verify_tool_calls(
- tools=req.tools, tool_calls=tool_calls
- )
+ # [TODO] Review: define the behavior in the case that tool call extraction fails
+ # if not extraction_status:
+
+ (
+ verification_status,
+ invalid_tool_call,
+ verification_error_message,
+ ) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
- if is_valid:
+ if verification_status:
model_response = Message(content="", tool_calls=tool_calls)
# else:
diff --git a/model_server/app/model_handler/guardrails.py b/model_server/app/model_handler/guardrails.py
index da4e6246..f733552e 100644
--- a/model_server/app/model_handler/guardrails.py
+++ b/model_server/app/model_handler/guardrails.py
@@ -6,6 +6,7 @@ import app.commons.utilities as utils
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
+from typing import List
class GuardRequest(BaseModel):
@@ -13,8 +14,22 @@ class GuardRequest(BaseModel):
task: str
+class GuardResponse(BaseModel):
+ prob: List
+ verdict: bool
+ sentence: List
+ latency: float = 0
+
+
class ArchGuardHanlder:
def __init__(self, model_dict):
+ """
+ Initializes the ArchGuardHanlder with the given model dictionary.
+
+ Args:
+ model_dict (dict): A dictionary containing the model, tokenizer, and device information.
+ """
+
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
@@ -23,9 +38,17 @@ class ArchGuardHanlder:
def _split_text_into_chunks(self, text, max_num_words=300):
"""
- Split the text into chunks of `max_num_words` words
+ Splits the input text into chunks of up to `max_num_words` words.
+
+ Args:
+ text (str): The input text to be split.
+ max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
+
+ Returns:
+ List[str]: A list of text chunks.
"""
- words = text.split() # Split text into words
+
+ words = text.split()
chunks = [
" ".join(words[i : i + max_num_words])
@@ -36,19 +59,44 @@ class ArchGuardHanlder:
@staticmethod
def softmax(x):
+ """
+ Computes the softmax of the input array.
+
+ Args:
+ x (np.ndarray): The input array.
+
+ Returns:
+ np.ndarray: The softmax of the input.
+ """
return np.exp(x) / np.exp(x).sum(axis=0)
- def _predict_text(self, task, text, max_length=512):
+ def _predict_text(self, task, text, max_length=512) -> GuardResponse:
+ """
+ Predicts the result for the provided text for a specific task.
+
+ Args:
+ task (str): The task to perform (e.g., "jailbreak").
+ text (str): The input text to classify.
+ max_length (int, optional): The maximum length for tokenization. Defaults to 512.
+
+ Returns:
+ GuardResponse: A GuardResponse object containing the prediction.
+ """
+
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
+ start_time = time.perf_counter()
+
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
+ latency = time.perf_counter() - start_time
+
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
@@ -56,49 +104,61 @@ class ArchGuardHanlder:
verdict = False
sentence = None
- result_dict = {
- "prob": prob.item(),
- "verdict": verdict,
- "sentence": sentence,
- }
+ return GuardResponse(
+ prob=prob.item(), verdict=verdict, sentence=sentence, latency=latency
+ )
- return result_dict
-
- def predict(self, req: GuardRequest, max_num_words=300):
+ def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
"""
- Note: currently only support jailbreak check
+ Makes a prediction based on the GuardRequest input.
+
+ Args:
+ req (GuardRequest): The GuardRequest object containing the input text and task.
+ max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
+
+ Returns:
+ GuardResponse: A GuardResponse object containing the prediction.
+
+ Note:
+ currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
- guard_result = {
- "prob": [],
- "verdict": False,
- "sentence": [],
- }
-
- start_time = time.perf_counter()
-
if len(req.input.split()) < max_num_words:
- guard_result = self._predict_text(req.task, req.input)
+ return self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
+ prob, verdict, sentence, latency = [], False, [], 0
+
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
- if chunk_result["verdict"]:
- guard_result["verdict"] = True
- guard_result["sentence"].append(chunk_result["sentence"])
- guard_result["prob"].append(chunk_result["prob"].item())
- guard_result["latency"] = time.perf_counter() - start_time
+ if chunk_result.verdict:
+ prob.append(chunk_result.prob)
+ verdict = True
+ sentence.append(chunk_result.sentence)
+ latency += chunk_result.latency
- return guard_result
+ return GuardResponse(
+ prob=prob, verdict=verdict, sentence=sentence, latency=latency
+ )
def get_guardrail_handler(device: str = None):
+ """
+ Initializes and returns an instance of ArchGuardHanlder based on the specified device.
+
+ Args:
+ device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
+
+ Returns:
+ ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
+ """
+
if device is None:
device = utils.get_device()