plano/model_server/app/function_calling/hallucination_handler.py

241 lines
8.9 KiB
Python
Raw Normal View History

2024-11-22 11:11:26 -08:00
import json
import ast
import os
import json
2024-11-18 00:53:49 -08:00
import math
2024-11-22 11:11:26 -08:00
import torch
2024-11-18 00:53:49 -08:00
import random
2024-11-22 11:11:26 -08:00
from typing import Any, Dict, List, Tuple
import app.commons.constants as const
import itertools
2024-11-17 23:32:57 -08:00
2024-11-22 11:11:26 -08:00
def check_threshold(entropy, varentropy, thd):
2024-11-17 23:32:57 -08:00
"""
2024-11-22 11:11:26 -08:00
Check if the given entropy or variance of entropy exceeds the specified thresholds.
2024-11-17 23:32:57 -08:00
Args:
2024-11-22 11:11:26 -08:00
entropy (float): The entropy value to check.
varentropy (float): The variance of entropy value to check.
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
2024-11-17 23:32:57 -08:00
Returns:
2024-11-22 11:11:26 -08:00
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
2024-11-17 23:32:57 -08:00
"""
2024-11-22 11:11:26 -08:00
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
2024-11-17 23:32:57 -08:00
2024-11-22 11:11:26 -08:00
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
2024-11-18 00:53:49 -08:00
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
2024-11-17 23:32:57 -08:00
2024-11-18 00:53:49 -08:00
Args:
log_probs (list of float): A list of log probabilities.
Returns:
tuple: A tuple containing:
- log_probs (list of float): The input log probabilities as a list.
- entropy (float): The calculated entropy.
- varentropy (float): The calculated variance of entropy.
"""
log_probs = torch.tensor(log_probs)
token_probs = torch.exp(log_probs)
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
varentropy = torch.sum(
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
2024-11-22 11:11:26 -08:00
return entropy.item(), varentropy.item()
2024-11-18 00:53:49 -08:00
2024-11-22 11:11:26 -08:00
def check_parameter_property(api_description, parameter_name, property_name):
2024-11-17 23:32:57 -08:00
"""
2024-11-22 11:11:26 -08:00
Check if a parameter in an API description has a specific property.
2024-11-17 23:32:57 -08:00
Args:
2024-11-22 11:11:26 -08:00
api_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
property_name (str): The property to look for (e.g., 'format', 'default').
2024-11-17 23:32:57 -08:00
Returns:
2024-11-22 11:11:26 -08:00
bool: True if the parameter has the specified property, False otherwise.
"""
parameters = api_description.get("properties", {})
parameter_info = parameters.get(parameter_name, {})
return property_name in parameter_info
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
Attributes:
tokens (list): List of tokens processed.
logprobs (list): List of log probabilities for each token.
state (str): Current state of the handler.
mask (list): List of masks indicating the type of each token.
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
current_token (str): The current token being processed.
2024-11-17 23:32:57 -08:00
"""
2024-11-18 00:53:49 -08:00
2024-11-22 11:11:26 -08:00
def __init__(self):
"""
Initializes the HallucinationStateHandler with default values.
"""
self.tokens = []
self.logprobs = []
self.state = None
self.mask = []
self.parameter_name_done = False
self.hallucination = False
self.hallucination_message = ""
self.parameter_name = []
self.token_probs_map = []
self.current_token = None
def process_function(self, apis):
self.apis = apis
if self.apis is None:
raise ValueError("API descriptions not set.")
parameter_names = {}
for func in self.apis:
func_name = func["name"]
parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys())
self.function_description = parameter_names
self.function_properties = {x["name"]: x["parameters"] for x in self.apis}
def process_token(self):
"""
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.current_token == "<tool_call>":
self.mask.append("t")
self.check_logprob()
# Function name extraction logic
if self.state == "function_name":
if self.current_token not in const.FUNC_NAME_END_TOKEN:
self.mask.append("f")
else:
self.state = None
self.check_function_name()
if content.endswith(const.FUNC_NAME_START_PATTERN):
print("function name entered")
self.state = "function_name"
# Parameter name extraction logic
if self.state == "parameter_name" and not content.endswith(
const.PARAMETER_NAME_END_TOKENS
):
self.mask.append("p")
elif self.state == "parameter_name" and content.endswith(
const.PARAMETER_NAME_END_TOKENS
):
self.state = None
self.check_parameter_name()
self.parameter_name_done = True
elif self.parameter_name_done and content.endswith(
const.PARAMETER_NAME_START_PATTERN
):
self.state = "parameter_name"
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name"
# Parameter value extraction logic
if self.state == "parameter_value" and not content.endswith(
const.PARAMETER_VALUE_END_TOKEN
):
if self.current_token.strip() not in ['"', ""]:
self.mask.append("v")
if (
len(self.mask) > 1
2024-11-22 11:24:27 -08:00
and self.mask[-2] != "v"
2024-11-22 11:11:26 -08:00
and not check_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
)
):
self.check_logprob()
else:
self.mask.append("e")
elif self.state == "parameter_value" and content.endswith(
const.PARAMETER_VALUE_END_TOKEN
):
self.state = None
elif self.parameter_name_done and content.endswith(
const.PARAMETER_VALUE_START_PATTERN
):
self.state = "parameter_value"
# Maintain consistency between stack and mask
if len(self.mask) != len(self.tokens):
self.mask.append("e")
def check_logprob(self):
"""
Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy.
"""
probs = self.logprobs[-1]
entropy, varentropy = calculate_entropy(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold(
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]]
):
self.hallucination = True
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
def count_consecutive_token(self, token="v") -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
Args:
token (str): The token to count in the mask.
Returns:
int: The number of consecutive occurrences of the token.
"""
return (
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
if self.mask and self.mask[-1] == token
else 0
)
def check_function_name(self):
"""
Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found.
"""
f_len = self.count_consecutive_token("f")
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.hallucination = True
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
def check_parameter_name(self):
"""
Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found.
"""
p_len = self.count_consecutive_token("p")
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]:
self.hallucination = True
self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions."