Init update on model_server

This commit is contained in:
Shuguang Chen 2024-12-04 16:41:30 -08:00
parent 1d9de28086
commit afe1410b37
25 changed files with 1758 additions and 1922 deletions

View file

@ -0,0 +1,415 @@
import json
import random
import builtins
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import override, final
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
model: str
choices: List[Choice]
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
generation_params: Dict,
):
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt = tool_prompt
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]):
raise NotImplementedError()
@final
def _format_system(self, tools: List[Dict[str, Any]]):
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instructions: str = None,
):
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instructions:
processed_messages[-1]["content"] += extra_instructions
return processed_messages
async def chat_completion(self, req: ChatMessage):
raise NotImplementedError()
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
intent_instruction: str,
generation_params: Dict,
):
super().__init__(
client,
model_name,
task_prompt,
tool_prompt,
format_prompt,
generation_params,
)
self.intent_instruction = intent_instruction
@override
def _convert_tools(self, tools: List[Dict[str, Any]]):
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
@override
async def chat_completion(self, req: ChatMessage):
"""
Note: Currently only support vllm inference
"""
messages = self._process_messages(
req.messages, req.tools, self.intent_instruction
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
return chat_completion_response
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
prefill_prefix: List,
):
super().__init__(
client,
model_name,
task_prompt,
tool_prompt,
format_prompt,
generation_params,
)
self.prefill_params = prefill_params
self.prefill_prefix = prefill_prefix
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
}
@override
def _convert_tools(self, tools: List[Dict[str, Any]]):
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(self, content: str):
tool_calls, is_valid, error_message = [], True, ""
flag = False
for line in content.split("\n"):
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception as e:
fixed_content = self._fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
return tool_calls, is_valid, error_message
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return tool_calls, is_valid, error_message
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
):
is_valid, error_tool_call, error_message = True, None, ""
functions = {}
for tool in tools:
if tool["type"] == "function":
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
func_name, func_args = (
tool_call["function"]["name"],
tool_call["function"]["arguments"],
)
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
error_message = f"{func_name} is not defined!"
return is_valid, error_message
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
error_tool_call = tool_call
error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!"
return is_valid, error_tool_call, error_message
# Verify the data type of each parameter in the tool calls
for param_name, param_value in func_args:
data_type = functions[func_name]["properties"][param_name]["type"]
if data_type in self.support_data_types:
if not isinstance(
param_value, self.support_data_types[data_type]
):
is_valid = False
error_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
return is_valid, error_tool_call, error_message
return is_valid, error_tool_call, error_message
@override
async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
"""
Note: Currently only support vllm inference
"""
messages = self._process_messages(req.messages, req.tools)
# Retrieve the first token, handling the Stream object carefully
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=enable_prefilling,
extra_body=self.generation_params,
)
model_response = ""
if enable_prefilling:
has_tool_call = None
model_response = ""
for token in response:
token_content = token.choices[0].delta.content.strip()
if has_tool_call is None and token_content != "<tool_call>":
has_tool_call = False
response.close()
break
else:
has_tool_call = True
if has_tool_call is True:
model_response += token_content
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
if tool_calls:
is_valid, error_tool_call, error_message = self._verify_tool_calls(
tools=req.tools, tool_calls=tool_calls
)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if is_valid:
model_response = Message(content="", tool_calls=tool_calls)
# else:
else:
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
# [TODO] Review: define the protocol to collect debugging output
# logger.info(
# f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
# )
# logger.info(
# f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
# )
return chat_completion_response

View file

@ -0,0 +1,95 @@
import time
import torch
import numpy as np
from pydantic import BaseModel
class GuardRequest(BaseModel):
input: str
task: str
class ArchGuardHanlder:
def __init__(self, model_dict):
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Split the text into chunks of `max_num_words` words
"""
words = text.split() # Split text into words
chunks = [
" ".join(words[i : i + max_num_words])
for i in range(0, len(words), max_num_words)
]
return chunks
@staticmethod
def softmax(x):
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512):
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
else:
verdict = False
sentence = None
result_dict = {
"prob": prob.item(),
"verdict": verdict,
"sentence": sentence,
}
return result_dict
def predict(self, req: GuardRequest, max_num_words=300):
"""
Note: currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
guard_result = {
"prob": [],
"verdict": False,
"sentence": [],
}
start_time = time.perf_counter()
if len(req.input.split()) < max_num_words:
guard_result = self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result["verdict"]:
guard_result["verdict"] = True
guard_result["sentence"].append(chunk_result["sentence"])
guard_result["prob"].append(chunk_result["prob"].item())
guard_result["latency"] = time.perf_counter() - start_time
return guard_result

View file

@ -0,0 +1,268 @@
import math
import torch
from typing import Dict, List, Tuple
import itertools
from enum import Enum
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
# Thresholds
class MaskToken(Enum):
FUNCTION_NAME = "f"
PARAMETER_VALUE = "v"
PARAMETER_NAME = "p"
NOT_USED = "e"
TOOL_CALL = "t"
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5,
"varentropy": 2.5,
},
}
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
"""
Check if the given entropy or variance of entropy exceeds the specified thresholds.
Args:
entropy (float): The entropy value to check.
varentropy (float): The variance of entropy value to check.
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
Returns:
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
"""
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
Args:
log_probs (list of float): A list of log probabilities.
Returns:
tuple: A tuple containing:
- log_probs (list of float): The input log probabilities as a list.
- entropy (float): The calculated entropy.
- varentropy (float): The calculated variance of entropy.
"""
log_probs = torch.tensor(log_probs)
token_probs = torch.exp(log_probs)
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
varentropy = torch.sum(
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
return entropy.item(), varentropy.item()
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
Attributes:
tokens (list): List of tokens processed.
logprobs (list): List of log probabilities for each token.
state (str): Current state of the handler.
mask (list): List of masks indicating the type of each token.
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
"""
def __init__(self, response_iterator=None):
"""
Initializes the HallucinationStateHandler with default values.
"""
self.tokens: List[str] = []
self.logprobs: List[float] = []
self.state: str = None
self.mask: List[str] = []
self.parameter_name_done: bool = False
self.hallucination: bool = False
self.error_message: str = ""
self.error_type: str = ""
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
def append_and_check_token_hallucination(self, token, logprob):
"""
Check if the given token is hallucinated based on the log probability.
Args:
token (str): The token to check.
logprob (float): The log probability of the token.
Returns:
bool: True if the token is hallucinated, False otherwise.
"""
self.tokens.append(token)
self.logprobs.append(logprob)
self._process_token()
return self.hallucination
def __iter__(self):
return self
def __next__(self):
if self.response_iterator is not None:
try:
r = next(self.response_iterator)
if hasattr(r.choices[0].delta, "content"):
token_content = r.choices[0].delta.content
if token_content:
try:
logprobs = [
p.logprob
for p in r.choices[0].logprobs.content[0].top_logprobs
]
except Exception as e:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content
except StopIteration:
raise StopIteration
def _process_token(self):
"""
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name":
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
self._is_function_name_hallucinated()
# Check if the token is a function name start token, change the state
if content.endswith(FUNC_NAME_START_PATTERN):
self.state = "function_name"
# Parameter name extraction logic
# if the state is parameter name and the token is not an end token, add to the mask
if self.state == "parameter_name" and not content.endswith(
PARAMETER_NAME_END_TOKENS
):
self.mask.append(MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith(
PARAMETER_NAME_END_TOKENS
):
self.state = None
self._is_parameter_name_hallucinated()
self.parameter_name_done = True
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_NAME_START_PATTERN
):
self.state = "parameter_name"
# if token is a first parameter value start token, change the state
if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name"
# Parameter value extraction logic
# if the state is parameter value and the token is not an end token, add to the mask
if self.state == "parameter_value" and not content.endswith(
PARAMETER_VALUE_END_TOKEN
):
# checking if the token is a value token and is not empty
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(MaskToken.PARAMETER_VALUE)
# [TODO] Review: update the following code: `is_parameter_property` should not be here
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
)
):
self._check_logprob()
else:
self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith(
PARAMETER_VALUE_END_TOKEN
):
self.state = None
# if the parameter name is done and the token is a parameter value start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_VALUE_START_PATTERN
):
self.state = "parameter_value"
# Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens):
self.mask.append(MaskToken.NOT_USED)
def _check_logprob(self):
"""
Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy.
"""
probs = self.logprobs[-1]
entropy, varentropy = calculate_entropy(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold(
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
):
self.hallucination = True
self.error_type = "Hallucination"
self.error_message = (
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
)
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
Args:
token (str): The token to count in the mask.
Returns:
int: The number of consecutive occurrences of the token.
"""
return (
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
if self.mask and self.mask[-1] == token
else 0
)