From 4da2184d56af09713c332e5ea4b9525a2c59236f Mon Sep 17 00:00:00 2001 From: cotran Date: Mon, 25 Nov 2024 14:15:49 -0800 Subject: [PATCH] address issue --- .../function_calling/hallucination_handler.py | 82 +++++++++++++++---- 1 file changed, 67 insertions(+), 15 deletions(-) diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index eb4a56d1..acbda584 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -10,7 +10,7 @@ import app.commons.constants as const import itertools -def check_threshold(entropy, varentropy, thd): +def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: """ Check if the given entropy or variance of entropy exceeds the specified thresholds. @@ -48,19 +48,21 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: return entropy.item(), varentropy.item() -def check_parameter_property(api_description, parameter_name, property_name): +def is_parameter_property( + function_description: Dict, parameter_name: str, property_name: str +) -> bool: """ Check if a parameter in an API description has a specific property. Args: - api_description (dict): The API description in JSON format. + function_description (dict): The API description in JSON format. parameter_name (str): The name of the parameter to check. property_name (str): The property to look for (e.g., 'format', 'default'). Returns: bool: True if the parameter has the specified property, False otherwise. """ - parameters = api_description.get("properties", {}) + parameters = function_description.get("properties", {}) parameter_info = parameters.get(parameter_name, {}) return property_name in parameter_info @@ -84,7 +86,7 @@ class HallucinationStateHandler: current_token (str): The current token being processed. """ - def __init__(self): + def __init__(self, response_iterator=None, function=None): """ Initializes the HallucinationStateHandler with default values. """ @@ -99,18 +101,56 @@ class HallucinationStateHandler: self.token_probs_map = [] self.current_token = None + self.response_iterator = response_iterator + self.process_function(function) - def process_function(self, apis): - self.apis = apis - if self.apis is None: + def process_function(self, function): + self.function = function + if self.function is None: raise ValueError("API descriptions not set.") parameter_names = {} - for func in self.apis: + for func in self.function: func_name = func["name"] parameters = func["parameters"]["properties"] parameter_names[func_name] = list(parameters.keys()) self.function_description = parameter_names - self.function_properties = {x["name"]: x["parameters"] for x in self.apis} + self.function_properties = {x["name"]: x["parameters"] for x in self.function} + + def check_token_hallucination(self, token, logprob): + """ + Check if the given token is hallucinated based on the log probability. + + Args: + token (str): The token to check. + logprob (float): The log probability of the token. + + Returns: + bool: True if the token is hallucinated, False otherwise. + """ + self.current_token = token + self.tokens.append(token) + self.logprobs.append(logprob) + self.process_token() + return self.hallucination + + def __iter__(self): + return self + + def __next__(self): + if self.response_iterator is not None: + try: + r = next(self.response_iterator) + if hasattr(r.choices[0].delta, "content"): + token_content = r.choices[0].delta.content + if token_content: + logprobs = [ + p.logprob + for p in r.choices[0].logprobs.content[0].top_logprobs + ] + self.check_token_hallucination(token_content, logprobs) + return token_content + except StopIteration: + raise StopIteration def process_token(self): """ @@ -123,42 +163,52 @@ class HallucinationStateHandler: self.check_logprob() # Function name extraction logic + # If the state is function name and the token is not an end token, add to the mask if self.state == "function_name": if self.current_token not in const.FUNC_NAME_END_TOKEN: self.mask.append("f") else: self.state = None - self.check_function_name() + self.is_function_name_hallucinated() + # Check if the token is a function name start token, change the state if content.endswith(const.FUNC_NAME_START_PATTERN): print("function name entered") self.state = "function_name" # Parameter name extraction logic + # if the state is parameter name and the token is not an end token, add to the mask if self.state == "parameter_name" and not content.endswith( const.PARAMETER_NAME_END_TOKENS ): self.mask.append("p") + # if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done + # The need for parameter name done is to allow the check of parameter value pattern elif self.state == "parameter_name" and content.endswith( const.PARAMETER_NAME_END_TOKENS ): self.state = None - self.check_parameter_name() + self.is_parameter_name_hallucinated() self.parameter_name_done = True + # if the parameter name is done and the token is a parameter name start token, change the state elif self.parameter_name_done and content.endswith( const.PARAMETER_NAME_START_PATTERN ): self.state = "parameter_name" + # if token is a first parameter value start token, change the state if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN): self.state = "parameter_name" # Parameter value extraction logic + # if the state is parameter value and the token is not an end token, add to the mask if self.state == "parameter_value" and not content.endswith( const.PARAMETER_VALUE_END_TOKEN ): + # checking if the token is a value token and is not empty if self.current_token.strip() not in ['"', ""]: self.mask.append("v") + # checking if the parameter doesn't have default and the token is the first parameter value token if ( len(self.mask) > 1 and self.mask[-2] != "v" @@ -171,17 +221,19 @@ class HallucinationStateHandler: self.check_logprob() else: self.mask.append("e") - + # if the state is parameter value and the token is an end token, change the state elif self.state == "parameter_value" and content.endswith( const.PARAMETER_VALUE_END_TOKEN ): self.state = None + # if the parameter name is done and the token is a parameter value start token, change the state elif self.parameter_name_done and content.endswith( const.PARAMETER_VALUE_START_PATTERN ): self.state = "parameter_value" # Maintain consistency between stack and mask + # If the mask length is less than tokens, add an not used (e) token to the mask if len(self.mask) != len(self.tokens): self.mask.append("e") @@ -216,7 +268,7 @@ class HallucinationStateHandler: else 0 ) - def check_function_name(self): + def is_function_name_hallucinated(self): """ Checks the extracted function name against the function descriptions. Detects hallucinations if the function name is not found. @@ -227,7 +279,7 @@ class HallucinationStateHandler: self.hallucination = True self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." - def check_parameter_name(self): + def is_parameter_name_hallucinated(self): """ Checks the extracted parameter name against the function descriptions. Detects hallucinations if the parameter name is not found.