From 665dbc2d4eef196aaab4f529cff77a1d0e7ecb27 Mon Sep 17 00:00:00 2001 From: cotran Date: Mon, 18 Nov 2024 00:53:49 -0800 Subject: [PATCH] fix test --- .../function_calling/hallucination_handler.py | 250 ++++++++++++++---- .../app/function_calling/model_handler.py | 2 +- model_server/app/tests/test_hallucination.py | 12 +- 3 files changed, 203 insertions(+), 61 deletions(-) diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py index 52e5fbac..67bc1be4 100644 --- a/model_server/app/function_calling/hallucination_handler.py +++ b/model_server/app/function_calling/hallucination_handler.py @@ -1,8 +1,15 @@ import torch import numpy as np -from typing import List, Dict +import math +import app.commons.constants as const +import random +from typing import List, Dict, Any, Tuple +import json -def filter_tokens_and_probs(tokens: List[str], probs: List[float]) -> Tuple[List[], List[float]]: + +def filter_tokens_and_probs( + tokens: List[str], probs: List[float] +) -> Tuple[List[str], List[float]]: """ Filters out special tokens from the list of tokens and their corresponding probabilities. @@ -14,17 +21,17 @@ def filter_tokens_and_probs(tokens: List[str], probs: List[float]) -> Tuple[List tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities. """ # Use regex to identify tokens without special characters - special_tokens = ['\\n', '{"', '":', ' "', '",', ' {"', '"}}\\n', ' ', '"}}\n'] - filtered_tokens = [ - token for token in tokens - if token not in special_tokens - ] + special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n'] + filtered_tokens = [token for token in tokens if token not in special_tokens] filtered_probs = [ - prob for token, prob in zip(tokens, probs) - if token not in special_tokens + prob for token, prob in zip(tokens, probs) if token not in special_tokens ] return filtered_tokens, filtered_probs -def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_names: Dict[str, List[str]]) -> Tuple[Dict[str, List[str]], Dict[str, List[float]]]: + + +def get_all_parameter_values( + tokens: List[str], probs: List[float], parameter_names: Dict[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Extracts parameter values and their corresponding probabilities from the tokens. @@ -49,7 +56,9 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na # Incrementally combine tokens to find a full match with any parameter name while i < len(tokens): if combined_token: - combined_token += tokens[i] # Append next token to the current combination + combined_token += tokens[ + i + ] # Append next token to the current combination else: combined_token = tokens[i] # Start a new combination @@ -62,7 +71,11 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na i += 1 # Move past the parameter name # Collect tokens as values until the next parameter or end marker - while i < len(tokens) and tokens[i] not in params and tokens[i] != '': + while ( + i < len(tokens) + and tokens[i] not in params + and tokens[i] != "" + ): values.append(tokens[i]) prob_values.append(probs[i]) i += 1 @@ -83,7 +96,11 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na i = start + 1 return parameter_values, probs_values -def calculate_stats(data: Dict, function_description: Dict) -> Dict: + + +def calculate_stats( + data: Dict[str, Any], function_description: Dict[str, Any] +) -> Dict[str, Any]: """ Calculates statistical metrics for the given data. @@ -97,19 +114,33 @@ def calculate_stats(data: Dict, function_description: Dict) -> Dict: stats = {} try: for key, values in data.items(): - if len(data[key])>=1: + if len(data[key]) >= 1: first = values[0] max_value = max(values) min_value = min(values) avg_value = sum(values) / len(values) - has_format = check_parameter_property(function_description, key, "format") - has_default = check_parameter_property(function_description, key , "default") - stats[key] = {'first':first, 'max': max_value, 'min': min_value, 'avg': avg_value, 'has_format': has_format, 'has_default': has_default} + has_format = check_parameter_property( + function_description, key, "format" + ) + has_default = check_parameter_property( + function_description, key, "default" + ) + stats[key] = { + "first": first, + "max": max_value, + "min": min_value, + "avg": avg_value, + "has_format": has_format, + "has_default": has_default, + } except Exception as e: print(data) return stats -def check_parameter_property(api_description: Dict, parameter_name: str, property_name: str)-> bool: + +def check_parameter_property( + api_description: Dict[str, Any], parameter_name: str, property_name: str +) -> bool: """ Check if a parameter in an API description has a specific property. @@ -127,8 +158,36 @@ def check_parameter_property(api_description: Dict, parameter_name: str, propert return property_name in parameter_info +def calculate_entropy(log_probs): + """ + Calculate the entropy and variance of entropy (varentropy) from log probabilities. -def hallucination_detect(token:str, log_probs:List[float], current_state: Dict, entropy_thd : float= 0.7, varentropy_thd :float = 4.0) -> bool: + Args: + log_probs (list of float): A list of log probabilities. + + Returns: + tuple: A tuple containing: + - log_probs (list of float): The input log probabilities as a list. + - entropy (float): The calculated entropy. + - varentropy (float): The calculated variance of entropy. + """ + log_probs = torch.tensor(log_probs) + token_probs = torch.exp(log_probs) + entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e) + varentropy = torch.sum( + token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, + dim=-1, + ) + return log_probs.tolist(), entropy.item(), varentropy.item() + + +def hallucination_detect( + token: str, + log_probs: List[float], + current_state: Dict[str, Any], + entropy_thd: float = 0.7, + varentropy_thd: float = 4.0, +) -> bool: """ Detects hallucinations in the token sequence based on entropy and varentropy thresholds. @@ -142,12 +201,12 @@ def hallucination_detect(token:str, log_probs:List[float], current_state: Dict, Returns: bool: True if a hallucination is detected, False otherwise. """ - + if token: # check if there is content in token current_state["tokens"].append(token) - current_state['content'] += token - current_state['logprobs'].append(log_probs) + current_state["content"] += token + current_state["logprobs"].append(log_probs) # keep track of entropy and varentropy _, entropy, varentropy = calculate_entropy(log_probs) current_state["entropy"].append(entropy) @@ -156,71 +215,148 @@ def hallucination_detect(token:str, log_probs:List[float], current_state: Dict, if token == "": if entropy > entropy_thd or varentropy > varentropy_thd: current_state["hallucination"] = True - current_state["hallucination_message"] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}" + current_state[ + "hallucination_message" + ] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}" return True elif token == "": current_state["state"] = "tool_call_end" # try to extract tool call, else raise error try: - current_state['tool_call'] = extract_tool_calls(current_state["content"])[0] - current_state['tool_call_process'] = True + current_state[ + "tool_call" + ] = const.arch_function_hanlder.extract_tool_calls( + current_state["content"] + )[ + 0 + ] + current_state["tool_call_process"] = True except: - current_state['tool_call_process'] = False + current_state["tool_call_process"] = False print(f"cant process tool") return True # check if function name is valid - if current_state['tool_call']['function']['name'] not in current_state['parameter_names'].keys(): + if ( + current_state["tool_call"]["function"]["name"] + not in current_state["parameter_names"].keys() + ): current_state["hallucination"] = True - current_state["hallucination_message"] = f"function name {current_state['tool_call']['name']} not found" + current_state[ + "hallucination_message" + ] = f"function name {current_state['tool_call']['name']} not found" return True # check if parameter names are from the given function tools - current_parameter_names = current_state['tool_call']['function']['arguments'].keys() - given_parameter_names = current_state['parameter_names'][current_state['tool_call']['function']['name']] + current_parameter_names = current_state["tool_call"]["function"][ + "arguments" + ].keys() + given_parameter_names = current_state["parameter_names"][ + current_state["tool_call"]["function"]["name"] + ] if not set(current_parameter_names).issubset(given_parameter_names): - missing_keys = set(current_parameter_names) - set(given_parameter_names) current_state["hallucination"] = True - current_state["hallucination_message"] = f"parameter names {missing_keys} not found" + current_state[ + "hallucination_message" + ] = f"parameter names {missing_keys} not found" return True # filtered special tokens that are not needed in the hallucination check for parameter values - current_state["filtered_tokens"], current_state["filtered_entropy"] = filter_tokens_and_probs(current_state["tokens"], current_state["entropy"]) - current_state["filtered_tokens"], current_state["filtered_varentropy"] = filter_tokens_and_probs(current_state["tokens"], current_state["varentropy"]) - parameter_values, entropy_values = get_all_parameter_values(current_state["filtered_tokens"], current_state["filtered_entropy"], current_state['parameter_names']) - parameter_values, varentropy_values = get_all_parameter_values(current_state["filtered_tokens"], current_state["filtered_varentropy"], current_state['parameter_names']) + ( + current_state["filtered_tokens"], + current_state["filtered_entropy"], + ) = filter_tokens_and_probs( + current_state["tokens"], current_state["entropy"] + ) + ( + current_state["filtered_tokens"], + current_state["filtered_varentropy"], + ) = filter_tokens_and_probs( + current_state["tokens"], current_state["varentropy"] + ) + parameter_values, entropy_values = get_all_parameter_values( + current_state["filtered_tokens"], + current_state["filtered_entropy"], + current_state["parameter_names"], + ) + parameter_values, varentropy_values = get_all_parameter_values( + current_state["filtered_tokens"], + current_state["filtered_varentropy"], + current_state["parameter_names"], + ) - current_state['parameter_values'] = parameter_values - current_state['parameter_values_entropy'] = entropy_values - current_state['parameter_values_varentropy'] = varentropy_values + current_state["parameter_values"] = parameter_values + current_state["parameter_values_entropy"] = entropy_values + current_state["parameter_values_varentropy"] = varentropy_values # calculate the max, first, avg of sub tokens for parameter value - current_state['parameter_value_entropy_stat'] = calculate_stats(current_state['parameter_values_entropy'], current_state['function_description'][0]) - current_state['parameter_value_varentropy_stat'] = calculate_stats(current_state['parameter_values_varentropy'], current_state['function_description'][0]) + current_state["parameter_value_entropy_stat"] = calculate_stats( + current_state["parameter_values_entropy"], + current_state["function_description"][0], + ) + current_state["parameter_value_varentropy_stat"] = calculate_stats( + current_state["parameter_values_varentropy"], + current_state["function_description"][0], + ) # get map for debugging - current_state['token_entropy_map'] = {x : y for x,y in zip(current_state['tokens'], current_state['entropy'])} - current_state['token_varentropy_map'] = {x : y for x,y in zip(current_state['tokens'], current_state['varentropy'])} + current_state["token_entropy_map"] = { + x: y for x, y in zip(current_state["tokens"], current_state["entropy"]) + } + current_state["token_varentropy_map"] = { + x: y + for x, y in zip(current_state["tokens"], current_state["varentropy"]) + } # checking hallucination for parameter value - current_state['parameter_value_check'] = {x : {'hallucination': False, 'message': ''} for x in current_state['parameter_values'].keys()} - for key in current_state['parameter_value_check'].keys(): + current_state["parameter_value_check"] = { + x: {"hallucination": False, "message": ""} + for x in current_state["parameter_values"].keys() + } + for key in current_state["parameter_value_check"].keys(): # if parameter is given a format, check the first token - if current_state['parameter_value_entropy_stat'][key]['has_format']: - if current_state['parameter_value_entropy_stat'][key]['first'] > entropy_thd or current_state['parameter_value_varentropy_stat'][key]['first'] > varentropy_thd: - current_state['parameter_value_check'][key]['hallucination'] = True + if current_state["parameter_value_entropy_stat"][key]["has_format"]: + if ( + current_state["parameter_value_entropy_stat"][key]["first"] + > entropy_thd + or current_state["parameter_value_varentropy_stat"][key][ + "first" + ] + > varentropy_thd + ): + current_state["parameter_value_check"][key][ + "hallucination" + ] = True current_state["hallucination"] = True - current_state['parameter_value_check'][key]['message'] = f"parameter {key} with formatting doesn't pass threshold" + current_state["parameter_value_check"][key][ + "message" + ] = f"parameter {key} with formatting doesn't pass threshold" # if parameter gis given a default value, we can always use default - elif current_state['parameter_value_entropy_stat'][key]['has_default']: - current_state['parameter_value_check'][key]['hallucination'] = False - current_state['parameter_value_check'][key]['message'] = f"parameter {key} with default" + elif current_state["parameter_value_entropy_stat"][key]["has_default"]: + current_state["parameter_value_check"][key]["hallucination"] = False + current_state["parameter_value_check"][key][ + "message" + ] = f"parameter {key} with default" # check if max sub token is > thresholds else: - if current_state['parameter_value_entropy_stat'][key]['max'] > entropy_thd or current_state['parameter_value_varentropy_stat'][key]['max'] > varentropy_thd: - current_state['parameter_value_check'][key]['hallucination'] = True - current_state['parameter_value_check'][key]['message'] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold" + if ( + current_state["parameter_value_entropy_stat"][key]["max"] + > entropy_thd + or current_state["parameter_value_varentropy_stat"][key]["max"] + > varentropy_thd + ): + current_state["parameter_value_check"][key][ + "hallucination" + ] = True + current_state["parameter_value_check"][key][ + "message" + ] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold" current_state["hallucination"] = True if current_state["hallucination"] == True: - current_state["hallucination_message"] = "\n".join([current_state['parameter_value_check'][key]['message'] for key in current_state['parameter_value_check'].keys()]) + current_state["hallucination_message"] = "\n".join( + [ + current_state["parameter_value_check"][key]["message"] + for key in current_state["parameter_value_check"].keys() + ] + ) return True - return False \ No newline at end of file + return False diff --git a/model_server/app/function_calling/model_handler.py b/model_server/app/function_calling/model_handler.py index 7b915cd4..e1da914c 100644 --- a/model_server/app/function_calling/model_handler.py +++ b/model_server/app/function_calling/model_handler.py @@ -134,4 +134,4 @@ class ArchFunctionHandler: fixed_str += opening_bracket[unmatched_opening] # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str + return fixed_str.replace("'", '"') diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py index 08ea75f2..21e8ad60 100644 --- a/model_server/app/tests/test_hallucination.py +++ b/model_server/app/tests/test_hallucination.py @@ -1,7 +1,16 @@ import json from app.function_calling.hallucination_handler import hallucination_detect import pytest +import os +# Get the directory of the current file +current_dir = os.path.dirname(__file__) + +# Construct the full path to the JSON file +json_file_path = os.path.join(current_dir, "test_cases.json") + +with open(json_file_path) as f: + test_cases = json.load(f) get_weather_api = { "type": "function", @@ -41,9 +50,6 @@ for func in function_description: parameters = func["parameters"]["properties"] parameter_names[func_name] = list(parameters.keys()) -with open("test_cases.json") as f: - test_cases = json.load(f) - @pytest.mark.parametrize("case", test_cases) def test_hallucination(case):