move constatns

This commit is contained in:
cotran 2024-11-26 17:10:29 -08:00
parent 075c94fd39
commit 4e7572f501

View file

@ -3,8 +3,37 @@ import math
import torch
import random
from typing import Any, Dict, List, Tuple
import app.commons.constants as const
import itertools
from enum import Enum
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
# Thresholds
class MaskToken(Enum):
FUNCTION_NAME = "f"
PARAMETER_VALUE = "v"
PARAMETER_NAME = "p"
NOT_USED = "e"
TOOL_CALL = "t"
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5,
"varentropy": 2.5,
},
}
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
@ -159,59 +188,59 @@ class HallucinationStateHandler:
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == const.TOOL_CALL_TOKEN:
self.mask.append(const.MaskToken.TOOL_CALL)
if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name":
if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN:
self.mask.append(const.MaskToken.FUNCTION_NAME)
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
self._is_function_name_hallucinated()
# Check if the token is a function name start token, change the state
if content.endswith(const.FUNC_NAME_START_PATTERN):
if content.endswith(FUNC_NAME_START_PATTERN):
self.state = "function_name"
# Parameter name extraction logic
# if the state is parameter name and the token is not an end token, add to the mask
if self.state == "parameter_name" and not content.endswith(
const.PARAMETER_NAME_END_TOKENS
PARAMETER_NAME_END_TOKENS
):
self.mask.append(const.MaskToken.PARAMETER_NAME)
self.mask.append(MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith(
const.PARAMETER_NAME_END_TOKENS
PARAMETER_NAME_END_TOKENS
):
self.state = None
self._is_parameter_name_hallucinated()
self.parameter_name_done = True
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
const.PARAMETER_NAME_START_PATTERN
PARAMETER_NAME_START_PATTERN
):
self.state = "parameter_name"
# if token is a first parameter value start token, change the state
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN):
if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name"
# Parameter value extraction logic
# if the state is parameter value and the token is not an end token, add to the mask
if self.state == "parameter_value" and not content.endswith(
const.PARAMETER_VALUE_END_TOKEN
PARAMETER_VALUE_END_TOKEN
):
# checking if the token is a value token and is not empty
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(const.MaskToken.PARAMETER_VALUE)
self.mask.append(MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != const.MaskToken.PARAMETER_VALUE
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
@ -220,22 +249,22 @@ class HallucinationStateHandler:
):
self._check_logprob()
else:
self.mask.append(const.MaskToken.NOT_USED)
self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith(
const.PARAMETER_VALUE_END_TOKEN
PARAMETER_VALUE_END_TOKEN
):
self.state = None
# if the parameter name is done and the token is a parameter value start token, change the state
elif self.parameter_name_done and content.endswith(
const.PARAMETER_VALUE_START_PATTERN
PARAMETER_VALUE_START_PATTERN
):
self.state = "parameter_value"
# Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens):
self.mask.append(const.MaskToken.NOT_USED)
self.mask.append(MaskToken.NOT_USED)
def _check_logprob(self):
"""
@ -247,7 +276,7 @@ class HallucinationStateHandler:
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold(
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
):
self.hallucination = True
self.error_type = "Hallucination"
@ -255,7 +284,7 @@ class HallucinationStateHandler:
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
)
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
@ -276,7 +305,7 @@ class HallucinationStateHandler:
Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found.
"""
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.error_type = "function_name"
@ -287,7 +316,7 @@ class HallucinationStateHandler:
Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found.
"""
p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME)
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]: