new implemenetation

This commit is contained in:
cotran 2024-11-22 11:11:26 -08:00
parent 665dbc2d4e
commit abfc81b0e7
3 changed files with 224 additions and 340 deletions

View file

@ -34,3 +34,21 @@ zero_shot_model = loader.get_zero_shot_model()
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
# Patterns for function name and parameter parsing
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
# Thresholds
HALLUCINATION_THRESHOLD_DICT = {
"t": {"entropy": 0.1, "varentropy": 0.5},
"v": {
"entropy": 0.5,
"varentropy": 2.5,
},
}

View file

@ -1,164 +1,31 @@
import torch
import numpy as np
import math
import app.commons.constants as const
import random
from typing import List, Dict, Any, Tuple
import json
import ast
import os
import json
import math
import torch
import random
from typing import Any, Dict, List, Tuple
import app.commons.constants as const
import itertools
def filter_tokens_and_probs(
tokens: List[str], probs: List[float]
) -> Tuple[List[str], List[float]]:
def check_threshold(entropy, varentropy, thd):
"""
Filters out special tokens from the list of tokens and their corresponding probabilities.
Check if the given entropy or variance of entropy exceeds the specified thresholds.
Args:
tokens (list): List of tokens.
probs (list): List of probabilities corresponding to the tokens.
entropy (float): The entropy value to check.
varentropy (float): The variance of entropy value to check.
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
Returns:
tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities.
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
"""
# Use regex to identify tokens without special characters
special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n']
filtered_tokens = [token for token in tokens if token not in special_tokens]
filtered_probs = [
prob for token, prob in zip(tokens, probs) if token not in special_tokens
]
return filtered_tokens, filtered_probs
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
def get_all_parameter_values(
tokens: List[str], probs: List[float], parameter_names: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Extracts parameter values and their corresponding probabilities from the tokens.
Args:
tokens (list): List of tokens.
probs (list): List of probabilities corresponding to the tokens.
parameter_names (dict): Dictionary of parameter names for each function.
Returns:
tuple: A tuple containing two dictionaries - parameter values and their corresponding probabilities.
"""
parameter_values = {}
probs_values = {}
i = 0
while i < len(tokens):
# Try to form parameter names by combining tokens
combined_token = ""
start = i
found_param = False
# Incrementally combine tokens to find a full match with any parameter name
while i < len(tokens):
if combined_token:
combined_token += tokens[
i
] # Append next token to the current combination
else:
combined_token = tokens[i] # Start a new combination
# Check if the combined token matches any parameter name
for func, params in parameter_names.items():
if combined_token in params:
# Collect values associated with this parameter
values = []
prob_values = []
i += 1 # Move past the parameter name
# Collect tokens as values until the next parameter or end marker
while (
i < len(tokens)
and tokens[i] not in params
and tokens[i] != "</tool_call>"
):
values.append(tokens[i])
prob_values.append(probs[i])
i += 1
# Store the parameter values and probabilities
parameter_values[combined_token] = values
probs_values[combined_token] = prob_values
found_param = True
break # Stop combining further once a parameter is matched
if found_param:
break # Exit the outer loop if parameter was matched
i += 1 # Move to the next token if no match was found yet
# Reset to the next token if no parameter match was found
if not found_param:
i = start + 1
return parameter_values, probs_values
def calculate_stats(
data: Dict[str, Any], function_description: Dict[str, Any]
) -> Dict[str, Any]:
"""
Calculates statistical metrics for the given data.
Args:
data (dict): Dictionary containing parameter values and their corresponding probabilities.
function_description (dict): Description of the function containing parameter properties.
Returns:
dict: Dictionary containing statistical metrics for each parameter.
"""
stats = {}
try:
for key, values in data.items():
if len(data[key]) >= 1:
first = values[0]
max_value = max(values)
min_value = min(values)
avg_value = sum(values) / len(values)
has_format = check_parameter_property(
function_description, key, "format"
)
has_default = check_parameter_property(
function_description, key, "default"
)
stats[key] = {
"first": first,
"max": max_value,
"min": min_value,
"avg": avg_value,
"has_format": has_format,
"has_default": has_default,
}
except Exception as e:
print(data)
return stats
def check_parameter_property(
api_description: Dict[str, Any], parameter_name: str, property_name: str
) -> bool:
"""
Check if a parameter in an API description has a specific property.
Args:
api_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
property_name (str): The property to look for (e.g., 'format', 'default').
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
parameters = api_description.get("parameters", {}).get("properties", {})
parameter_info = parameters.get(parameter_name, {})
return property_name in parameter_info
def calculate_entropy(log_probs):
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
@ -178,185 +45,196 @@ def calculate_entropy(log_probs):
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
return log_probs.tolist(), entropy.item(), varentropy.item()
return entropy.item(), varentropy.item()
def hallucination_detect(
token: str,
log_probs: List[float],
current_state: Dict[str, Any],
entropy_thd: float = 0.7,
varentropy_thd: float = 4.0,
) -> bool:
def check_parameter_property(api_description, parameter_name, property_name):
"""
Detects hallucinations in the token sequence based on entropy and varentropy thresholds.
Check if a parameter in an API description has a specific property.
Args:
token (str): The current token.
log_probs (list): List of log probabilities for the current token.
current_state (dict): The current state of the detection process.
entropy_thd (float): Entropy threshold for detecting hallucinations.
varentropy_thd (float): Variance of entropy threshold for detecting hallucinations.
api_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
property_name (str): The property to look for (e.g., 'format', 'default').
Returns:
bool: True if a hallucination is detected, False otherwise.
bool: True if the parameter has the specified property, False otherwise.
"""
parameters = api_description.get("properties", {})
parameter_info = parameters.get(parameter_name, {})
return property_name in parameter_info
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
Attributes:
tokens (list): List of tokens processed.
logprobs (list): List of log probabilities for each token.
state (str): Current state of the handler.
mask (list): List of masks indicating the type of each token.
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
current_token (str): The current token being processed.
"""
if token:
# check if there is content in token
current_state["tokens"].append(token)
current_state["content"] += token
current_state["logprobs"].append(log_probs)
# keep track of entropy and varentropy
_, entropy, varentropy = calculate_entropy(log_probs)
current_state["entropy"].append(entropy)
current_state["varentropy"].append(varentropy)
# first check if tool call token is certain
if token == "<tool_call>":
if entropy > entropy_thd or varentropy > varentropy_thd:
current_state["hallucination"] = True
current_state[
"hallucination_message"
] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}"
return True
elif token == "</tool_call>":
current_state["state"] = "tool_call_end"
# try to extract tool call, else raise error
try:
current_state[
"tool_call"
] = const.arch_function_hanlder.extract_tool_calls(
current_state["content"]
)[
0
]
current_state["tool_call_process"] = True
except:
current_state["tool_call_process"] = False
print(f"cant process tool")
return True
# check if function name is valid
if (
current_state["tool_call"]["function"]["name"]
not in current_state["parameter_names"].keys()
):
current_state["hallucination"] = True
current_state[
"hallucination_message"
] = f"function name {current_state['tool_call']['name']} not found"
return True
def __init__(self):
"""
Initializes the HallucinationStateHandler with default values.
"""
self.tokens = []
self.logprobs = []
self.state = None
self.mask = []
self.parameter_name_done = False
self.hallucination = False
self.hallucination_message = ""
self.parameter_name = []
# check if parameter names are from the given function tools
current_parameter_names = current_state["tool_call"]["function"][
"arguments"
].keys()
given_parameter_names = current_state["parameter_names"][
current_state["tool_call"]["function"]["name"]
]
if not set(current_parameter_names).issubset(given_parameter_names):
missing_keys = set(current_parameter_names) - set(given_parameter_names)
self.token_probs_map = []
self.current_token = None
current_state["hallucination"] = True
current_state[
"hallucination_message"
] = f"parameter names {missing_keys} not found"
return True
def process_function(self, apis):
self.apis = apis
if self.apis is None:
raise ValueError("API descriptions not set.")
parameter_names = {}
for func in self.apis:
func_name = func["name"]
parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys())
self.function_description = parameter_names
self.function_properties = {x["name"]: x["parameters"] for x in self.apis}
# filtered special tokens that are not needed in the hallucination check for parameter values
(
current_state["filtered_tokens"],
current_state["filtered_entropy"],
) = filter_tokens_and_probs(
current_state["tokens"], current_state["entropy"]
)
(
current_state["filtered_tokens"],
current_state["filtered_varentropy"],
) = filter_tokens_and_probs(
current_state["tokens"], current_state["varentropy"]
)
parameter_values, entropy_values = get_all_parameter_values(
current_state["filtered_tokens"],
current_state["filtered_entropy"],
current_state["parameter_names"],
)
parameter_values, varentropy_values = get_all_parameter_values(
current_state["filtered_tokens"],
current_state["filtered_varentropy"],
current_state["parameter_names"],
)
def process_token(self):
"""
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.current_token == "<tool_call>":
self.mask.append("t")
self.check_logprob()
current_state["parameter_values"] = parameter_values
current_state["parameter_values_entropy"] = entropy_values
current_state["parameter_values_varentropy"] = varentropy_values
# calculate the max, first, avg of sub tokens for parameter value
current_state["parameter_value_entropy_stat"] = calculate_stats(
current_state["parameter_values_entropy"],
current_state["function_description"][0],
)
current_state["parameter_value_varentropy_stat"] = calculate_stats(
current_state["parameter_values_varentropy"],
current_state["function_description"][0],
)
# get map for debugging
current_state["token_entropy_map"] = {
x: y for x, y in zip(current_state["tokens"], current_state["entropy"])
}
current_state["token_varentropy_map"] = {
x: y
for x, y in zip(current_state["tokens"], current_state["varentropy"])
}
# Function name extraction logic
if self.state == "function_name":
if self.current_token not in const.FUNC_NAME_END_TOKEN:
self.mask.append("f")
else:
self.state = None
self.check_function_name()
# checking hallucination for parameter value
current_state["parameter_value_check"] = {
x: {"hallucination": False, "message": ""}
for x in current_state["parameter_values"].keys()
}
for key in current_state["parameter_value_check"].keys():
# if parameter is given a format, check the first token
if current_state["parameter_value_entropy_stat"][key]["has_format"]:
if (
current_state["parameter_value_entropy_stat"][key]["first"]
> entropy_thd
or current_state["parameter_value_varentropy_stat"][key][
"first"
]
> varentropy_thd
):
current_state["parameter_value_check"][key][
"hallucination"
] = True
current_state["hallucination"] = True
current_state["parameter_value_check"][key][
"message"
] = f"parameter {key} with formatting doesn't pass threshold"
# if parameter gis given a default value, we can always use default
elif current_state["parameter_value_entropy_stat"][key]["has_default"]:
current_state["parameter_value_check"][key]["hallucination"] = False
current_state["parameter_value_check"][key][
"message"
] = f"parameter {key} with default"
# check if max sub token is > thresholds
else:
if (
current_state["parameter_value_entropy_stat"][key]["max"]
> entropy_thd
or current_state["parameter_value_varentropy_stat"][key]["max"]
> varentropy_thd
):
current_state["parameter_value_check"][key][
"hallucination"
] = True
current_state["parameter_value_check"][key][
"message"
] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold"
current_state["hallucination"] = True
if current_state["hallucination"] == True:
current_state["hallucination_message"] = "\n".join(
[
current_state["parameter_value_check"][key]["message"]
for key in current_state["parameter_value_check"].keys()
]
)
return True
return False
if content.endswith(const.FUNC_NAME_START_PATTERN):
print("function name entered")
self.state = "function_name"
# Parameter name extraction logic
if self.state == "parameter_name" and not content.endswith(
const.PARAMETER_NAME_END_TOKENS
):
self.mask.append("p")
elif self.state == "parameter_name" and content.endswith(
const.PARAMETER_NAME_END_TOKENS
):
self.state = None
self.check_parameter_name()
self.parameter_name_done = True
elif self.parameter_name_done and content.endswith(
const.PARAMETER_NAME_START_PATTERN
):
self.state = "parameter_name"
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name"
# Parameter value extraction logic
if self.state == "parameter_value" and not content.endswith(
const.PARAMETER_VALUE_END_TOKEN
):
if self.current_token.strip() not in ['"', ""]:
self.mask.append("v")
if (
len(self.mask) > 1
and self.mask[-2] == "v"
and not check_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
)
):
self.check_logprob()
else:
self.mask.append("e")
elif self.state == "parameter_value" and content.endswith(
const.PARAMETER_VALUE_END_TOKEN
):
self.state = None
elif self.parameter_name_done and content.endswith(
const.PARAMETER_VALUE_START_PATTERN
):
self.state = "parameter_value"
# Maintain consistency between stack and mask
if len(self.mask) != len(self.tokens):
self.mask.append("e")
def check_logprob(self):
"""
Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy.
"""
probs = self.logprobs[-1]
entropy, varentropy = calculate_entropy(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold(
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]]
):
self.hallucination = True
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
def count_consecutive_token(self, token="v") -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
Args:
token (str): The token to count in the mask.
Returns:
int: The number of consecutive occurrences of the token.
"""
return (
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
if self.mask and self.mask[-1] == token
else 0
)
def check_function_name(self):
"""
Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found.
"""
f_len = self.count_consecutive_token("f")
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.hallucination = True
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
def check_parameter_name(self):
"""
Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found.
"""
p_len = self.count_consecutive_token("p")
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]:
self.hallucination = True
self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions."

View file

@ -1,5 +1,5 @@
import json
from app.function_calling.hallucination_handler import hallucination_detect
from app.function_calling.hallucination_handler import HallucinationStateHandler
import pytest
import os
@ -44,27 +44,15 @@ function_description = get_weather_api["function"]
if type(function_description) != list:
function_description = [get_weather_api["function"]]
parameter_names = {}
for func in function_description:
func_name = func["name"]
parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys())
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
current_state = {
"state": "start",
"tool_call": "",
"entropy": [],
"varentropy": [],
"logprobs": [],
"tokens": [],
"content": "",
"hallucination": False,
"parameter_names": parameter_names,
"function_description": function_description,
}
for token_content, logprobs in zip(case["tokens"], case["logprobs"]):
result = hallucination_detect(token_content, logprobs, current_state, 0.7, 4)
assert result == case["expect"]
state = HallucinationStateHandler()
state.process_function(function_description)
for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "</tool_call>":
state.current_token = token
state.tokens.append(token)
state.logprobs.append(logprob)
state.process_token()
assert state.hallucination == case["expect"]