move constatns

This commit is contained in:
cotran 2024-11-26 17:10:29 -08:00
parent 075c94fd39
commit 4e7572f501

View file

@ -3,8 +3,37 @@ import math
import torch import torch
import random import random
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import app.commons.constants as const
import itertools import itertools
from enum import Enum
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
# Thresholds
class MaskToken(Enum):
FUNCTION_NAME = "f"
PARAMETER_VALUE = "v"
PARAMETER_NAME = "p"
NOT_USED = "e"
TOOL_CALL = "t"
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5,
"varentropy": 2.5,
},
}
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
@ -159,59 +188,59 @@ class HallucinationStateHandler:
Detects hallucinations based on the token type and log probabilities. Detects hallucinations based on the token type and log probabilities.
""" """
content = "".join(self.tokens).replace(" ", "") content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == const.TOOL_CALL_TOKEN: if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(const.MaskToken.TOOL_CALL) self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob() self._check_logprob()
# Function name extraction logic # Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask # If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name": if self.state == "function_name":
if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN: if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(const.MaskToken.FUNCTION_NAME) self.mask.append(MaskToken.FUNCTION_NAME)
else: else:
self.state = None self.state = None
self._is_function_name_hallucinated() self._is_function_name_hallucinated()
# Check if the token is a function name start token, change the state # Check if the token is a function name start token, change the state
if content.endswith(const.FUNC_NAME_START_PATTERN): if content.endswith(FUNC_NAME_START_PATTERN):
self.state = "function_name" self.state = "function_name"
# Parameter name extraction logic # Parameter name extraction logic
# if the state is parameter name and the token is not an end token, add to the mask # if the state is parameter name and the token is not an end token, add to the mask
if self.state == "parameter_name" and not content.endswith( if self.state == "parameter_name" and not content.endswith(
const.PARAMETER_NAME_END_TOKENS PARAMETER_NAME_END_TOKENS
): ):
self.mask.append(const.MaskToken.PARAMETER_NAME) self.mask.append(MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done # if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern # The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith( elif self.state == "parameter_name" and content.endswith(
const.PARAMETER_NAME_END_TOKENS PARAMETER_NAME_END_TOKENS
): ):
self.state = None self.state = None
self._is_parameter_name_hallucinated() self._is_parameter_name_hallucinated()
self.parameter_name_done = True self.parameter_name_done = True
# if the parameter name is done and the token is a parameter name start token, change the state # if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith( elif self.parameter_name_done and content.endswith(
const.PARAMETER_NAME_START_PATTERN PARAMETER_NAME_START_PATTERN
): ):
self.state = "parameter_name" self.state = "parameter_name"
# if token is a first parameter value start token, change the state # if token is a first parameter value start token, change the state
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN): if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name" self.state = "parameter_name"
# Parameter value extraction logic # Parameter value extraction logic
# if the state is parameter value and the token is not an end token, add to the mask # if the state is parameter value and the token is not an end token, add to the mask
if self.state == "parameter_value" and not content.endswith( if self.state == "parameter_value" and not content.endswith(
const.PARAMETER_VALUE_END_TOKEN PARAMETER_VALUE_END_TOKEN
): ):
# checking if the token is a value token and is not empty # checking if the token is a value token and is not empty
if self.tokens[-1].strip() not in ['"', ""]: if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(const.MaskToken.PARAMETER_VALUE) self.mask.append(MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token # checking if the parameter doesn't have default and the token is the first parameter value token
if ( if (
len(self.mask) > 1 len(self.mask) > 1
and self.mask[-2] != const.MaskToken.PARAMETER_VALUE and self.mask[-2] != MaskToken.PARAMETER_VALUE
and not is_parameter_property( and not is_parameter_property(
self.function_properties[self.function_name], self.function_properties[self.function_name],
self.parameter_name[-1], self.parameter_name[-1],
@ -220,22 +249,22 @@ class HallucinationStateHandler:
): ):
self._check_logprob() self._check_logprob()
else: else:
self.mask.append(const.MaskToken.NOT_USED) self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state # if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith( elif self.state == "parameter_value" and content.endswith(
const.PARAMETER_VALUE_END_TOKEN PARAMETER_VALUE_END_TOKEN
): ):
self.state = None self.state = None
# if the parameter name is done and the token is a parameter value start token, change the state # if the parameter name is done and the token is a parameter value start token, change the state
elif self.parameter_name_done and content.endswith( elif self.parameter_name_done and content.endswith(
const.PARAMETER_VALUE_START_PATTERN PARAMETER_VALUE_START_PATTERN
): ):
self.state = "parameter_value" self.state = "parameter_value"
# Maintain consistency between stack and mask # Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask # If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens): if len(self.mask) != len(self.tokens):
self.mask.append(const.MaskToken.NOT_USED) self.mask.append(MaskToken.NOT_USED)
def _check_logprob(self): def _check_logprob(self):
""" """
@ -247,7 +276,7 @@ class HallucinationStateHandler:
self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold( if check_threshold(
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
): ):
self.hallucination = True self.hallucination = True
self.error_type = "Hallucination" self.error_type = "Hallucination"
@ -255,7 +284,7 @@ class HallucinationStateHandler:
f"Hallucination: token '{self.tokens[-1]}' is uncertain." f"Hallucination: token '{self.tokens[-1]}' is uncertain."
) )
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int: def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
""" """
Counts the number of consecutive occurrences of a given token in the mask. Counts the number of consecutive occurrences of a given token in the mask.
@ -276,7 +305,7 @@ class HallucinationStateHandler:
Checks the extracted function name against the function descriptions. Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found. Detects hallucinations if the function name is not found.
""" """
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME) f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:]) self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys(): if self.function_name not in self.function_description.keys():
self.error_type = "function_name" self.error_type = "function_name"
@ -287,7 +316,7 @@ class HallucinationStateHandler:
Checks the extracted parameter name against the function descriptions. Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found. Detects hallucinations if the parameter name is not found.
""" """
p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME) p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:]) parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name) self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]: if parameter_name not in self.function_description[self.function_name]: