From abfc81b0e7740125bfc2808c3618bdc233e08e19 Mon Sep 17 00:00:00 2001 From: cotran Date: Fri, 22 Nov 2024 11:11:26 -0800 Subject: [PATCH] new implemenetation --- model_server/app/commons/constants.py | 18 + .../function_calling/hallucination_handler.py | 514 +++++++----------- model_server/app/tests/test_hallucination.py | 32 +- 3 files changed, 224 insertions(+), 340 deletions(-) diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index d4e01d12..977d7b9c 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -34,3 +34,21 @@ zero_shot_model = loader.get_zero_shot_model() prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE]) arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict) +# Patterns for function name and parameter parsing +FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") +FUNC_NAME_END_TOKEN = ('",', "',") + +FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") +PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") +PARAMETER_NAME_START_PATTERN = (',"', ",'") +PARAMETER_VALUE_START_PATTERN = ('":', "':") +PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") + +# Thresholds +HALLUCINATION_THRESHOLD_DICT = { + "t": {"entropy": 0.1, "varentropy": 0.5}, + "v": { + "entropy": 0.5, + "varentropy": 2.5, + }, +} diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index 67bc1be4..92fd1326 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -1,164 +1,31 @@ -import torch -import numpy as np -import math -import app.commons.constants as const -import random -from typing import List, Dict, Any, Tuple import json +import ast +import os +import json +import math +import torch +import random +from typing import Any, Dict, List, Tuple +import app.commons.constants as const +import itertools -def filter_tokens_and_probs( - tokens: List[str], probs: List[float] -) -> Tuple[List[str], List[float]]: +def check_threshold(entropy, varentropy, thd): """ - Filters out special tokens from the list of tokens and their corresponding probabilities. + Check if the given entropy or variance of entropy exceeds the specified thresholds. Args: - tokens (list): List of tokens. - probs (list): List of probabilities corresponding to the tokens. + entropy (float): The entropy value to check. + varentropy (float): The variance of entropy value to check. + thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'. Returns: - tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities. + bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise. """ - # Use regex to identify tokens without special characters - special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n'] - filtered_tokens = [token for token in tokens if token not in special_tokens] - filtered_probs = [ - prob for token, prob in zip(tokens, probs) if token not in special_tokens - ] - return filtered_tokens, filtered_probs + return entropy > thd["entropy"] or varentropy > thd["varentropy"] -def get_all_parameter_values( - tokens: List[str], probs: List[float], parameter_names: Dict[str, Any] -) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - Extracts parameter values and their corresponding probabilities from the tokens. - - Args: - tokens (list): List of tokens. - probs (list): List of probabilities corresponding to the tokens. - parameter_names (dict): Dictionary of parameter names for each function. - - Returns: - tuple: A tuple containing two dictionaries - parameter values and their corresponding probabilities. - """ - parameter_values = {} - probs_values = {} - i = 0 - - while i < len(tokens): - # Try to form parameter names by combining tokens - combined_token = "" - start = i - found_param = False - - # Incrementally combine tokens to find a full match with any parameter name - while i < len(tokens): - if combined_token: - combined_token += tokens[ - i - ] # Append next token to the current combination - else: - combined_token = tokens[i] # Start a new combination - - # Check if the combined token matches any parameter name - for func, params in parameter_names.items(): - if combined_token in params: - # Collect values associated with this parameter - values = [] - prob_values = [] - i += 1 # Move past the parameter name - - # Collect tokens as values until the next parameter or end marker - while ( - i < len(tokens) - and tokens[i] not in params - and tokens[i] != "" - ): - values.append(tokens[i]) - prob_values.append(probs[i]) - i += 1 - - # Store the parameter values and probabilities - parameter_values[combined_token] = values - probs_values[combined_token] = prob_values - - found_param = True - break # Stop combining further once a parameter is matched - - if found_param: - break # Exit the outer loop if parameter was matched - i += 1 # Move to the next token if no match was found yet - - # Reset to the next token if no parameter match was found - if not found_param: - i = start + 1 - - return parameter_values, probs_values - - -def calculate_stats( - data: Dict[str, Any], function_description: Dict[str, Any] -) -> Dict[str, Any]: - """ - Calculates statistical metrics for the given data. - - Args: - data (dict): Dictionary containing parameter values and their corresponding probabilities. - function_description (dict): Description of the function containing parameter properties. - - Returns: - dict: Dictionary containing statistical metrics for each parameter. - """ - stats = {} - try: - for key, values in data.items(): - if len(data[key]) >= 1: - first = values[0] - max_value = max(values) - min_value = min(values) - avg_value = sum(values) / len(values) - has_format = check_parameter_property( - function_description, key, "format" - ) - has_default = check_parameter_property( - function_description, key, "default" - ) - stats[key] = { - "first": first, - "max": max_value, - "min": min_value, - "avg": avg_value, - "has_format": has_format, - "has_default": has_default, - } - except Exception as e: - print(data) - return stats - - -def check_parameter_property( - api_description: Dict[str, Any], parameter_name: str, property_name: str -) -> bool: - """ - Check if a parameter in an API description has a specific property. - - Args: - api_description (dict): The API description in JSON format. - parameter_name (str): The name of the parameter to check. - property_name (str): The property to look for (e.g., 'format', 'default'). - - Returns: - bool: True if the parameter has the specified property, False otherwise. - """ - parameters = api_description.get("parameters", {}).get("properties", {}) - parameter_info = parameters.get(parameter_name, {}) - - return property_name in parameter_info - - -def calculate_entropy(log_probs): +def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: """ Calculate the entropy and variance of entropy (varentropy) from log probabilities. @@ -178,185 +45,196 @@ def calculate_entropy(log_probs): token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, dim=-1, ) - return log_probs.tolist(), entropy.item(), varentropy.item() + return entropy.item(), varentropy.item() -def hallucination_detect( - token: str, - log_probs: List[float], - current_state: Dict[str, Any], - entropy_thd: float = 0.7, - varentropy_thd: float = 4.0, -) -> bool: +def check_parameter_property(api_description, parameter_name, property_name): """ - Detects hallucinations in the token sequence based on entropy and varentropy thresholds. + Check if a parameter in an API description has a specific property. Args: - token (str): The current token. - log_probs (list): List of log probabilities for the current token. - current_state (dict): The current state of the detection process. - entropy_thd (float): Entropy threshold for detecting hallucinations. - varentropy_thd (float): Variance of entropy threshold for detecting hallucinations. + api_description (dict): The API description in JSON format. + parameter_name (str): The name of the parameter to check. + property_name (str): The property to look for (e.g., 'format', 'default'). Returns: - bool: True if a hallucination is detected, False otherwise. + bool: True if the parameter has the specified property, False otherwise. + """ + parameters = api_description.get("properties", {}) + parameter_info = parameters.get(parameter_name, {}) + + return property_name in parameter_info + + +class HallucinationStateHandler: + """ + A class to handle the state of hallucination detection in token processing. + + Attributes: + tokens (list): List of tokens processed. + logprobs (list): List of log probabilities for each token. + state (str): Current state of the handler. + mask (list): List of masks indicating the type of each token. + parameter_name_done (bool): Flag indicating if parameter name extraction is done. + hallucination (bool): Flag indicating if a hallucination is detected. + hallucination_message (str): Message describing the hallucination. + parameter_name (list): List of extracted parameter names. + function_description (dict): Description of functions and their parameters. + token_probs_map (list): List mapping tokens to their entropy and variance of entropy. + current_token (str): The current token being processed. """ - if token: - # check if there is content in token - current_state["tokens"].append(token) - current_state["content"] += token - current_state["logprobs"].append(log_probs) - # keep track of entropy and varentropy - _, entropy, varentropy = calculate_entropy(log_probs) - current_state["entropy"].append(entropy) - current_state["varentropy"].append(varentropy) - # first check if tool call token is certain - if token == "": - if entropy > entropy_thd or varentropy > varentropy_thd: - current_state["hallucination"] = True - current_state[ - "hallucination_message" - ] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}" - return True - elif token == "": - current_state["state"] = "tool_call_end" - # try to extract tool call, else raise error - try: - current_state[ - "tool_call" - ] = const.arch_function_hanlder.extract_tool_calls( - current_state["content"] - )[ - 0 - ] - current_state["tool_call_process"] = True - except: - current_state["tool_call_process"] = False - print(f"cant process tool") - return True - # check if function name is valid - if ( - current_state["tool_call"]["function"]["name"] - not in current_state["parameter_names"].keys() - ): - current_state["hallucination"] = True - current_state[ - "hallucination_message" - ] = f"function name {current_state['tool_call']['name']} not found" - return True + def __init__(self): + """ + Initializes the HallucinationStateHandler with default values. + """ + self.tokens = [] + self.logprobs = [] + self.state = None + self.mask = [] + self.parameter_name_done = False + self.hallucination = False + self.hallucination_message = "" + self.parameter_name = [] - # check if parameter names are from the given function tools - current_parameter_names = current_state["tool_call"]["function"][ - "arguments" - ].keys() - given_parameter_names = current_state["parameter_names"][ - current_state["tool_call"]["function"]["name"] - ] - if not set(current_parameter_names).issubset(given_parameter_names): - missing_keys = set(current_parameter_names) - set(given_parameter_names) + self.token_probs_map = [] + self.current_token = None - current_state["hallucination"] = True - current_state[ - "hallucination_message" - ] = f"parameter names {missing_keys} not found" - return True + def process_function(self, apis): + self.apis = apis + if self.apis is None: + raise ValueError("API descriptions not set.") + parameter_names = {} + for func in self.apis: + func_name = func["name"] + parameters = func["parameters"]["properties"] + parameter_names[func_name] = list(parameters.keys()) + self.function_description = parameter_names + self.function_properties = {x["name"]: x["parameters"] for x in self.apis} - # filtered special tokens that are not needed in the hallucination check for parameter values - ( - current_state["filtered_tokens"], - current_state["filtered_entropy"], - ) = filter_tokens_and_probs( - current_state["tokens"], current_state["entropy"] - ) - ( - current_state["filtered_tokens"], - current_state["filtered_varentropy"], - ) = filter_tokens_and_probs( - current_state["tokens"], current_state["varentropy"] - ) - parameter_values, entropy_values = get_all_parameter_values( - current_state["filtered_tokens"], - current_state["filtered_entropy"], - current_state["parameter_names"], - ) - parameter_values, varentropy_values = get_all_parameter_values( - current_state["filtered_tokens"], - current_state["filtered_varentropy"], - current_state["parameter_names"], - ) + def process_token(self): + """ + Processes the current token and updates the state and mask accordingly. + Detects hallucinations based on the token type and log probabilities. + """ + content = "".join(self.tokens).replace(" ", "") + if self.current_token == "": + self.mask.append("t") + self.check_logprob() - current_state["parameter_values"] = parameter_values - current_state["parameter_values_entropy"] = entropy_values - current_state["parameter_values_varentropy"] = varentropy_values - # calculate the max, first, avg of sub tokens for parameter value - current_state["parameter_value_entropy_stat"] = calculate_stats( - current_state["parameter_values_entropy"], - current_state["function_description"][0], - ) - current_state["parameter_value_varentropy_stat"] = calculate_stats( - current_state["parameter_values_varentropy"], - current_state["function_description"][0], - ) - # get map for debugging - current_state["token_entropy_map"] = { - x: y for x, y in zip(current_state["tokens"], current_state["entropy"]) - } - current_state["token_varentropy_map"] = { - x: y - for x, y in zip(current_state["tokens"], current_state["varentropy"]) - } + # Function name extraction logic + if self.state == "function_name": + if self.current_token not in const.FUNC_NAME_END_TOKEN: + self.mask.append("f") + else: + self.state = None + self.check_function_name() - # checking hallucination for parameter value - current_state["parameter_value_check"] = { - x: {"hallucination": False, "message": ""} - for x in current_state["parameter_values"].keys() - } - for key in current_state["parameter_value_check"].keys(): - # if parameter is given a format, check the first token - if current_state["parameter_value_entropy_stat"][key]["has_format"]: - if ( - current_state["parameter_value_entropy_stat"][key]["first"] - > entropy_thd - or current_state["parameter_value_varentropy_stat"][key][ - "first" - ] - > varentropy_thd - ): - current_state["parameter_value_check"][key][ - "hallucination" - ] = True - current_state["hallucination"] = True - current_state["parameter_value_check"][key][ - "message" - ] = f"parameter {key} with formatting doesn't pass threshold" - # if parameter gis given a default value, we can always use default - elif current_state["parameter_value_entropy_stat"][key]["has_default"]: - current_state["parameter_value_check"][key]["hallucination"] = False - current_state["parameter_value_check"][key][ - "message" - ] = f"parameter {key} with default" - # check if max sub token is > thresholds - else: - if ( - current_state["parameter_value_entropy_stat"][key]["max"] - > entropy_thd - or current_state["parameter_value_varentropy_stat"][key]["max"] - > varentropy_thd - ): - current_state["parameter_value_check"][key][ - "hallucination" - ] = True - current_state["parameter_value_check"][key][ - "message" - ] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold" - current_state["hallucination"] = True - if current_state["hallucination"] == True: - current_state["hallucination_message"] = "\n".join( - [ - current_state["parameter_value_check"][key]["message"] - for key in current_state["parameter_value_check"].keys() - ] - ) - return True - return False + if content.endswith(const.FUNC_NAME_START_PATTERN): + print("function name entered") + self.state = "function_name" + + # Parameter name extraction logic + if self.state == "parameter_name" and not content.endswith( + const.PARAMETER_NAME_END_TOKENS + ): + self.mask.append("p") + elif self.state == "parameter_name" and content.endswith( + const.PARAMETER_NAME_END_TOKENS + ): + self.state = None + self.check_parameter_name() + self.parameter_name_done = True + elif self.parameter_name_done and content.endswith( + const.PARAMETER_NAME_START_PATTERN + ): + self.state = "parameter_name" + + if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN): + self.state = "parameter_name" + + # Parameter value extraction logic + if self.state == "parameter_value" and not content.endswith( + const.PARAMETER_VALUE_END_TOKEN + ): + if self.current_token.strip() not in ['"', ""]: + self.mask.append("v") + if ( + len(self.mask) > 1 + and self.mask[-2] == "v" + and not check_parameter_property( + self.function_properties[self.function_name], + self.parameter_name[-1], + "default", + ) + ): + self.check_logprob() + else: + self.mask.append("e") + + elif self.state == "parameter_value" and content.endswith( + const.PARAMETER_VALUE_END_TOKEN + ): + self.state = None + elif self.parameter_name_done and content.endswith( + const.PARAMETER_VALUE_START_PATTERN + ): + self.state = "parameter_value" + + # Maintain consistency between stack and mask + if len(self.mask) != len(self.tokens): + self.mask.append("e") + + def check_logprob(self): + """ + Checks the log probability of the current token and updates the token probability map. + Detects hallucinations based on entropy and variance of entropy. + """ + probs = self.logprobs[-1] + entropy, varentropy = calculate_entropy(probs) + self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) + + if check_threshold( + entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]] + ): + self.hallucination = True + self.hallucination_message = f"Token '{self.current_token}' is uncertain." + + def count_consecutive_token(self, token="v") -> int: + """ + Counts the number of consecutive occurrences of a given token in the mask. + + Args: + token (str): The token to count in the mask. + + Returns: + int: The number of consecutive occurrences of the token. + """ + return ( + len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask)))) + if self.mask and self.mask[-1] == token + else 0 + ) + + def check_function_name(self): + """ + Checks the extracted function name against the function descriptions. + Detects hallucinations if the function name is not found. + """ + f_len = self.count_consecutive_token("f") + self.function_name = "".join(self.tokens[:-1][-f_len:]) + if self.function_name not in self.function_description.keys(): + self.hallucination = True + self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." + + def check_parameter_name(self): + """ + Checks the extracted parameter name against the function descriptions. + Detects hallucinations if the parameter name is not found. + """ + p_len = self.count_consecutive_token("p") + parameter_name = "".join(self.tokens[:-1][-p_len:]) + self.parameter_name.append(parameter_name) + if parameter_name not in self.function_description[self.function_name]: + self.hallucination = True + self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions." diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py index 21e8ad60..25206558 100644 --- a/model_server/app/tests/test_hallucination.py +++ b/model_server/app/tests/test_hallucination.py @@ -1,5 +1,5 @@ import json -from app.function_calling.hallucination_handler import hallucination_detect +from app.function_calling.hallucination_handler import HallucinationStateHandler import pytest import os @@ -44,27 +44,15 @@ function_description = get_weather_api["function"] if type(function_description) != list: function_description = [get_weather_api["function"]] -parameter_names = {} -for func in function_description: - func_name = func["name"] - parameters = func["parameters"]["properties"] - parameter_names[func_name] = list(parameters.keys()) - @pytest.mark.parametrize("case", test_cases) def test_hallucination(case): - current_state = { - "state": "start", - "tool_call": "", - "entropy": [], - "varentropy": [], - "logprobs": [], - "tokens": [], - "content": "", - "hallucination": False, - "parameter_names": parameter_names, - "function_description": function_description, - } - for token_content, logprobs in zip(case["tokens"], case["logprobs"]): - result = hallucination_detect(token_content, logprobs, current_state, 0.7, 4) - assert result == case["expect"] + state = HallucinationStateHandler() + state.process_function(function_description) + for token, logprob in zip(case["tokens"], case["logprobs"]): + if token != "": + state.current_token = token + state.tokens.append(token) + state.logprobs.append(logprob) + state.process_token() + assert state.hallucination == case["expect"]