mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Refine model_server
This commit is contained in:
parent
a5bd005411
commit
4fcfd83639
6 changed files with 149 additions and 64 deletions
|
|
@ -4,23 +4,23 @@ ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
|
|||
|
||||
ARCH_INTENT_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
"""
|
||||
|
||||
|
||||
ARCH_INTENT_TOOL_PROMPT = """
|
||||
ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
"""
|
||||
|
||||
|
||||
ARCH_INTENT_FORMAT_PROMPT = """
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
""".strip()
|
||||
"""
|
||||
|
||||
|
||||
ARCH_INTENT_GENERATION_CONFIG = {
|
||||
|
|
@ -37,10 +37,10 @@ ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
|||
|
||||
ARCH_FUNCTION_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
"""
|
||||
|
||||
|
||||
ARCH_FUNCTION_TOOL_PROMPT = """
|
||||
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
|
@ -49,7 +49,7 @@ You are provided with function signatures within <tools></tools> XML tags:
|
|||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
"""
|
||||
|
||||
|
||||
ARCH_FUNCTION_FORMAT_PROMPT = """
|
||||
|
|
@ -57,7 +57,7 @@ For each function call, return a json object with function name and arguments wi
|
|||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
"""
|
||||
|
||||
ARCH_FUNCTION_GENERATION_CONFIG = {
|
||||
"generation_params": {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ logger = utils.get_model_server_logger()
|
|||
|
||||
|
||||
# Define the client
|
||||
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
|
||||
|
||||
# Define model handlers
|
||||
|
|
@ -19,7 +21,7 @@ handler_map = {
|
|||
ARCH_CLIENT,
|
||||
ARCH_INTENT_MODEL_ALIAS,
|
||||
ARCH_INTENT_TASK_PROMPT,
|
||||
ARCH_INTENT_TOOL_PROMPT,
|
||||
ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
|
||||
ARCH_INTENT_FORMAT_PROMPT,
|
||||
ARCH_INTENT_INSTRUCTION,
|
||||
**ARCH_INTENT_GENERATION_CONFIG,
|
||||
|
|
@ -28,7 +30,7 @@ handler_map = {
|
|||
ARCH_CLIENT,
|
||||
ARCH_FUNCTION_MODEL_ALIAS,
|
||||
ARCH_FUNCTION_TASK_PROMPT,
|
||||
ARCH_FUNCTION_TOOL_PROMPT,
|
||||
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
|
||||
ARCH_FUNCTION_FORMAT_PROMPT,
|
||||
**ARCH_FUNCTION_GENERATION_CONFIG,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from app.commons.globals import handler_map
|
|||
from app.model_handler.base_handler import ChatMessage
|
||||
from app.model_handler.guardrails import GuardRequest
|
||||
|
||||
from fastapi import FastAPI, Response, Request
|
||||
from fastapi import FastAPI, Response
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
|
|
@ -53,24 +53,25 @@ async def models():
|
|||
|
||||
|
||||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response, request: Request):
|
||||
async def function_calling(req: ChatMessage, res: Response):
|
||||
try:
|
||||
intent_result = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
if intent_result.choices[0].message.content == "Yes":
|
||||
if handler_map["Arch-Intent"].detect_intent(intent_response):
|
||||
# [TODO] measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_result = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
return function_result
|
||||
function_calling_response = await handler_map[
|
||||
"Arch-Function"
|
||||
].chat_completion(req)
|
||||
return function_calling_response
|
||||
except Exception as e:
|
||||
# [TODO]
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Function] - {e}"}
|
||||
|
||||
except Exception as e:
|
||||
# [TODO]
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Intent] - {e}"}
|
||||
|
|
@ -82,6 +83,6 @@ async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
|||
guard_result = handler_map["Arch-Guard"].predict(req)
|
||||
return guard_result
|
||||
except Exception as e:
|
||||
# [TODO]
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Guard] - {e}"}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class ArchBaseHandler:
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
|
|
@ -59,7 +59,7 @@ class ArchBaseHandler:
|
|||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
self.tool_prompt = tool_prompt
|
||||
self.tool_prompt_template = tool_prompt_template
|
||||
self.format_prompt = format_prompt
|
||||
|
||||
self.generation_params = generation_params
|
||||
|
|
@ -78,7 +78,7 @@ class ArchBaseHandler:
|
|||
raise NotImplementedError()
|
||||
|
||||
@final
|
||||
def _format_system(self, tools: List[Dict[str, Any]]) -> str:
|
||||
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Formats the system prompt using provided tools.
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class ArchBaseHandler:
|
|||
system_prompt = (
|
||||
self.task_prompt
|
||||
+ "\n\n"
|
||||
+ self.tool_prompt.format(tool_text=tool_text)
|
||||
+ self.tool_prompt_template.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
extra_instruction: str,
|
||||
generation_params: Dict,
|
||||
|
|
@ -35,7 +35,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
tool_prompt_template (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
extra_instruction (str): Instructions specific to intent handling.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
|
|
@ -45,7 +45,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt,
|
||||
tool_prompt_template,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
)
|
||||
|
|
@ -69,6 +69,19 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
def detect_intent(self, content: str) -> bool:
|
||||
"""
|
||||
Detect if any intent match with prompts
|
||||
|
||||
Args:
|
||||
content: str: Model response that contains intent detection results
|
||||
|
||||
Returns:
|
||||
bool: A boolean value to indicate if any intent match with prompts or not
|
||||
"""
|
||||
|
||||
return content.choices[0].message.content == "Yes"
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
|
|
@ -110,7 +123,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
prefill_params: Dict,
|
||||
|
|
@ -123,7 +136,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
tool_prompt_template (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
prefill_params (Dict): Additional parameters for prefilling responses.
|
||||
|
|
@ -134,7 +147,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt,
|
||||
tool_prompt_template,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
)
|
||||
|
|
@ -392,15 +405,24 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
else:
|
||||
model_response = response.choices[0].message.content
|
||||
|
||||
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
|
||||
(
|
||||
tool_calls,
|
||||
extraction_status,
|
||||
extraction_error_message,
|
||||
) = self._extract_tool_calls(model_response)
|
||||
|
||||
if tool_calls:
|
||||
is_valid, error_tool_call, error_message = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=tool_calls
|
||||
)
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
# if not extraction_status:
|
||||
|
||||
(
|
||||
verification_status,
|
||||
invalid_tool_call,
|
||||
verification_error_message,
|
||||
) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if is_valid:
|
||||
if verification_status:
|
||||
model_response = Message(content="", tool_calls=tool_calls)
|
||||
# else:
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import app.commons.utilities as utils
|
|||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
from typing import List
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
|
|
@ -13,8 +14,22 @@ class GuardRequest(BaseModel):
|
|||
task: str
|
||||
|
||||
|
||||
class GuardResponse(BaseModel):
|
||||
prob: List
|
||||
verdict: bool
|
||||
sentence: List
|
||||
latency: float = 0
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
def __init__(self, model_dict):
|
||||
"""
|
||||
Initializes the ArchGuardHanlder with the given model dictionary.
|
||||
|
||||
Args:
|
||||
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
|
||||
"""
|
||||
|
||||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
|
|
@ -23,9 +38,17 @@ class ArchGuardHanlder:
|
|||
|
||||
def _split_text_into_chunks(self, text, max_num_words=300):
|
||||
"""
|
||||
Split the text into chunks of `max_num_words` words
|
||||
Splits the input text into chunks of up to `max_num_words` words.
|
||||
|
||||
Args:
|
||||
text (str): The input text to be split.
|
||||
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of text chunks.
|
||||
"""
|
||||
words = text.split() # Split text into words
|
||||
|
||||
words = text.split()
|
||||
|
||||
chunks = [
|
||||
" ".join(words[i : i + max_num_words])
|
||||
|
|
@ -36,19 +59,44 @@ class ArchGuardHanlder:
|
|||
|
||||
@staticmethod
|
||||
def softmax(x):
|
||||
"""
|
||||
Computes the softmax of the input array.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): The input array.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The softmax of the input.
|
||||
"""
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
||||
def _predict_text(self, task, text, max_length=512):
|
||||
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
|
||||
"""
|
||||
Predicts the result for the provided text for a specific task.
|
||||
|
||||
Args:
|
||||
task (str): The task to perform (e.g., "jailbreak").
|
||||
text (str): The input text to classify.
|
||||
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
|
||||
|
||||
Returns:
|
||||
GuardResponse: A GuardResponse object containing the prediction.
|
||||
"""
|
||||
|
||||
inputs = self.tokenizer(
|
||||
text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = ArchGuardHanlder.softmax(logits)[
|
||||
self.support_tasks[task]["positive_class"]
|
||||
]
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
if prob > self.support_tasks[task]["threshold"]:
|
||||
verdict = True
|
||||
sentence = text
|
||||
|
|
@ -56,49 +104,61 @@ class ArchGuardHanlder:
|
|||
verdict = False
|
||||
sentence = None
|
||||
|
||||
result_dict = {
|
||||
"prob": prob.item(),
|
||||
"verdict": verdict,
|
||||
"sentence": sentence,
|
||||
}
|
||||
return GuardResponse(
|
||||
prob=prob.item(), verdict=verdict, sentence=sentence, latency=latency
|
||||
)
|
||||
|
||||
return result_dict
|
||||
|
||||
def predict(self, req: GuardRequest, max_num_words=300):
|
||||
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
|
||||
"""
|
||||
Note: currently only support jailbreak check
|
||||
Makes a prediction based on the GuardRequest input.
|
||||
|
||||
Args:
|
||||
req (GuardRequest): The GuardRequest object containing the input text and task.
|
||||
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
|
||||
|
||||
Returns:
|
||||
GuardResponse: A GuardResponse object containing the prediction.
|
||||
|
||||
Note:
|
||||
currently only support jailbreak check
|
||||
"""
|
||||
|
||||
if req.task not in self.support_tasks:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
guard_result = {
|
||||
"prob": [],
|
||||
"verdict": False,
|
||||
"sentence": [],
|
||||
}
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
guard_result = self._predict_text(req.task, req.input)
|
||||
return self._predict_text(req.task, req.input)
|
||||
else:
|
||||
# split into chunks if text is long
|
||||
text_chunks = self._split_text_into_chunks(req.input)
|
||||
|
||||
prob, verdict, sentence, latency = [], False, [], 0
|
||||
|
||||
for chunk in text_chunks:
|
||||
chunk_result = self._predict_text(req.task, chunk)
|
||||
if chunk_result["verdict"]:
|
||||
guard_result["verdict"] = True
|
||||
guard_result["sentence"].append(chunk_result["sentence"])
|
||||
guard_result["prob"].append(chunk_result["prob"].item())
|
||||
|
||||
guard_result["latency"] = time.perf_counter() - start_time
|
||||
if chunk_result.verdict:
|
||||
prob.append(chunk_result.prob)
|
||||
verdict = True
|
||||
sentence.append(chunk_result.sentence)
|
||||
latency += chunk_result.latency
|
||||
|
||||
return guard_result
|
||||
return GuardResponse(
|
||||
prob=prob, verdict=verdict, sentence=sentence, latency=latency
|
||||
)
|
||||
|
||||
|
||||
def get_guardrail_handler(device: str = None):
|
||||
"""
|
||||
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
|
||||
|
||||
Args:
|
||||
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
|
||||
|
||||
Returns:
|
||||
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = utils.get_device()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue