diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index 136a842a..544782fb 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -3,8 +3,37 @@ import math import torch import random from typing import Any, Dict, List, Tuple -import app.commons.constants as const import itertools +from enum import Enum + +# constants +FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") +FUNC_NAME_END_TOKEN = ('",', "',") +TOOL_CALL_TOKEN = "" + +FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") +PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") +PARAMETER_NAME_START_PATTERN = (',"', ",'") +PARAMETER_VALUE_START_PATTERN = ('":', "':") +PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") + + +# Thresholds +class MaskToken(Enum): + FUNCTION_NAME = "f" + PARAMETER_VALUE = "v" + PARAMETER_NAME = "p" + NOT_USED = "e" + TOOL_CALL = "t" + + +HALLUCINATION_THRESHOLD_DICT = { + MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5}, + MaskToken.PARAMETER_VALUE.value: { + "entropy": 0.5, + "varentropy": 2.5, + }, +} def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: @@ -159,59 +188,59 @@ class HallucinationStateHandler: Detects hallucinations based on the token type and log probabilities. """ content = "".join(self.tokens).replace(" ", "") - if self.tokens[-1] == const.TOOL_CALL_TOKEN: - self.mask.append(const.MaskToken.TOOL_CALL) + if self.tokens[-1] == TOOL_CALL_TOKEN: + self.mask.append(MaskToken.TOOL_CALL) self._check_logprob() # Function name extraction logic # If the state is function name and the token is not an end token, add to the mask if self.state == "function_name": - if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN: - self.mask.append(const.MaskToken.FUNCTION_NAME) + if self.tokens[-1] not in FUNC_NAME_END_TOKEN: + self.mask.append(MaskToken.FUNCTION_NAME) else: self.state = None self._is_function_name_hallucinated() # Check if the token is a function name start token, change the state - if content.endswith(const.FUNC_NAME_START_PATTERN): + if content.endswith(FUNC_NAME_START_PATTERN): self.state = "function_name" # Parameter name extraction logic # if the state is parameter name and the token is not an end token, add to the mask if self.state == "parameter_name" and not content.endswith( - const.PARAMETER_NAME_END_TOKENS + PARAMETER_NAME_END_TOKENS ): - self.mask.append(const.MaskToken.PARAMETER_NAME) + self.mask.append(MaskToken.PARAMETER_NAME) # if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done # The need for parameter name done is to allow the check of parameter value pattern elif self.state == "parameter_name" and content.endswith( - const.PARAMETER_NAME_END_TOKENS + PARAMETER_NAME_END_TOKENS ): self.state = None self._is_parameter_name_hallucinated() self.parameter_name_done = True # if the parameter name is done and the token is a parameter name start token, change the state elif self.parameter_name_done and content.endswith( - const.PARAMETER_NAME_START_PATTERN + PARAMETER_NAME_START_PATTERN ): self.state = "parameter_name" # if token is a first parameter value start token, change the state - if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN): + if content.endswith(FIRST_PARAM_NAME_START_PATTERN): self.state = "parameter_name" # Parameter value extraction logic # if the state is parameter value and the token is not an end token, add to the mask if self.state == "parameter_value" and not content.endswith( - const.PARAMETER_VALUE_END_TOKEN + PARAMETER_VALUE_END_TOKEN ): # checking if the token is a value token and is not empty if self.tokens[-1].strip() not in ['"', ""]: - self.mask.append(const.MaskToken.PARAMETER_VALUE) + self.mask.append(MaskToken.PARAMETER_VALUE) # checking if the parameter doesn't have default and the token is the first parameter value token if ( len(self.mask) > 1 - and self.mask[-2] != const.MaskToken.PARAMETER_VALUE + and self.mask[-2] != MaskToken.PARAMETER_VALUE and not is_parameter_property( self.function_properties[self.function_name], self.parameter_name[-1], @@ -220,22 +249,22 @@ class HallucinationStateHandler: ): self._check_logprob() else: - self.mask.append(const.MaskToken.NOT_USED) + self.mask.append(MaskToken.NOT_USED) # if the state is parameter value and the token is an end token, change the state elif self.state == "parameter_value" and content.endswith( - const.PARAMETER_VALUE_END_TOKEN + PARAMETER_VALUE_END_TOKEN ): self.state = None # if the parameter name is done and the token is a parameter value start token, change the state elif self.parameter_name_done and content.endswith( - const.PARAMETER_VALUE_START_PATTERN + PARAMETER_VALUE_START_PATTERN ): self.state = "parameter_value" # Maintain consistency between stack and mask # If the mask length is less than tokens, add an not used (e) token to the mask if len(self.mask) != len(self.tokens): - self.mask.append(const.MaskToken.NOT_USED) + self.mask.append(MaskToken.NOT_USED) def _check_logprob(self): """ @@ -247,7 +276,7 @@ class HallucinationStateHandler: self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) if check_threshold( - entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] + entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] ): self.hallucination = True self.error_type = "Hallucination" @@ -255,7 +284,7 @@ class HallucinationStateHandler: f"Hallucination: token '{self.tokens[-1]}' is uncertain." ) - def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int: + def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int: """ Counts the number of consecutive occurrences of a given token in the mask. @@ -276,7 +305,7 @@ class HallucinationStateHandler: Checks the extracted function name against the function descriptions. Detects hallucinations if the function name is not found. """ - f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME) + f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME) self.function_name = "".join(self.tokens[:-1][-f_len:]) if self.function_name not in self.function_description.keys(): self.error_type = "function_name" @@ -287,7 +316,7 @@ class HallucinationStateHandler: Checks the extracted parameter name against the function descriptions. Detects hallucinations if the parameter name is not found. """ - p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME) + p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME) parameter_name = "".join(self.tokens[:-1][-p_len:]) self.parameter_name.append(parameter_name) if parameter_name not in self.function_description[self.function_name]: