From 075c94fd39e498abf368d5ab7d7d4f777c91a06e Mon Sep 17 00:00:00 2001 From: cotran Date: Tue, 26 Nov 2024 12:13:21 -0800 Subject: [PATCH] fix --- .../function_calling/hallucination_handler.py | 47 +++++++++---------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index add3fc52..136a842a 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -1,14 +1,10 @@ import json -import ast -import os -import json import math import torch import random from typing import Any, Dict, List, Tuple import app.commons.constants as const import itertools -from enum import Enum def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: @@ -84,24 +80,22 @@ class HallucinationStateHandler: parameter_name (list): List of extracted parameter names. function_description (dict): Description of functions and their parameters. token_probs_map (list): List mapping tokens to their entropy and variance of entropy. - current_token (str): The current token being processed. """ def __init__(self, response_iterator=None, function=None): """ Initializes the HallucinationStateHandler with default values. """ - self.tokens = [] - self.logprobs = [] - self.state = None - self.mask = [] - self.parameter_name_done = False - self.hallucination = False - self.hallucination_message = "" - self.parameter_name = [] - - self.token_probs_map = [] - self.current_token = None + self.tokens: List[str] = [] + self.logprobs: List[float] = [] + self.state: str = None + self.mask: List[str] = [] + self.parameter_name_done: bool = False + self.hallucination: bool = False + self.error_message: str = "" + self.error_type: str = "" + self.parameter_name: List[str] = [] + self.token_probs_map: List[Tuple[str, float, float]] = [] self.response_iterator = response_iterator self._process_function(function) @@ -128,7 +122,6 @@ class HallucinationStateHandler: Returns: bool: True if the token is hallucinated, False otherwise. """ - self.current_token = token self.tokens.append(token) self.logprobs.append(logprob) self._process_token() @@ -166,14 +159,14 @@ class HallucinationStateHandler: Detects hallucinations based on the token type and log probabilities. """ content = "".join(self.tokens).replace(" ", "") - if self.current_token == const.TOOL_CALL_TOKEN: + if self.tokens[-1] == const.TOOL_CALL_TOKEN: self.mask.append(const.MaskToken.TOOL_CALL) self._check_logprob() # Function name extraction logic # If the state is function name and the token is not an end token, add to the mask if self.state == "function_name": - if self.current_token not in const.FUNC_NAME_END_TOKEN: + if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN: self.mask.append(const.MaskToken.FUNCTION_NAME) else: self.state = None @@ -181,7 +174,6 @@ class HallucinationStateHandler: # Check if the token is a function name start token, change the state if content.endswith(const.FUNC_NAME_START_PATTERN): - print("function name entered") self.state = "function_name" # Parameter name extraction logic @@ -214,7 +206,7 @@ class HallucinationStateHandler: const.PARAMETER_VALUE_END_TOKEN ): # checking if the token is a value token and is not empty - if self.current_token.strip() not in ['"', ""]: + if self.tokens[-1].strip() not in ['"', ""]: self.mask.append(const.MaskToken.PARAMETER_VALUE) # checking if the parameter doesn't have default and the token is the first parameter value token if ( @@ -258,7 +250,10 @@ class HallucinationStateHandler: entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] ): self.hallucination = True - self.hallucination_message = f"Token '{self.current_token}' is uncertain." + self.error_type = "Hallucination" + self.error_message = ( + f"Hallucination: token '{self.tokens[-1]}' is uncertain." + ) def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int: """ @@ -284,8 +279,8 @@ class HallucinationStateHandler: f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME) self.function_name = "".join(self.tokens[:-1][-f_len:]) if self.function_name not in self.function_description.keys(): - self.hallucination = True - self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions." + self.error_type = "function_name" + self.error_message = f"Function name '{self.function_name}' not found in given function descriptions." def _is_parameter_name_hallucinated(self): """ @@ -296,5 +291,5 @@ class HallucinationStateHandler: parameter_name = "".join(self.tokens[:-1][-p_len:]) self.parameter_name.append(parameter_name) if parameter_name not in self.function_description[self.function_name]: - self.hallucination = True - self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions." + self.error_type = "parameter_name" + self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."