diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index 3b937aa5..e62bf229 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -4,6 +4,7 @@ import app.loader as loader from app.function_calling.model_handler import ArchFunctionHandler from app.prompt_guard.model_handler import ArchGuardHanlder +from enum import Enum logger = utils.get_model_server_logger() @@ -38,6 +39,7 @@ arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict) # Patterns for function name and parameter parsing FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") FUNC_NAME_END_TOKEN = ('",', "',") +TOOL_CALL_TOKEN = "" FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") @@ -45,10 +47,19 @@ PARAMETER_NAME_START_PATTERN = (',"', ",'") PARAMETER_VALUE_START_PATTERN = ('":', "':") PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") + # Thresholds +class MaskToken(Enum): + FUNCTION_NAME = "f" + PARAMETER_VALUE = "v" + PARAMETER_NAME = "p" + NOT_USED = "e" + TOOL_CALL = "t" + + HALLUCINATION_THRESHOLD_DICT = { - "t": {"entropy": 0.1, "varentropy": 0.5}, - "v": { + MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5}, + MaskToken.PARAMETER_VALUE.value: { "entropy": 0.5, "varentropy": 2.5, }, diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index dc8e8ee0..add3fc52 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -8,6 +8,7 @@ import random from typing import Any, Dict, List, Tuple import app.commons.constants as const import itertools +from enum import Enum def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: @@ -102,9 +103,9 @@ class HallucinationStateHandler: self.token_probs_map = [] self.current_token = None self.response_iterator = response_iterator - self.process_function(function) + self._process_function(function) - def process_function(self, function): + def _process_function(self, function): self.function = function if self.function is None: raise ValueError("API descriptions not set.") @@ -116,7 +117,7 @@ class HallucinationStateHandler: self.function_description = parameter_names self.function_properties = {x["name"]: x["parameters"] for x in self.function} - def check_token_hallucination(self, token, logprob): + def append_and_check_token_hallucination(self, token, logprob): """ Check if the given token is hallucinated based on the log probability. @@ -130,7 +131,7 @@ class HallucinationStateHandler: self.current_token = token self.tokens.append(token) self.logprobs.append(logprob) - self.process_token() + self._process_token() return self.hallucination def __iter__(self): @@ -143,33 +144,40 @@ class HallucinationStateHandler: if hasattr(r.choices[0].delta, "content"): token_content = r.choices[0].delta.content if token_content: - logprobs = [ - p.logprob - for p in r.choices[0].logprobs.content[0].top_logprobs - ] - self.check_token_hallucination(token_content, logprobs) + try: + logprobs = [ + p.logprob + for p in r.choices[0].logprobs.content[0].top_logprobs + ] + except Exception as e: + raise ValueError( + f"Error extracting logprobs from response: {e}" + ) + self.append_and_check_token_hallucination( + token_content, logprobs + ) return token_content except StopIteration: raise StopIteration - def process_token(self): + def _process_token(self): """ Processes the current token and updates the state and mask accordingly. Detects hallucinations based on the token type and log probabilities. """ content = "".join(self.tokens).replace(" ", "") - if self.current_token == "": - self.mask.append("t") - self.check_logprob() + if self.current_token == const.TOOL_CALL_TOKEN: + self.mask.append(const.MaskToken.TOOL_CALL) + self._check_logprob() # Function name extraction logic # If the state is function name and the token is not an end token, add to the mask if self.state == "function_name": if self.current_token not in const.FUNC_NAME_END_TOKEN: - self.mask.append("f") + self.mask.append(const.MaskToken.FUNCTION_NAME) else: self.state = None - self.is_function_name_hallucinated() + self._is_function_name_hallucinated() # Check if the token is a function name start token, change the state if content.endswith(const.FUNC_NAME_START_PATTERN): @@ -181,14 +189,14 @@ class HallucinationStateHandler: if self.state == "parameter_name" and not content.endswith( const.PARAMETER_NAME_END_TOKENS ): - self.mask.append("p") + self.mask.append(const.MaskToken.PARAMETER_NAME) # if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done # The need for parameter name done is to allow the check of parameter value pattern elif self.state == "parameter_name" and content.endswith( const.PARAMETER_NAME_END_TOKENS ): self.state = None - self.is_parameter_name_hallucinated() + self._is_parameter_name_hallucinated() self.parameter_name_done = True # if the parameter name is done and the token is a parameter name start token, change the state elif self.parameter_name_done and content.endswith( @@ -207,20 +215,20 @@ class HallucinationStateHandler: ): # checking if the token is a value token and is not empty if self.current_token.strip() not in ['"', ""]: - self.mask.append("v") + self.mask.append(const.MaskToken.PARAMETER_VALUE) # checking if the parameter doesn't have default and the token is the first parameter value token if ( len(self.mask) > 1 - and self.mask[-2] != "v" + and self.mask[-2] != const.MaskToken.PARAMETER_VALUE and not is_parameter_property( self.function_properties[self.function_name], self.parameter_name[-1], "default", ) ): - self.check_logprob() + self._check_logprob() else: - self.mask.append("e") + self.mask.append(const.MaskToken.NOT_USED) # if the state is parameter value and the token is an end token, change the state elif self.state == "parameter_value" and content.endswith( const.PARAMETER_VALUE_END_TOKEN @@ -235,9 +243,9 @@ class HallucinationStateHandler: # Maintain consistency between stack and mask # If the mask length is less than tokens, add an not used (e) token to the mask if len(self.mask) != len(self.tokens): - self.mask.append("e") + self.mask.append(const.MaskToken.NOT_USED) - def check_logprob(self): + def _check_logprob(self): """ Checks the log probability of the current token and updates the token probability map. Detects hallucinations based on entropy and variance of entropy. @@ -247,12 +255,12 @@ class HallucinationStateHandler: self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) if check_threshold( - entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]] + entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] ): self.hallucination = True self.hallucination_message = f"Token '{self.current_token}' is uncertain." - def count_consecutive_token(self, token="v") -> int: + def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int: """ Counts the number of consecutive occurrences of a given token in the mask. @@ -268,23 +276,23 @@ class HallucinationStateHandler: else 0 ) - def is_function_name_hallucinated(self): + def _is_function_name_hallucinated(self): """ Checks the extracted function name against the function descriptions. Detects hallucinations if the function name is not found. """ - f_len = self.count_consecutive_token("f") + f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME) self.function_name = "".join(self.tokens[:-1][-f_len:]) if self.function_name not in self.function_description.keys(): self.hallucination = True self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." - def is_parameter_name_hallucinated(self): + def _is_parameter_name_hallucinated(self): """ Checks the extracted parameter name against the function descriptions. Detects hallucinations if the parameter name is not found. """ - p_len = self.count_consecutive_token("p") + p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME) parameter_name = "".join(self.tokens[:-1][-p_len:]) self.parameter_name.append(parameter_name) if parameter_name not in self.function_description[self.function_name]: diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py index 8c23324b..60483ad6 100644 --- a/model_server/app/tests/test_hallucination.py +++ b/model_server/app/tests/test_hallucination.py @@ -52,7 +52,7 @@ def test_hallucination(case): ) for token, logprob in zip(case["tokens"], case["logprobs"]): if token != "": - state.check_token_hallucination(token, logprob) + state.append_and_check_token_hallucination(token, logprob) if state.hallucination: break assert state.hallucination == case["expect"]