address comments

This commit is contained in:
cotran 2024-11-26 09:45:27 -08:00
parent e44d189d86
commit e30bbe39e7
3 changed files with 51 additions and 32 deletions

View file

@ -4,6 +4,7 @@ import app.loader as loader
from app.function_calling.model_handler import ArchFunctionHandler from app.function_calling.model_handler import ArchFunctionHandler
from app.prompt_guard.model_handler import ArchGuardHanlder from app.prompt_guard.model_handler import ArchGuardHanlder
from enum import Enum
logger = utils.get_model_server_logger() logger = utils.get_model_server_logger()
@ -38,6 +39,7 @@ arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
# Patterns for function name and parameter parsing # Patterns for function name and parameter parsing
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'") FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',") FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
@ -45,10 +47,19 @@ PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':") PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
# Thresholds # Thresholds
class MaskToken(Enum):
FUNCTION_NAME = "f"
PARAMETER_VALUE = "v"
PARAMETER_NAME = "p"
NOT_USED = "e"
TOOL_CALL = "t"
HALLUCINATION_THRESHOLD_DICT = { HALLUCINATION_THRESHOLD_DICT = {
"t": {"entropy": 0.1, "varentropy": 0.5}, MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
"v": { MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5, "entropy": 0.5,
"varentropy": 2.5, "varentropy": 2.5,
}, },

View file

@ -8,6 +8,7 @@ import random
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
import app.commons.constants as const import app.commons.constants as const
import itertools import itertools
from enum import Enum
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
@ -102,9 +103,9 @@ class HallucinationStateHandler:
self.token_probs_map = [] self.token_probs_map = []
self.current_token = None self.current_token = None
self.response_iterator = response_iterator self.response_iterator = response_iterator
self.process_function(function) self._process_function(function)
def process_function(self, function): def _process_function(self, function):
self.function = function self.function = function
if self.function is None: if self.function is None:
raise ValueError("API descriptions not set.") raise ValueError("API descriptions not set.")
@ -116,7 +117,7 @@ class HallucinationStateHandler:
self.function_description = parameter_names self.function_description = parameter_names
self.function_properties = {x["name"]: x["parameters"] for x in self.function} self.function_properties = {x["name"]: x["parameters"] for x in self.function}
def check_token_hallucination(self, token, logprob): def append_and_check_token_hallucination(self, token, logprob):
""" """
Check if the given token is hallucinated based on the log probability. Check if the given token is hallucinated based on the log probability.
@ -130,7 +131,7 @@ class HallucinationStateHandler:
self.current_token = token self.current_token = token
self.tokens.append(token) self.tokens.append(token)
self.logprobs.append(logprob) self.logprobs.append(logprob)
self.process_token() self._process_token()
return self.hallucination return self.hallucination
def __iter__(self): def __iter__(self):
@ -143,33 +144,40 @@ class HallucinationStateHandler:
if hasattr(r.choices[0].delta, "content"): if hasattr(r.choices[0].delta, "content"):
token_content = r.choices[0].delta.content token_content = r.choices[0].delta.content
if token_content: if token_content:
logprobs = [ try:
p.logprob logprobs = [
for p in r.choices[0].logprobs.content[0].top_logprobs p.logprob
] for p in r.choices[0].logprobs.content[0].top_logprobs
self.check_token_hallucination(token_content, logprobs) ]
except Exception as e:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content return token_content
except StopIteration: except StopIteration:
raise StopIteration raise StopIteration
def process_token(self): def _process_token(self):
""" """
Processes the current token and updates the state and mask accordingly. Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities. Detects hallucinations based on the token type and log probabilities.
""" """
content = "".join(self.tokens).replace(" ", "") content = "".join(self.tokens).replace(" ", "")
if self.current_token == "<tool_call>": if self.current_token == const.TOOL_CALL_TOKEN:
self.mask.append("t") self.mask.append(const.MaskToken.TOOL_CALL)
self.check_logprob() self._check_logprob()
# Function name extraction logic # Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask # If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name": if self.state == "function_name":
if self.current_token not in const.FUNC_NAME_END_TOKEN: if self.current_token not in const.FUNC_NAME_END_TOKEN:
self.mask.append("f") self.mask.append(const.MaskToken.FUNCTION_NAME)
else: else:
self.state = None self.state = None
self.is_function_name_hallucinated() self._is_function_name_hallucinated()
# Check if the token is a function name start token, change the state # Check if the token is a function name start token, change the state
if content.endswith(const.FUNC_NAME_START_PATTERN): if content.endswith(const.FUNC_NAME_START_PATTERN):
@ -181,14 +189,14 @@ class HallucinationStateHandler:
if self.state == "parameter_name" and not content.endswith( if self.state == "parameter_name" and not content.endswith(
const.PARAMETER_NAME_END_TOKENS const.PARAMETER_NAME_END_TOKENS
): ):
self.mask.append("p") self.mask.append(const.MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done # if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern # The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith( elif self.state == "parameter_name" and content.endswith(
const.PARAMETER_NAME_END_TOKENS const.PARAMETER_NAME_END_TOKENS
): ):
self.state = None self.state = None
self.is_parameter_name_hallucinated() self._is_parameter_name_hallucinated()
self.parameter_name_done = True self.parameter_name_done = True
# if the parameter name is done and the token is a parameter name start token, change the state # if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith( elif self.parameter_name_done and content.endswith(
@ -207,20 +215,20 @@ class HallucinationStateHandler:
): ):
# checking if the token is a value token and is not empty # checking if the token is a value token and is not empty
if self.current_token.strip() not in ['"', ""]: if self.current_token.strip() not in ['"', ""]:
self.mask.append("v") self.mask.append(const.MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token # checking if the parameter doesn't have default and the token is the first parameter value token
if ( if (
len(self.mask) > 1 len(self.mask) > 1
and self.mask[-2] != "v" and self.mask[-2] != const.MaskToken.PARAMETER_VALUE
and not is_parameter_property( and not is_parameter_property(
self.function_properties[self.function_name], self.function_properties[self.function_name],
self.parameter_name[-1], self.parameter_name[-1],
"default", "default",
) )
): ):
self.check_logprob() self._check_logprob()
else: else:
self.mask.append("e") self.mask.append(const.MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state # if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith( elif self.state == "parameter_value" and content.endswith(
const.PARAMETER_VALUE_END_TOKEN const.PARAMETER_VALUE_END_TOKEN
@ -235,9 +243,9 @@ class HallucinationStateHandler:
# Maintain consistency between stack and mask # Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask # If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens): if len(self.mask) != len(self.tokens):
self.mask.append("e") self.mask.append(const.MaskToken.NOT_USED)
def check_logprob(self): def _check_logprob(self):
""" """
Checks the log probability of the current token and updates the token probability map. Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy. Detects hallucinations based on entropy and variance of entropy.
@ -247,12 +255,12 @@ class HallucinationStateHandler:
self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold( if check_threshold(
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]] entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
): ):
self.hallucination = True self.hallucination = True
self.hallucination_message = f"Token '{self.current_token}' is uncertain." self.hallucination_message = f"Token '{self.current_token}' is uncertain."
def count_consecutive_token(self, token="v") -> int: def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
""" """
Counts the number of consecutive occurrences of a given token in the mask. Counts the number of consecutive occurrences of a given token in the mask.
@ -268,23 +276,23 @@ class HallucinationStateHandler:
else 0 else 0
) )
def is_function_name_hallucinated(self): def _is_function_name_hallucinated(self):
""" """
Checks the extracted function name against the function descriptions. Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found. Detects hallucinations if the function name is not found.
""" """
f_len = self.count_consecutive_token("f") f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:]) self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys(): if self.function_name not in self.function_description.keys():
self.hallucination = True self.hallucination = True
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
def is_parameter_name_hallucinated(self): def _is_parameter_name_hallucinated(self):
""" """
Checks the extracted parameter name against the function descriptions. Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found. Detects hallucinations if the parameter name is not found.
""" """
p_len = self.count_consecutive_token("p") p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:]) parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name) self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]: if parameter_name not in self.function_description[self.function_name]:

View file

@ -52,7 +52,7 @@ def test_hallucination(case):
) )
for token, logprob in zip(case["tokens"], case["logprobs"]): for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "</tool_call>": if token != "</tool_call>":
state.check_token_hallucination(token, logprob) state.append_and_check_token_hallucination(token, logprob)
if state.hallucination: if state.hallucination:
break break
assert state.hallucination == case["expect"] assert state.hallucination == case["expect"]