This commit is contained in:
cotran 2024-11-26 12:13:21 -08:00
parent 673b187eb5
commit 075c94fd39

View file

@ -1,14 +1,10 @@
import json import json
import ast
import os
import json
import math import math
import torch import torch
import random import random
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import app.commons.constants as const import app.commons.constants as const
import itertools import itertools
from enum import Enum
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
@ -84,24 +80,22 @@ class HallucinationStateHandler:
parameter_name (list): List of extracted parameter names. parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters. function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy. token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
current_token (str): The current token being processed.
""" """
def __init__(self, response_iterator=None, function=None): def __init__(self, response_iterator=None, function=None):
""" """
Initializes the HallucinationStateHandler with default values. Initializes the HallucinationStateHandler with default values.
""" """
self.tokens = [] self.tokens: List[str] = []
self.logprobs = [] self.logprobs: List[float] = []
self.state = None self.state: str = None
self.mask = [] self.mask: List[str] = []
self.parameter_name_done = False self.parameter_name_done: bool = False
self.hallucination = False self.hallucination: bool = False
self.hallucination_message = "" self.error_message: str = ""
self.parameter_name = [] self.error_type: str = ""
self.parameter_name: List[str] = []
self.token_probs_map = [] self.token_probs_map: List[Tuple[str, float, float]] = []
self.current_token = None
self.response_iterator = response_iterator self.response_iterator = response_iterator
self._process_function(function) self._process_function(function)
@ -128,7 +122,6 @@ class HallucinationStateHandler:
Returns: Returns:
bool: True if the token is hallucinated, False otherwise. bool: True if the token is hallucinated, False otherwise.
""" """
self.current_token = token
self.tokens.append(token) self.tokens.append(token)
self.logprobs.append(logprob) self.logprobs.append(logprob)
self._process_token() self._process_token()
@ -166,14 +159,14 @@ class HallucinationStateHandler:
Detects hallucinations based on the token type and log probabilities. Detects hallucinations based on the token type and log probabilities.
""" """
content = "".join(self.tokens).replace(" ", "") content = "".join(self.tokens).replace(" ", "")
if self.current_token == const.TOOL_CALL_TOKEN: if self.tokens[-1] == const.TOOL_CALL_TOKEN:
self.mask.append(const.MaskToken.TOOL_CALL) self.mask.append(const.MaskToken.TOOL_CALL)
self._check_logprob() self._check_logprob()
# Function name extraction logic # Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask # If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name": if self.state == "function_name":
if self.current_token not in const.FUNC_NAME_END_TOKEN: if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN:
self.mask.append(const.MaskToken.FUNCTION_NAME) self.mask.append(const.MaskToken.FUNCTION_NAME)
else: else:
self.state = None self.state = None
@ -181,7 +174,6 @@ class HallucinationStateHandler:
# Check if the token is a function name start token, change the state # Check if the token is a function name start token, change the state
if content.endswith(const.FUNC_NAME_START_PATTERN): if content.endswith(const.FUNC_NAME_START_PATTERN):
print("function name entered")
self.state = "function_name" self.state = "function_name"
# Parameter name extraction logic # Parameter name extraction logic
@ -214,7 +206,7 @@ class HallucinationStateHandler:
const.PARAMETER_VALUE_END_TOKEN const.PARAMETER_VALUE_END_TOKEN
): ):
# checking if the token is a value token and is not empty # checking if the token is a value token and is not empty
if self.current_token.strip() not in ['"', ""]: if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(const.MaskToken.PARAMETER_VALUE) self.mask.append(const.MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token # checking if the parameter doesn't have default and the token is the first parameter value token
if ( if (
@ -258,7 +250,10 @@ class HallucinationStateHandler:
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
): ):
self.hallucination = True self.hallucination = True
self.hallucination_message = f"Token '{self.current_token}' is uncertain." self.error_type = "Hallucination"
self.error_message = (
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
)
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int: def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
""" """
@ -284,8 +279,8 @@ class HallucinationStateHandler:
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME) f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:]) self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys(): if self.function_name not in self.function_description.keys():
self.hallucination = True self.error_type = "function_name"
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
def _is_parameter_name_hallucinated(self): def _is_parameter_name_hallucinated(self):
""" """
@ -296,5 +291,5 @@ class HallucinationStateHandler:
parameter_name = "".join(self.tokens[:-1][-p_len:]) parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name) self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]: if parameter_name not in self.function_description[self.function_name]:
self.hallucination = True self.error_type = "parameter_name"
self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions." self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."