address comments

This commit is contained in:
cotran 2024-11-26 09:45:27 -08:00
parent e44d189d86
commit e30bbe39e7
3 changed files with 51 additions and 32 deletions

View file

@ -8,6 +8,7 @@ import random
from typing import Any, Dict, List, Tuple
import app.commons.constants as const
import itertools
from enum import Enum
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
@ -102,9 +103,9 @@ class HallucinationStateHandler:
self.token_probs_map = []
self.current_token = None
self.response_iterator = response_iterator
self.process_function(function)
self._process_function(function)
def process_function(self, function):
def _process_function(self, function):
self.function = function
if self.function is None:
raise ValueError("API descriptions not set.")
@ -116,7 +117,7 @@ class HallucinationStateHandler:
self.function_description = parameter_names
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
def check_token_hallucination(self, token, logprob):
def append_and_check_token_hallucination(self, token, logprob):
"""
Check if the given token is hallucinated based on the log probability.
@ -130,7 +131,7 @@ class HallucinationStateHandler:
self.current_token = token
self.tokens.append(token)
self.logprobs.append(logprob)
self.process_token()
self._process_token()
return self.hallucination
def __iter__(self):
@ -143,33 +144,40 @@ class HallucinationStateHandler:
if hasattr(r.choices[0].delta, "content"):
token_content = r.choices[0].delta.content
if token_content:
logprobs = [
p.logprob
for p in r.choices[0].logprobs.content[0].top_logprobs
]
self.check_token_hallucination(token_content, logprobs)
try:
logprobs = [
p.logprob
for p in r.choices[0].logprobs.content[0].top_logprobs
]
except Exception as e:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content
except StopIteration:
raise StopIteration
def process_token(self):
def _process_token(self):
"""
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.current_token == "<tool_call>":
self.mask.append("t")
self.check_logprob()
if self.current_token == const.TOOL_CALL_TOKEN:
self.mask.append(const.MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name":
if self.current_token not in const.FUNC_NAME_END_TOKEN:
self.mask.append("f")
self.mask.append(const.MaskToken.FUNCTION_NAME)
else:
self.state = None
self.is_function_name_hallucinated()
self._is_function_name_hallucinated()
# Check if the token is a function name start token, change the state
if content.endswith(const.FUNC_NAME_START_PATTERN):
@ -181,14 +189,14 @@ class HallucinationStateHandler:
if self.state == "parameter_name" and not content.endswith(
const.PARAMETER_NAME_END_TOKENS
):
self.mask.append("p")
self.mask.append(const.MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith(
const.PARAMETER_NAME_END_TOKENS
):
self.state = None
self.is_parameter_name_hallucinated()
self._is_parameter_name_hallucinated()
self.parameter_name_done = True
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
@ -207,20 +215,20 @@ class HallucinationStateHandler:
):
# checking if the token is a value token and is not empty
if self.current_token.strip() not in ['"', ""]:
self.mask.append("v")
self.mask.append(const.MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != "v"
and self.mask[-2] != const.MaskToken.PARAMETER_VALUE
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
)
):
self.check_logprob()
self._check_logprob()
else:
self.mask.append("e")
self.mask.append(const.MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith(
const.PARAMETER_VALUE_END_TOKEN
@ -235,9 +243,9 @@ class HallucinationStateHandler:
# Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens):
self.mask.append("e")
self.mask.append(const.MaskToken.NOT_USED)
def check_logprob(self):
def _check_logprob(self):
"""
Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy.
@ -247,12 +255,12 @@ class HallucinationStateHandler:
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold(
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]]
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
):
self.hallucination = True
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
def count_consecutive_token(self, token="v") -> int:
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
@ -268,23 +276,23 @@ class HallucinationStateHandler:
else 0
)
def is_function_name_hallucinated(self):
def _is_function_name_hallucinated(self):
"""
Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found.
"""
f_len = self.count_consecutive_token("f")
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.hallucination = True
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
def is_parameter_name_hallucinated(self):
def _is_parameter_name_hallucinated(self):
"""
Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found.
"""
p_len = self.count_consecutive_token("p")
p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]: