This commit is contained in:
cotran 2024-11-26 12:13:21 -08:00
parent 673b187eb5
commit 075c94fd39

View file

@ -1,14 +1,10 @@
import json
import ast
import os
import json
import math
import torch
import random
from typing import Any, Dict, List, Tuple
import app.commons.constants as const
import itertools
from enum import Enum
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
@ -84,24 +80,22 @@ class HallucinationStateHandler:
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
current_token (str): The current token being processed.
"""
def __init__(self, response_iterator=None, function=None):
"""
Initializes the HallucinationStateHandler with default values.
"""
self.tokens = []
self.logprobs = []
self.state = None
self.mask = []
self.parameter_name_done = False
self.hallucination = False
self.hallucination_message = ""
self.parameter_name = []
self.token_probs_map = []
self.current_token = None
self.tokens: List[str] = []
self.logprobs: List[float] = []
self.state: str = None
self.mask: List[str] = []
self.parameter_name_done: bool = False
self.hallucination: bool = False
self.error_message: str = ""
self.error_type: str = ""
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self._process_function(function)
@ -128,7 +122,6 @@ class HallucinationStateHandler:
Returns:
bool: True if the token is hallucinated, False otherwise.
"""
self.current_token = token
self.tokens.append(token)
self.logprobs.append(logprob)
self._process_token()
@ -166,14 +159,14 @@ class HallucinationStateHandler:
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.current_token == const.TOOL_CALL_TOKEN:
if self.tokens[-1] == const.TOOL_CALL_TOKEN:
self.mask.append(const.MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name":
if self.current_token not in const.FUNC_NAME_END_TOKEN:
if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN:
self.mask.append(const.MaskToken.FUNCTION_NAME)
else:
self.state = None
@ -181,7 +174,6 @@ class HallucinationStateHandler:
# Check if the token is a function name start token, change the state
if content.endswith(const.FUNC_NAME_START_PATTERN):
print("function name entered")
self.state = "function_name"
# Parameter name extraction logic
@ -214,7 +206,7 @@ class HallucinationStateHandler:
const.PARAMETER_VALUE_END_TOKEN
):
# checking if the token is a value token and is not empty
if self.current_token.strip() not in ['"', ""]:
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(const.MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
@ -258,7 +250,10 @@ class HallucinationStateHandler:
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
):
self.hallucination = True
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
self.error_type = "Hallucination"
self.error_message = (
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
)
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
"""
@ -284,8 +279,8 @@ class HallucinationStateHandler:
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.hallucination = True
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
self.error_type = "function_name"
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
def _is_parameter_name_hallucinated(self):
"""
@ -296,5 +291,5 @@ class HallucinationStateHandler:
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]:
self.hallucination = True
self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
self.error_type = "parameter_name"
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."