This commit is contained in:
cotran 2024-11-18 00:53:49 -08:00
parent d776ed0117
commit 665dbc2d4e
3 changed files with 203 additions and 61 deletions

View file

@ -1,8 +1,15 @@
import torch import torch
import numpy as np import numpy as np
from typing import List, Dict import math
import app.commons.constants as const
import random
from typing import List, Dict, Any, Tuple
import json
def filter_tokens_and_probs(tokens: List[str], probs: List[float]) -> Tuple[List[], List[float]]:
def filter_tokens_and_probs(
tokens: List[str], probs: List[float]
) -> Tuple[List[str], List[float]]:
""" """
Filters out special tokens from the list of tokens and their corresponding probabilities. Filters out special tokens from the list of tokens and their corresponding probabilities.
@ -14,17 +21,17 @@ def filter_tokens_and_probs(tokens: List[str], probs: List[float]) -> Tuple[List
tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities. tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities.
""" """
# Use regex to identify tokens without special characters # Use regex to identify tokens without special characters
special_tokens = ['\\n', '{"', '":', ' "', '",', ' {"', '"}}\\n', ' ', '"}}\n'] special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n']
filtered_tokens = [ filtered_tokens = [token for token in tokens if token not in special_tokens]
token for token in tokens
if token not in special_tokens
]
filtered_probs = [ filtered_probs = [
prob for token, prob in zip(tokens, probs) prob for token, prob in zip(tokens, probs) if token not in special_tokens
if token not in special_tokens
] ]
return filtered_tokens, filtered_probs return filtered_tokens, filtered_probs
def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_names: Dict[str, List[str]]) -> Tuple[Dict[str, List[str]], Dict[str, List[float]]]:
def get_all_parameter_values(
tokens: List[str], probs: List[float], parameter_names: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
""" """
Extracts parameter values and their corresponding probabilities from the tokens. Extracts parameter values and their corresponding probabilities from the tokens.
@ -49,7 +56,9 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na
# Incrementally combine tokens to find a full match with any parameter name # Incrementally combine tokens to find a full match with any parameter name
while i < len(tokens): while i < len(tokens):
if combined_token: if combined_token:
combined_token += tokens[i] # Append next token to the current combination combined_token += tokens[
i
] # Append next token to the current combination
else: else:
combined_token = tokens[i] # Start a new combination combined_token = tokens[i] # Start a new combination
@ -62,7 +71,11 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na
i += 1 # Move past the parameter name i += 1 # Move past the parameter name
# Collect tokens as values until the next parameter or end marker # Collect tokens as values until the next parameter or end marker
while i < len(tokens) and tokens[i] not in params and tokens[i] != '</tool_call>': while (
i < len(tokens)
and tokens[i] not in params
and tokens[i] != "</tool_call>"
):
values.append(tokens[i]) values.append(tokens[i])
prob_values.append(probs[i]) prob_values.append(probs[i])
i += 1 i += 1
@ -83,7 +96,11 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na
i = start + 1 i = start + 1
return parameter_values, probs_values return parameter_values, probs_values
def calculate_stats(data: Dict, function_description: Dict) -> Dict:
def calculate_stats(
data: Dict[str, Any], function_description: Dict[str, Any]
) -> Dict[str, Any]:
""" """
Calculates statistical metrics for the given data. Calculates statistical metrics for the given data.
@ -97,19 +114,33 @@ def calculate_stats(data: Dict, function_description: Dict) -> Dict:
stats = {} stats = {}
try: try:
for key, values in data.items(): for key, values in data.items():
if len(data[key])>=1: if len(data[key]) >= 1:
first = values[0] first = values[0]
max_value = max(values) max_value = max(values)
min_value = min(values) min_value = min(values)
avg_value = sum(values) / len(values) avg_value = sum(values) / len(values)
has_format = check_parameter_property(function_description, key, "format") has_format = check_parameter_property(
has_default = check_parameter_property(function_description, key , "default") function_description, key, "format"
stats[key] = {'first':first, 'max': max_value, 'min': min_value, 'avg': avg_value, 'has_format': has_format, 'has_default': has_default} )
has_default = check_parameter_property(
function_description, key, "default"
)
stats[key] = {
"first": first,
"max": max_value,
"min": min_value,
"avg": avg_value,
"has_format": has_format,
"has_default": has_default,
}
except Exception as e: except Exception as e:
print(data) print(data)
return stats return stats
def check_parameter_property(api_description: Dict, parameter_name: str, property_name: str)-> bool:
def check_parameter_property(
api_description: Dict[str, Any], parameter_name: str, property_name: str
) -> bool:
""" """
Check if a parameter in an API description has a specific property. Check if a parameter in an API description has a specific property.
@ -127,8 +158,36 @@ def check_parameter_property(api_description: Dict, parameter_name: str, propert
return property_name in parameter_info return property_name in parameter_info
def calculate_entropy(log_probs):
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
def hallucination_detect(token:str, log_probs:List[float], current_state: Dict, entropy_thd : float= 0.7, varentropy_thd :float = 4.0) -> bool: Args:
log_probs (list of float): A list of log probabilities.
Returns:
tuple: A tuple containing:
- log_probs (list of float): The input log probabilities as a list.
- entropy (float): The calculated entropy.
- varentropy (float): The calculated variance of entropy.
"""
log_probs = torch.tensor(log_probs)
token_probs = torch.exp(log_probs)
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
varentropy = torch.sum(
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
return log_probs.tolist(), entropy.item(), varentropy.item()
def hallucination_detect(
token: str,
log_probs: List[float],
current_state: Dict[str, Any],
entropy_thd: float = 0.7,
varentropy_thd: float = 4.0,
) -> bool:
""" """
Detects hallucinations in the token sequence based on entropy and varentropy thresholds. Detects hallucinations in the token sequence based on entropy and varentropy thresholds.
@ -146,8 +205,8 @@ def hallucination_detect(token:str, log_probs:List[float], current_state: Dict,
if token: if token:
# check if there is content in token # check if there is content in token
current_state["tokens"].append(token) current_state["tokens"].append(token)
current_state['content'] += token current_state["content"] += token
current_state['logprobs'].append(log_probs) current_state["logprobs"].append(log_probs)
# keep track of entropy and varentropy # keep track of entropy and varentropy
_, entropy, varentropy = calculate_entropy(log_probs) _, entropy, varentropy = calculate_entropy(log_probs)
current_state["entropy"].append(entropy) current_state["entropy"].append(entropy)
@ -156,71 +215,148 @@ def hallucination_detect(token:str, log_probs:List[float], current_state: Dict,
if token == "<tool_call>": if token == "<tool_call>":
if entropy > entropy_thd or varentropy > varentropy_thd: if entropy > entropy_thd or varentropy > varentropy_thd:
current_state["hallucination"] = True current_state["hallucination"] = True
current_state["hallucination_message"] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}" current_state[
"hallucination_message"
] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}"
return True return True
elif token == "</tool_call>": elif token == "</tool_call>":
current_state["state"] = "tool_call_end" current_state["state"] = "tool_call_end"
# try to extract tool call, else raise error # try to extract tool call, else raise error
try: try:
current_state['tool_call'] = extract_tool_calls(current_state["content"])[0] current_state[
current_state['tool_call_process'] = True "tool_call"
] = const.arch_function_hanlder.extract_tool_calls(
current_state["content"]
)[
0
]
current_state["tool_call_process"] = True
except: except:
current_state['tool_call_process'] = False current_state["tool_call_process"] = False
print(f"cant process tool") print(f"cant process tool")
return True return True
# check if function name is valid # check if function name is valid
if current_state['tool_call']['function']['name'] not in current_state['parameter_names'].keys(): if (
current_state["tool_call"]["function"]["name"]
not in current_state["parameter_names"].keys()
):
current_state["hallucination"] = True current_state["hallucination"] = True
current_state["hallucination_message"] = f"function name {current_state['tool_call']['name']} not found" current_state[
"hallucination_message"
] = f"function name {current_state['tool_call']['name']} not found"
return True return True
# check if parameter names are from the given function tools # check if parameter names are from the given function tools
current_parameter_names = current_state['tool_call']['function']['arguments'].keys() current_parameter_names = current_state["tool_call"]["function"][
given_parameter_names = current_state['parameter_names'][current_state['tool_call']['function']['name']] "arguments"
].keys()
given_parameter_names = current_state["parameter_names"][
current_state["tool_call"]["function"]["name"]
]
if not set(current_parameter_names).issubset(given_parameter_names): if not set(current_parameter_names).issubset(given_parameter_names):
missing_keys = set(current_parameter_names) - set(given_parameter_names) missing_keys = set(current_parameter_names) - set(given_parameter_names)
current_state["hallucination"] = True current_state["hallucination"] = True
current_state["hallucination_message"] = f"parameter names {missing_keys} not found" current_state[
"hallucination_message"
] = f"parameter names {missing_keys} not found"
return True return True
# filtered special tokens that are not needed in the hallucination check for parameter values # filtered special tokens that are not needed in the hallucination check for parameter values
current_state["filtered_tokens"], current_state["filtered_entropy"] = filter_tokens_and_probs(current_state["tokens"], current_state["entropy"]) (
current_state["filtered_tokens"], current_state["filtered_varentropy"] = filter_tokens_and_probs(current_state["tokens"], current_state["varentropy"]) current_state["filtered_tokens"],
parameter_values, entropy_values = get_all_parameter_values(current_state["filtered_tokens"], current_state["filtered_entropy"], current_state['parameter_names']) current_state["filtered_entropy"],
parameter_values, varentropy_values = get_all_parameter_values(current_state["filtered_tokens"], current_state["filtered_varentropy"], current_state['parameter_names']) ) = filter_tokens_and_probs(
current_state["tokens"], current_state["entropy"]
)
(
current_state["filtered_tokens"],
current_state["filtered_varentropy"],
) = filter_tokens_and_probs(
current_state["tokens"], current_state["varentropy"]
)
parameter_values, entropy_values = get_all_parameter_values(
current_state["filtered_tokens"],
current_state["filtered_entropy"],
current_state["parameter_names"],
)
parameter_values, varentropy_values = get_all_parameter_values(
current_state["filtered_tokens"],
current_state["filtered_varentropy"],
current_state["parameter_names"],
)
current_state['parameter_values'] = parameter_values current_state["parameter_values"] = parameter_values
current_state['parameter_values_entropy'] = entropy_values current_state["parameter_values_entropy"] = entropy_values
current_state['parameter_values_varentropy'] = varentropy_values current_state["parameter_values_varentropy"] = varentropy_values
# calculate the max, first, avg of sub tokens for parameter value # calculate the max, first, avg of sub tokens for parameter value
current_state['parameter_value_entropy_stat'] = calculate_stats(current_state['parameter_values_entropy'], current_state['function_description'][0]) current_state["parameter_value_entropy_stat"] = calculate_stats(
current_state['parameter_value_varentropy_stat'] = calculate_stats(current_state['parameter_values_varentropy'], current_state['function_description'][0]) current_state["parameter_values_entropy"],
current_state["function_description"][0],
)
current_state["parameter_value_varentropy_stat"] = calculate_stats(
current_state["parameter_values_varentropy"],
current_state["function_description"][0],
)
# get map for debugging # get map for debugging
current_state['token_entropy_map'] = {x : y for x,y in zip(current_state['tokens'], current_state['entropy'])} current_state["token_entropy_map"] = {
current_state['token_varentropy_map'] = {x : y for x,y in zip(current_state['tokens'], current_state['varentropy'])} x: y for x, y in zip(current_state["tokens"], current_state["entropy"])
}
current_state["token_varentropy_map"] = {
x: y
for x, y in zip(current_state["tokens"], current_state["varentropy"])
}
# checking hallucination for parameter value # checking hallucination for parameter value
current_state['parameter_value_check'] = {x : {'hallucination': False, 'message': ''} for x in current_state['parameter_values'].keys()} current_state["parameter_value_check"] = {
for key in current_state['parameter_value_check'].keys(): x: {"hallucination": False, "message": ""}
for x in current_state["parameter_values"].keys()
}
for key in current_state["parameter_value_check"].keys():
# if parameter is given a format, check the first token # if parameter is given a format, check the first token
if current_state['parameter_value_entropy_stat'][key]['has_format']: if current_state["parameter_value_entropy_stat"][key]["has_format"]:
if current_state['parameter_value_entropy_stat'][key]['first'] > entropy_thd or current_state['parameter_value_varentropy_stat'][key]['first'] > varentropy_thd: if (
current_state['parameter_value_check'][key]['hallucination'] = True current_state["parameter_value_entropy_stat"][key]["first"]
> entropy_thd
or current_state["parameter_value_varentropy_stat"][key][
"first"
]
> varentropy_thd
):
current_state["parameter_value_check"][key][
"hallucination"
] = True
current_state["hallucination"] = True current_state["hallucination"] = True
current_state['parameter_value_check'][key]['message'] = f"parameter {key} with formatting doesn't pass threshold" current_state["parameter_value_check"][key][
"message"
] = f"parameter {key} with formatting doesn't pass threshold"
# if parameter gis given a default value, we can always use default # if parameter gis given a default value, we can always use default
elif current_state['parameter_value_entropy_stat'][key]['has_default']: elif current_state["parameter_value_entropy_stat"][key]["has_default"]:
current_state['parameter_value_check'][key]['hallucination'] = False current_state["parameter_value_check"][key]["hallucination"] = False
current_state['parameter_value_check'][key]['message'] = f"parameter {key} with default" current_state["parameter_value_check"][key][
"message"
] = f"parameter {key} with default"
# check if max sub token is > thresholds # check if max sub token is > thresholds
else: else:
if current_state['parameter_value_entropy_stat'][key]['max'] > entropy_thd or current_state['parameter_value_varentropy_stat'][key]['max'] > varentropy_thd: if (
current_state['parameter_value_check'][key]['hallucination'] = True current_state["parameter_value_entropy_stat"][key]["max"]
current_state['parameter_value_check'][key]['message'] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold" > entropy_thd
or current_state["parameter_value_varentropy_stat"][key]["max"]
> varentropy_thd
):
current_state["parameter_value_check"][key][
"hallucination"
] = True
current_state["parameter_value_check"][key][
"message"
] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold"
current_state["hallucination"] = True current_state["hallucination"] = True
if current_state["hallucination"] == True: if current_state["hallucination"] == True:
current_state["hallucination_message"] = "\n".join([current_state['parameter_value_check'][key]['message'] for key in current_state['parameter_value_check'].keys()]) current_state["hallucination_message"] = "\n".join(
[
current_state["parameter_value_check"][key]["message"]
for key in current_state["parameter_value_check"].keys()
]
)
return True return True
return False return False

View file

@ -134,4 +134,4 @@ class ArchFunctionHandler:
fixed_str += opening_bracket[unmatched_opening] fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON # Attempt to parse the corrected string to ensure its valid JSON
return fixed_str return fixed_str.replace("'", '"')

View file

@ -1,7 +1,16 @@
import json import json
from app.function_calling.hallucination_handler import hallucination_detect from app.function_calling.hallucination_handler import hallucination_detect
import pytest import pytest
import os
# Get the directory of the current file
current_dir = os.path.dirname(__file__)
# Construct the full path to the JSON file
json_file_path = os.path.join(current_dir, "test_cases.json")
with open(json_file_path) as f:
test_cases = json.load(f)
get_weather_api = { get_weather_api = {
"type": "function", "type": "function",
@ -41,9 +50,6 @@ for func in function_description:
parameters = func["parameters"]["properties"] parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys()) parameter_names[func_name] = list(parameters.keys())
with open("test_cases.json") as f:
test_cases = json.load(f)
@pytest.mark.parametrize("case", test_cases) @pytest.mark.parametrize("case", test_cases)
def test_hallucination(case): def test_hallucination(case):