2024-11-17 23:32:57 -08:00
import torch
import numpy as np
2024-11-18 00:53:49 -08:00
import math
import app . commons . constants as const
import random
from typing import List , Dict , Any , Tuple
import json
2024-11-17 23:32:57 -08:00
2024-11-18 00:53:49 -08:00
def filter_tokens_and_probs (
tokens : List [ str ] , probs : List [ float ]
) - > Tuple [ List [ str ] , List [ float ] ] :
2024-11-17 23:32:57 -08:00
"""
Filters out special tokens from the list of tokens and their corresponding probabilities .
Args :
tokens ( list ) : List of tokens .
probs ( list ) : List of probabilities corresponding to the tokens .
Returns :
tuple : A tuple containing two lists - filtered tokens and their corresponding probabilities .
"""
# Use regex to identify tokens without special characters
2024-11-18 00:53:49 -08:00
special_tokens = [ " \\ n " , ' { " ' , ' " : ' , ' " ' , ' " , ' , ' { " ' , ' " }} \\ n ' , " " , ' " }} \n ' ]
filtered_tokens = [ token for token in tokens if token not in special_tokens ]
2024-11-17 23:32:57 -08:00
filtered_probs = [
2024-11-18 00:53:49 -08:00
prob for token , prob in zip ( tokens , probs ) if token not in special_tokens
2024-11-17 23:32:57 -08:00
]
return filtered_tokens , filtered_probs
2024-11-18 00:53:49 -08:00
def get_all_parameter_values (
tokens : List [ str ] , probs : List [ float ] , parameter_names : Dict [ str , Any ]
) - > Tuple [ Dict [ str , Any ] , Dict [ str , Any ] ] :
2024-11-17 23:32:57 -08:00
"""
Extracts parameter values and their corresponding probabilities from the tokens .
Args :
tokens ( list ) : List of tokens .
probs ( list ) : List of probabilities corresponding to the tokens .
parameter_names ( dict ) : Dictionary of parameter names for each function .
Returns :
tuple : A tuple containing two dictionaries - parameter values and their corresponding probabilities .
"""
parameter_values = { }
probs_values = { }
i = 0
while i < len ( tokens ) :
# Try to form parameter names by combining tokens
combined_token = " "
start = i
found_param = False
# Incrementally combine tokens to find a full match with any parameter name
while i < len ( tokens ) :
if combined_token :
2024-11-18 00:53:49 -08:00
combined_token + = tokens [
i
] # Append next token to the current combination
2024-11-17 23:32:57 -08:00
else :
combined_token = tokens [ i ] # Start a new combination
# Check if the combined token matches any parameter name
for func , params in parameter_names . items ( ) :
if combined_token in params :
# Collect values associated with this parameter
values = [ ]
prob_values = [ ]
i + = 1 # Move past the parameter name
# Collect tokens as values until the next parameter or end marker
2024-11-18 00:53:49 -08:00
while (
i < len ( tokens )
and tokens [ i ] not in params
and tokens [ i ] != " </tool_call> "
) :
2024-11-17 23:32:57 -08:00
values . append ( tokens [ i ] )
prob_values . append ( probs [ i ] )
i + = 1
# Store the parameter values and probabilities
parameter_values [ combined_token ] = values
probs_values [ combined_token ] = prob_values
found_param = True
break # Stop combining further once a parameter is matched
if found_param :
break # Exit the outer loop if parameter was matched
i + = 1 # Move to the next token if no match was found yet
# Reset to the next token if no parameter match was found
if not found_param :
i = start + 1
return parameter_values , probs_values
2024-11-18 00:53:49 -08:00
def calculate_stats (
data : Dict [ str , Any ] , function_description : Dict [ str , Any ]
) - > Dict [ str , Any ] :
2024-11-17 23:32:57 -08:00
"""
Calculates statistical metrics for the given data .
Args :
data ( dict ) : Dictionary containing parameter values and their corresponding probabilities .
function_description ( dict ) : Description of the function containing parameter properties .
Returns :
dict : Dictionary containing statistical metrics for each parameter .
"""
stats = { }
try :
for key , values in data . items ( ) :
2024-11-18 00:53:49 -08:00
if len ( data [ key ] ) > = 1 :
2024-11-17 23:32:57 -08:00
first = values [ 0 ]
max_value = max ( values )
min_value = min ( values )
avg_value = sum ( values ) / len ( values )
2024-11-18 00:53:49 -08:00
has_format = check_parameter_property (
function_description , key , " format "
)
has_default = check_parameter_property (
function_description , key , " default "
)
stats [ key ] = {
" first " : first ,
" max " : max_value ,
" min " : min_value ,
" avg " : avg_value ,
" has_format " : has_format ,
" has_default " : has_default ,
}
2024-11-17 23:32:57 -08:00
except Exception as e :
print ( data )
return stats
2024-11-18 00:53:49 -08:00
def check_parameter_property (
api_description : Dict [ str , Any ] , parameter_name : str , property_name : str
) - > bool :
2024-11-17 23:32:57 -08:00
"""
Check if a parameter in an API description has a specific property .
Args :
api_description ( dict ) : The API description in JSON format .
parameter_name ( str ) : The name of the parameter to check .
property_name ( str ) : The property to look for ( e . g . , ' format ' , ' default ' ) .
Returns :
bool : True if the parameter has the specified property , False otherwise .
"""
parameters = api_description . get ( " parameters " , { } ) . get ( " properties " , { } )
parameter_info = parameters . get ( parameter_name , { } )
return property_name in parameter_info
2024-11-18 00:53:49 -08:00
def calculate_entropy ( log_probs ) :
"""
Calculate the entropy and variance of entropy ( varentropy ) from log probabilities .
2024-11-17 23:32:57 -08:00
2024-11-18 00:53:49 -08:00
Args :
log_probs ( list of float ) : A list of log probabilities .
Returns :
tuple : A tuple containing :
- log_probs ( list of float ) : The input log probabilities as a list .
- entropy ( float ) : The calculated entropy .
- varentropy ( float ) : The calculated variance of entropy .
"""
log_probs = torch . tensor ( log_probs )
token_probs = torch . exp ( log_probs )
entropy = - torch . sum ( log_probs * token_probs , dim = - 1 ) / math . log ( 2 , math . e )
varentropy = torch . sum (
token_probs * ( log_probs / math . log ( 2 , math . e ) ) + entropy . unsqueeze ( - 1 ) * * 2 ,
dim = - 1 ,
)
return log_probs . tolist ( ) , entropy . item ( ) , varentropy . item ( )
def hallucination_detect (
token : str ,
log_probs : List [ float ] ,
current_state : Dict [ str , Any ] ,
entropy_thd : float = 0.7 ,
varentropy_thd : float = 4.0 ,
) - > bool :
2024-11-17 23:32:57 -08:00
"""
Detects hallucinations in the token sequence based on entropy and varentropy thresholds .
Args :
token ( str ) : The current token .
log_probs ( list ) : List of log probabilities for the current token .
current_state ( dict ) : The current state of the detection process .
entropy_thd ( float ) : Entropy threshold for detecting hallucinations .
varentropy_thd ( float ) : Variance of entropy threshold for detecting hallucinations .
Returns :
bool : True if a hallucination is detected , False otherwise .
"""
2024-11-18 00:53:49 -08:00
2024-11-17 23:32:57 -08:00
if token :
# check if there is content in token
current_state [ " tokens " ] . append ( token )
2024-11-18 00:53:49 -08:00
current_state [ " content " ] + = token
current_state [ " logprobs " ] . append ( log_probs )
2024-11-17 23:32:57 -08:00
# keep track of entropy and varentropy
_ , entropy , varentropy = calculate_entropy ( log_probs )
current_state [ " entropy " ] . append ( entropy )
current_state [ " varentropy " ] . append ( varentropy )
# first check if tool call token is certain
if token == " <tool_call> " :
if entropy > entropy_thd or varentropy > varentropy_thd :
current_state [ " hallucination " ] = True
2024-11-18 00:53:49 -08:00
current_state [
" hallucination_message "
] = f " { token } with entropy { entropy } , varentropy { varentropy } doesn ' t pass the threshold { entropy_thd } | { varentropy_thd } "
2024-11-17 23:32:57 -08:00
return True
elif token == " </tool_call> " :
current_state [ " state " ] = " tool_call_end "
# try to extract tool call, else raise error
try :
2024-11-18 00:53:49 -08:00
current_state [
" tool_call "
] = const . arch_function_hanlder . extract_tool_calls (
current_state [ " content " ]
) [
0
]
current_state [ " tool_call_process " ] = True
2024-11-17 23:32:57 -08:00
except :
2024-11-18 00:53:49 -08:00
current_state [ " tool_call_process " ] = False
2024-11-17 23:32:57 -08:00
print ( f " cant process tool " )
return True
# check if function name is valid
2024-11-18 00:53:49 -08:00
if (
current_state [ " tool_call " ] [ " function " ] [ " name " ]
not in current_state [ " parameter_names " ] . keys ( )
) :
2024-11-17 23:32:57 -08:00
current_state [ " hallucination " ] = True
2024-11-18 00:53:49 -08:00
current_state [
" hallucination_message "
] = f " function name { current_state [ ' tool_call ' ] [ ' name ' ] } not found "
2024-11-17 23:32:57 -08:00
return True
# check if parameter names are from the given function tools
2024-11-18 00:53:49 -08:00
current_parameter_names = current_state [ " tool_call " ] [ " function " ] [
" arguments "
] . keys ( )
given_parameter_names = current_state [ " parameter_names " ] [
current_state [ " tool_call " ] [ " function " ] [ " name " ]
]
2024-11-17 23:32:57 -08:00
if not set ( current_parameter_names ) . issubset ( given_parameter_names ) :
missing_keys = set ( current_parameter_names ) - set ( given_parameter_names )
current_state [ " hallucination " ] = True
2024-11-18 00:53:49 -08:00
current_state [
" hallucination_message "
] = f " parameter names { missing_keys } not found "
2024-11-17 23:32:57 -08:00
return True
# filtered special tokens that are not needed in the hallucination check for parameter values
2024-11-18 00:53:49 -08:00
(
current_state [ " filtered_tokens " ] ,
current_state [ " filtered_entropy " ] ,
) = filter_tokens_and_probs (
current_state [ " tokens " ] , current_state [ " entropy " ]
)
(
current_state [ " filtered_tokens " ] ,
current_state [ " filtered_varentropy " ] ,
) = filter_tokens_and_probs (
current_state [ " tokens " ] , current_state [ " varentropy " ]
)
parameter_values , entropy_values = get_all_parameter_values (
current_state [ " filtered_tokens " ] ,
current_state [ " filtered_entropy " ] ,
current_state [ " parameter_names " ] ,
)
parameter_values , varentropy_values = get_all_parameter_values (
current_state [ " filtered_tokens " ] ,
current_state [ " filtered_varentropy " ] ,
current_state [ " parameter_names " ] ,
)
current_state [ " parameter_values " ] = parameter_values
current_state [ " parameter_values_entropy " ] = entropy_values
current_state [ " parameter_values_varentropy " ] = varentropy_values
2024-11-17 23:32:57 -08:00
# calculate the max, first, avg of sub tokens for parameter value
2024-11-18 00:53:49 -08:00
current_state [ " parameter_value_entropy_stat " ] = calculate_stats (
current_state [ " parameter_values_entropy " ] ,
current_state [ " function_description " ] [ 0 ] ,
)
current_state [ " parameter_value_varentropy_stat " ] = calculate_stats (
current_state [ " parameter_values_varentropy " ] ,
current_state [ " function_description " ] [ 0 ] ,
)
2024-11-17 23:32:57 -08:00
# get map for debugging
2024-11-18 00:53:49 -08:00
current_state [ " token_entropy_map " ] = {
x : y for x , y in zip ( current_state [ " tokens " ] , current_state [ " entropy " ] )
}
current_state [ " token_varentropy_map " ] = {
x : y
for x , y in zip ( current_state [ " tokens " ] , current_state [ " varentropy " ] )
}
2024-11-17 23:32:57 -08:00
# checking hallucination for parameter value
2024-11-18 00:53:49 -08:00
current_state [ " parameter_value_check " ] = {
x : { " hallucination " : False , " message " : " " }
for x in current_state [ " parameter_values " ] . keys ( )
}
for key in current_state [ " parameter_value_check " ] . keys ( ) :
2024-11-17 23:32:57 -08:00
# if parameter is given a format, check the first token
2024-11-18 00:53:49 -08:00
if current_state [ " parameter_value_entropy_stat " ] [ key ] [ " has_format " ] :
if (
current_state [ " parameter_value_entropy_stat " ] [ key ] [ " first " ]
> entropy_thd
or current_state [ " parameter_value_varentropy_stat " ] [ key ] [
" first "
]
> varentropy_thd
) :
current_state [ " parameter_value_check " ] [ key ] [
" hallucination "
] = True
2024-11-17 23:32:57 -08:00
current_state [ " hallucination " ] = True
2024-11-18 00:53:49 -08:00
current_state [ " parameter_value_check " ] [ key ] [
" message "
] = f " parameter { key } with formatting doesn ' t pass threshold "
2024-11-17 23:32:57 -08:00
# if parameter gis given a default value, we can always use default
2024-11-18 00:53:49 -08:00
elif current_state [ " parameter_value_entropy_stat " ] [ key ] [ " has_default " ] :
current_state [ " parameter_value_check " ] [ key ] [ " hallucination " ] = False
current_state [ " parameter_value_check " ] [ key ] [
" message "
] = f " parameter { key } with default "
2024-11-17 23:32:57 -08:00
# check if max sub token is > thresholds
else :
2024-11-18 00:53:49 -08:00
if (
current_state [ " parameter_value_entropy_stat " ] [ key ] [ " max " ]
> entropy_thd
or current_state [ " parameter_value_varentropy_stat " ] [ key ] [ " max " ]
> varentropy_thd
) :
current_state [ " parameter_value_check " ] [ key ] [
" hallucination "
] = True
current_state [ " parameter_value_check " ] [ key ] [
" message "
] = f " parameter { key } with { current_state [ ' parameter_value_entropy_stat ' ] [ key ] [ ' max ' ] } and { current_state [ ' parameter_value_varentropy_stat ' ] [ key ] [ ' max ' ] } doesnt pass threshold "
2024-11-17 23:32:57 -08:00
current_state [ " hallucination " ] = True
if current_state [ " hallucination " ] == True :
2024-11-18 00:53:49 -08:00
current_state [ " hallucination_message " ] = " \n " . join (
[
current_state [ " parameter_value_check " ] [ key ] [ " message " ]
for key in current_state [ " parameter_value_check " ] . keys ( )
]
)
2024-11-17 23:32:57 -08:00
return True
2024-11-18 00:53:49 -08:00
return False