diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index d4e01d12..0a77830b 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -19,6 +19,7 @@ arch_function_generation_params = { "top_k": 50, "max_tokens": 512, "stop_token_ids": [151645], + # "top_logprobs": 10, } arch_guard_model_type = { @@ -34,3 +35,4 @@ zero_shot_model = loader.get_zero_shot_model() prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE]) arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict) +# Patterns for function name and parameter parsing diff --git a/model_server/app/function_calling/hallucination_handler.py b/model_server/app/function_calling/hallucination_handler.py new file mode 100644 index 00000000..544782fb --- /dev/null +++ b/model_server/app/function_calling/hallucination_handler.py @@ -0,0 +1,324 @@ +import json +import math +import torch +import random +from typing import Any, Dict, List, Tuple +import itertools +from enum import Enum + +# constants +FUNC_NAME_START_PATTERN = ('\n{"name":"', "\n{'name':'") +FUNC_NAME_END_TOKEN = ('",', "',") +TOOL_CALL_TOKEN = "" + +FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'") +PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'") +PARAMETER_NAME_START_PATTERN = (',"', ",'") +PARAMETER_VALUE_START_PATTERN = ('":', "':") +PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',") + + +# Thresholds +class MaskToken(Enum): + FUNCTION_NAME = "f" + PARAMETER_VALUE = "v" + PARAMETER_NAME = "p" + NOT_USED = "e" + TOOL_CALL = "t" + + +HALLUCINATION_THRESHOLD_DICT = { + MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5}, + MaskToken.PARAMETER_VALUE.value: { + "entropy": 0.5, + "varentropy": 2.5, + }, +} + + +def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: + """ + Check if the given entropy or variance of entropy exceeds the specified thresholds. + + Args: + entropy (float): The entropy value to check. + varentropy (float): The variance of entropy value to check. + thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'. + + Returns: + bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise. + """ + return entropy > thd["entropy"] or varentropy > thd["varentropy"] + + +def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: + """ + Calculate the entropy and variance of entropy (varentropy) from log probabilities. + + Args: + log_probs (list of float): A list of log probabilities. + + Returns: + tuple: A tuple containing: + - log_probs (list of float): The input log probabilities as a list. + - entropy (float): The calculated entropy. + - varentropy (float): The calculated variance of entropy. + """ + log_probs = torch.tensor(log_probs) + token_probs = torch.exp(log_probs) + entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e) + varentropy = torch.sum( + token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, + dim=-1, + ) + return entropy.item(), varentropy.item() + + +def is_parameter_property( + function_description: Dict, parameter_name: str, property_name: str +) -> bool: + """ + Check if a parameter in an API description has a specific property. + + Args: + function_description (dict): The API description in JSON format. + parameter_name (str): The name of the parameter to check. + property_name (str): The property to look for (e.g., 'format', 'default'). + + Returns: + bool: True if the parameter has the specified property, False otherwise. + """ + parameters = function_description.get("properties", {}) + parameter_info = parameters.get(parameter_name, {}) + + return property_name in parameter_info + + +class HallucinationStateHandler: + """ + A class to handle the state of hallucination detection in token processing. + + Attributes: + tokens (list): List of tokens processed. + logprobs (list): List of log probabilities for each token. + state (str): Current state of the handler. + mask (list): List of masks indicating the type of each token. + parameter_name_done (bool): Flag indicating if parameter name extraction is done. + hallucination (bool): Flag indicating if a hallucination is detected. + hallucination_message (str): Message describing the hallucination. + parameter_name (list): List of extracted parameter names. + function_description (dict): Description of functions and their parameters. + token_probs_map (list): List mapping tokens to their entropy and variance of entropy. + """ + + def __init__(self, response_iterator=None, function=None): + """ + Initializes the HallucinationStateHandler with default values. + """ + self.tokens: List[str] = [] + self.logprobs: List[float] = [] + self.state: str = None + self.mask: List[str] = [] + self.parameter_name_done: bool = False + self.hallucination: bool = False + self.error_message: str = "" + self.error_type: str = "" + self.parameter_name: List[str] = [] + self.token_probs_map: List[Tuple[str, float, float]] = [] + self.response_iterator = response_iterator + self._process_function(function) + + def _process_function(self, function): + self.function = function + if self.function is None: + raise ValueError("API descriptions not set.") + parameter_names = {} + for func in self.function: + func_name = func["name"] + parameters = func["parameters"]["properties"] + parameter_names[func_name] = list(parameters.keys()) + self.function_description = parameter_names + self.function_properties = {x["name"]: x["parameters"] for x in self.function} + + def append_and_check_token_hallucination(self, token, logprob): + """ + Check if the given token is hallucinated based on the log probability. + + Args: + token (str): The token to check. + logprob (float): The log probability of the token. + + Returns: + bool: True if the token is hallucinated, False otherwise. + """ + self.tokens.append(token) + self.logprobs.append(logprob) + self._process_token() + return self.hallucination + + def __iter__(self): + return self + + def __next__(self): + if self.response_iterator is not None: + try: + r = next(self.response_iterator) + if hasattr(r.choices[0].delta, "content"): + token_content = r.choices[0].delta.content + if token_content: + try: + logprobs = [ + p.logprob + for p in r.choices[0].logprobs.content[0].top_logprobs + ] + except Exception as e: + raise ValueError( + f"Error extracting logprobs from response: {e}" + ) + self.append_and_check_token_hallucination( + token_content, logprobs + ) + return token_content + except StopIteration: + raise StopIteration + + def _process_token(self): + """ + Processes the current token and updates the state and mask accordingly. + Detects hallucinations based on the token type and log probabilities. + """ + content = "".join(self.tokens).replace(" ", "") + if self.tokens[-1] == TOOL_CALL_TOKEN: + self.mask.append(MaskToken.TOOL_CALL) + self._check_logprob() + + # Function name extraction logic + # If the state is function name and the token is not an end token, add to the mask + if self.state == "function_name": + if self.tokens[-1] not in FUNC_NAME_END_TOKEN: + self.mask.append(MaskToken.FUNCTION_NAME) + else: + self.state = None + self._is_function_name_hallucinated() + + # Check if the token is a function name start token, change the state + if content.endswith(FUNC_NAME_START_PATTERN): + self.state = "function_name" + + # Parameter name extraction logic + # if the state is parameter name and the token is not an end token, add to the mask + if self.state == "parameter_name" and not content.endswith( + PARAMETER_NAME_END_TOKENS + ): + self.mask.append(MaskToken.PARAMETER_NAME) + # if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done + # The need for parameter name done is to allow the check of parameter value pattern + elif self.state == "parameter_name" and content.endswith( + PARAMETER_NAME_END_TOKENS + ): + self.state = None + self._is_parameter_name_hallucinated() + self.parameter_name_done = True + # if the parameter name is done and the token is a parameter name start token, change the state + elif self.parameter_name_done and content.endswith( + PARAMETER_NAME_START_PATTERN + ): + self.state = "parameter_name" + + # if token is a first parameter value start token, change the state + if content.endswith(FIRST_PARAM_NAME_START_PATTERN): + self.state = "parameter_name" + + # Parameter value extraction logic + # if the state is parameter value and the token is not an end token, add to the mask + if self.state == "parameter_value" and not content.endswith( + PARAMETER_VALUE_END_TOKEN + ): + # checking if the token is a value token and is not empty + if self.tokens[-1].strip() not in ['"', ""]: + self.mask.append(MaskToken.PARAMETER_VALUE) + # checking if the parameter doesn't have default and the token is the first parameter value token + if ( + len(self.mask) > 1 + and self.mask[-2] != MaskToken.PARAMETER_VALUE + and not is_parameter_property( + self.function_properties[self.function_name], + self.parameter_name[-1], + "default", + ) + ): + self._check_logprob() + else: + self.mask.append(MaskToken.NOT_USED) + # if the state is parameter value and the token is an end token, change the state + elif self.state == "parameter_value" and content.endswith( + PARAMETER_VALUE_END_TOKEN + ): + self.state = None + # if the parameter name is done and the token is a parameter value start token, change the state + elif self.parameter_name_done and content.endswith( + PARAMETER_VALUE_START_PATTERN + ): + self.state = "parameter_value" + + # Maintain consistency between stack and mask + # If the mask length is less than tokens, add an not used (e) token to the mask + if len(self.mask) != len(self.tokens): + self.mask.append(MaskToken.NOT_USED) + + def _check_logprob(self): + """ + Checks the log probability of the current token and updates the token probability map. + Detects hallucinations based on entropy and variance of entropy. + """ + probs = self.logprobs[-1] + entropy, varentropy = calculate_entropy(probs) + self.token_probs_map.append((self.tokens[-1], entropy, varentropy)) + + if check_threshold( + entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value] + ): + self.hallucination = True + self.error_type = "Hallucination" + self.error_message = ( + f"Hallucination: token '{self.tokens[-1]}' is uncertain." + ) + + def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int: + """ + Counts the number of consecutive occurrences of a given token in the mask. + + Args: + token (str): The token to count in the mask. + + Returns: + int: The number of consecutive occurrences of the token. + """ + return ( + len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask)))) + if self.mask and self.mask[-1] == token + else 0 + ) + + def _is_function_name_hallucinated(self): + """ + Checks the extracted function name against the function descriptions. + Detects hallucinations if the function name is not found. + """ + f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME) + self.function_name = "".join(self.tokens[:-1][-f_len:]) + if self.function_name not in self.function_description.keys(): + self.error_type = "function_name" + self.error_message = f"Function name '{self.function_name}' not found in given function descriptions." + + def _is_parameter_name_hallucinated(self): + """ + Checks the extracted parameter name against the function descriptions. + Detects hallucinations if the parameter name is not found. + """ + p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME) + parameter_name = "".join(self.tokens[:-1][-p_len:]) + self.parameter_name.append(parameter_name) + if parameter_name not in self.function_description[self.function_name]: + self.error_type = "parameter_name" + self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions." diff --git a/model_server/app/function_calling/model_handler.py b/model_server/app/function_calling/model_handler.py index 7b915cd4..e1da914c 100644 --- a/model_server/app/function_calling/model_handler.py +++ b/model_server/app/function_calling/model_handler.py @@ -134,4 +134,4 @@ class ArchFunctionHandler: fixed_str += opening_bracket[unmatched_opening] # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str + return fixed_str.replace("'", '"') diff --git a/model_server/app/tests/test_cases.json b/model_server/app/tests/test_cases.json new file mode 100644 index 00000000..8fd7ec1e --- /dev/null +++ b/model_server/app/tests/test_cases.json @@ -0,0 +1,794 @@ +[{ + "case": "tool_call_halluciation", + "tokens" : [""], + "expect": 1, + "logprobs": [[-0.3333307206630707, + -1.5310522317886353, + -3.5098977088928223, + -3.9004578590393066, + -5.775152683258057, + -5.814209461212158, + -5.9574151039123535, + -6.0094895362854, + -6.0094895362854, + -6.673445224761963]] +}, +{ + "case" : "parameter_value_hallucination", + "expect" : 0, + "tokens" : ["", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Sea", + ",", + " Australia", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "1", + "'}}\n", + ""], + "logprobs": [[-0.008103232830762863, + -5.085402488708496, + -6.777836799621582, + -7.558959007263184, + -9.850253105163574, + -10.266852378845215, + -10.540244102478027, + -10.722506523132324, + -10.800618171691895, + -10.917786598205566], + [0.0, + -23.25142478942871, + -25.139137268066406, + -26.2847843170166, + -28.992677688598633, + -29.070789337158203, + -29.55248260498047, + -29.91700553894043, + -30.20341682434082, + -30.307567596435547], + [0.0, + -21.66313934326172, + -23.06916046142578, + -23.32953453063965, + -25.65988540649414, + -25.985353469848633, + -26.519121170043945, + -27.07892417907715, + -27.977216720581055, + -28.458908081054688], + [0.0, + -28.094383239746094, + -28.56305694580078, + -29.109844207763672, + -29.44832992553711, + -31.79170036315918, + -32.0, + -32.05207443237305, + -32.31244659423828, + -32.364524841308594], + [0.0, + -30.489830017089844, + -31.140766143798828, + -31.81774139404297, + -34.525634765625, + -35.8275032043457, + -36.504478454589844, + -39.05614471435547, + -40.123680114746094, + -40.696502685546875], + [0.0, + -25.646865844726562, + -26.66232681274414, + -27.781936645507812, + -28.979660034179688, + -31.140764236450195, + -31.92188835144043, + -31.973962783813477, + -33.04149627685547, + -33.58828353881836], + [0.0, + -23.511798858642578, + -24.136695861816406, + -25.230268478393555, + -25.777053833007812, + -25.80309295654297, + -26.45402717590332, + -26.636289596557617, + -26.740440368652344, + -26.896663665771484], + [0.0, + -22.366153717041016, + -24.683483123779297, + -26.610252380371094, + -26.610252380371094, + -27.313264846801758, + -27.67778778076172, + -28.510986328125, + -28.615135192871094, + -29.13588523864746], + [0.0, + -22.52237319946289, + -24.292919158935547, + -24.344993591308594, + -24.39706802368164, + -24.73555564880371, + -29.943042755126953, + -29.969079971313477, + -30.021154403686523, + -30.0341739654541], + [0.0, + -30.17738151550293, + -30.411718368530273, + -30.88039207458496, + -30.984540939331055, + -31.270952224731445, + -31.895851135253906, + -32.46867370605469, + -32.624900817871094, + -33.484134674072266], + [0.0, + -28.146459579467773, + -29.396255493164062, + -30.099267959594727, + -31.127744674682617, + -31.179821014404297, + -32.807159423828125, + -33.7445068359375, + -33.770545959472656, + -34.069976806640625], + [0.0, + -26.323841094970703, + -26.558177947998047, + -30.515867233276367, + -30.932466506958008, + -31.37510108947754, + -31.531326293945312, + -31.70056915283203, + -32.065093994140625, + -32.364524841308594], + [0.0, + -26.922698974609375, + -30.28152847290039, + -31.505287170410156, + -33.30187225341797, + -33.73148727416992, + -34.27827453613281, + -34.33034896850586, + -34.460533142089844, + -34.720909118652344], + [0.0, + -21.532955169677734, + -26.94873809814453, + -29.109848022460938, + -30.80228042602539, + -31.55736541748047, + -33.484134674072266, + -34.681854248046875, + -35.384864807128906, + -35.853538513183594], + [0.0, + -19.502033233642578, + -20.46541976928711, + -24.110658645629883, + -24.501218795776367, + -25.256305694580078, + -25.82912826538086, + -25.881202697753906, + -26.063465118408203, + -26.063465118408203], + [0.0, + -24.37103271484375, + -25.256305694580078, + -25.933277130126953, + -26.714401245117188, + -28.2506103515625, + -31.010576248168945, + -32.07810974121094, + -34.62977981567383, + -35.241661071777344], + [-1.1920922133867862e-06, + -14.398697853088379, + -14.424736976623535, + -17.158666610717773, + -17.41904067993164, + -18.200162887573242, + -18.434499740600586, + -18.66883659362793, + -19.71033477783203, + -19.71033477783203], + [-0.0001445904199499637, + -8.98305892944336, + -11.35246467590332, + -13.1490478515625, + -13.669795989990234, + -14.073375701904297, + -14.516012191772461, + -14.555068969726562, + -15.622602462768555, + -15.635622024536133], + [-0.44747352600097656, + -1.0202960968017578, + -8.467000961303711, + -10.914518356323242, + -11.25300407409668, + -11.435266494750977, + -12.346576690673828, + -13.075624465942383, + -13.12769889831543, + -13.231849670410156], + [-3.123767137527466, + -1.1188862323760986, + -1.639634370803833, + -2.0562336444854736, + -2.8633930683135986, + -2.9675419330596924, + -3.4882919788360596, + -3.69659161567688, + -4.217339515686035, + -4.243376731872559], + [-7.199982064776123e-05, + -9.76410961151123, + -11.144091606140137, + -16.507802963256836, + -17.132701873779297, + -17.44515037536621, + -17.9138240814209, + -18.33042335510254, + -18.9162654876709, + -19.39795684814453], + [0.0, + -22.991050720214844, + -23.824249267578125, + -24.969894409179688, + -25.46460723876953, + -25.829130172729492, + -26.480066299438477, + -26.909683227539062, + -27.33930206298828, + -27.391376495361328], + [-0.21928852796554565, + -1.625309705734253, + -9.775025367736816, + -12.977627754211426, + -16.388530731201172, + -17.091541290283203, + -19.044347763061523, + -19.38283348083496, + -19.460947036743164, + -19.59113311767578], + [0.0, + -24.006507873535156, + -27.443450927734375, + -27.729862213134766, + -28.12042236328125, + -28.276647567749023, + -28.927583694458008, + -30.099267959594727, + -31.479251861572266, + -32.07810974121094], + [0.0, + -18.17412567138672, + -18.772987365722656, + -21.689178466796875, + -21.92351531982422, + -23.7200984954834, + -23.79821014404297, + -23.79821014404297, + -24.032546997070312, + -25.308382034301758], + [-0.12947827577590942, + -2.1083219051361084, + -12.419143676757812, + -15.23118782043457, + -15.595710754394531, + -15.830047607421875, + -17.001731872558594, + -17.60059356689453, + -18.121341705322266, + -18.251529693603516], + [0.0, + -19.449962615966797, + -24.371034622192383, + -24.917821884155273, + -25.529701232910156, + -25.85516929626465, + -26.037429809570312, + -26.115543365478516, + -26.623271942138672, + -26.649309158325195], + [-0.03332124650478363, + -3.4181859493255615, + -15.759925842285156, + -15.812002182006836, + -16.593124389648438, + -17.894996643066406, + -18.09027671813965, + -18.79328727722168, + -19.144792556762695, + -20.147233963012695], + [0.0, + -21.142393112182617, + -22.157852172851562, + -23.511798858642578, + -24.657445907592773, + -25.021968841552734, + -25.5427188873291, + -25.59479331970215, + -25.75101661682129, + -25.95931625366211], + [0.0, + -23.04312515258789, + -24.94385528564453, + -26.323841094970703, + -27.54759979248047, + -28.563060760498047, + -29.786819458007812, + -30.620018005371094, + -30.69812774658203, + -31.08869171142578], + [0.0, + -26.167617797851562, + -28.771360397338867, + -29.55248260498047, + -30.906429290771484, + -31.114728927612305, + -31.414159774780273, + -31.622459411621094, + -31.713590621948242, + -31.726608276367188], + [-0.05012698099017143, + -3.018392562866211, + -11.740934371948242, + -13.146955490112305, + -13.797887802124023, + -14.943536758422852, + -16.037107467651367, + -16.375595092773438, + -16.714080810546875, + -17.36501693725586], + [-0.9704352021217346, + -0.7360983490943909, + -2.1941938400268555, + -4.225115776062012, + -5.0062360763549805, + -5.2666120529174805, + -5.839434623718262, + -7.2714948654174805, + -8.33902645111084, + -8.495253562927246], + [-0.014467108063399792, + -4.258565902709961, + -8.789079666137695, + -10.429437637329102, + -10.793962478637695, + -11.835458755493164, + -11.939607620239258, + -13.31959342956543, + -13.866378784179688, + -15.038063049316406], + [0.0, + -20.08787727355957, + -21.350692749023438, + -21.415786743164062, + -21.50691795349121, + -21.50691795349121, + -22.7176570892334, + -24.13669776916504, + -24.188772201538086, + -24.34499740600586]] +}, +{ + "case": "fail_case", + "expect" : 0, + "tokens" : ["", + "\n", + "{'", + "name", + "':", + " '", + "get", + "_current", + "_weather", + "',", + " '", + "arguments", + "':", + " {'", + "location", + "':", + " '", + "Seattle", + ",", + " WA", + "',", + " '", + "unit", + "':", + " '", + "c", + "elsius", + "',", + " '", + "days", + "':", + " '", + "7", + "'}}\n", + ""], + "logprobs":[[-0.00013815402053296566, + -9.113236427307129, + -10.571331977844238, + -14.099404335021973, + -14.28166675567627, + -15.583537101745605, + -15.81787395477295, + -16.143341064453125, + -16.143341064453125, + -16.260509490966797], + [0.0, + -26.896663665771484, + -27.32628059387207, + -27.41741180419922, + -32.07810974121094, + -32.07810974121094, + -32.28641128540039, + -32.29943084716797, + -32.44263458251953, + -32.520748138427734], + [0.0, + -22.444263458251953, + -24.527257919311523, + -27.15703773498535, + -28.016273498535156, + -28.2506103515625, + -28.693246841430664, + -29.070789337158203, + -29.565500259399414, + -29.812854766845703], + [0.0, + -27.860050201416016, + -28.641170501708984, + -29.448333740234375, + -30.932466506958008, + -31.63547706604004, + -32.33848571777344, + -32.85923767089844, + -33.17168426513672, + -33.45809555053711], + [0.0, + -31.81774139404297, + -31.895854949951172, + -32.05207824707031, + -35.43694305419922, + -36.3482551574707, + -38.61351013183594, + -39.26444625854492, + -40.61839294433594, + -41.71196365356445], + [0.0, + -27.33930206298828, + -27.834014892578125, + -28.849472045898438, + -30.567943572998047, + -32.98942565917969, + -33.067535400390625, + -33.067535400390625, + -35.67127990722656, + -35.69731903076172], + [0.0, + -25.33441925048828, + -26.063465118408203, + -26.219690322875977, + -26.2457275390625, + -26.53213882446289, + -27.365337371826172, + -28.354759216308594, + -28.667207717895508, + -28.74532127380371], + [0.0, + -24.423107147216797, + -24.579330444335938, + -26.81855010986328, + -28.12042236328125, + -28.32872200012207, + -28.61513328552246, + -29.16191864013672, + -29.187957763671875, + -29.240032196044922], + [0.0, + -22.027664184570312, + -23.850284576416016, + -23.980472564697266, + -24.292922973632812, + -24.787633895874023, + -29.279088973999023, + -29.55248260498047, + -29.903987884521484, + -30.190399169921875], + [0.0, + -31.609439849853516, + -31.817739486694336, + -32.54678726196289, + -32.676971435546875, + -32.781124114990234, + -32.98942565917969, + -33.106590270996094, + -33.57526397705078, + -34.369407653808594], + [0.0, + -29.34418296813965, + -29.63059425354004, + -30.021156311035156, + -30.984540939331055, + -33.21073913574219, + -34.30431365966797, + -34.56468963623047, + -34.70789337158203, + -34.79902648925781], + [0.0, + -25.438566207885742, + -25.69894027709961, + -30.190397262573242, + -30.802276611328125, + -31.58340072631836, + -31.609437942504883, + -31.64849281311035, + -31.973960876464844, + -32.29943084716797], + [0.0, + -27.157039642333984, + -32.104148864746094, + -32.33848571777344, + -34.04393768310547, + -34.12205505371094, + -34.40846252441406, + -34.42148208618164, + -34.772987365722656, + -34.87713623046875], + [0.0, + -24.813671112060547, + -26.974777221679688, + -31.010578155517578, + -31.08869171142578, + -32.1822624206543, + -35.33279037475586, + -35.489013671875, + -36.999183654785156, + -37.88446044921875], + [0.0, + -20.46541976928711, + -20.647682189941406, + -23.069164276123047, + -24.136699676513672, + -25.438570022583008, + -25.646869659423828, + -26.193655014038086, + -26.297805786132812, + -26.506103515625], + [0.0, + -27.18307113647461, + -28.30268096923828, + -28.56305694580078, + -29.526439666748047, + -32.416595458984375, + -35.202598571777344, + -36.426361083984375, + -39.31651306152344, + -39.38160705566406], + [0.0, + -18.7469482421875, + -20.100894927978516, + -21.402767181396484, + -21.428804397583008, + -22.20992660522461, + -22.34011459350586, + -22.730674743652344, + -23.069162368774414, + -23.980472564697266], + [-3.576278118089249e-07, + -15.2579345703125, + -16.481693267822266, + -17.991863250732422, + -19.215621948242188, + -20.25712013244629, + -21.350692749023438, + -22.314077377319336, + -22.496337890625, + -22.938974380493164], + [-0.08506780862808228, + -2.506549835205078, + -14.848289489746094, + -15.473188400268555, + -16.33242416381836, + -16.358461380004883, + -16.566761016845703, + -17.03543472290039, + -17.686370849609375, + -17.816556930541992], + [-0.0194891095161438, + -4.445854187011719, + -5.591499328613281, + -5.956024169921875, + -6.685070037841797, + -13.142353057861328, + -13.558952331542969, + -15.173273086547852, + -15.303461074829102, + -15.85024642944336], + [-0.0005990855861455202, + -7.4212646484375, + -15.675132751464844, + -15.72720718383789, + -16.76870346069336, + -16.76870346069336, + -17.706050872802734, + -18.669435501098633, + -19.398483276367188, + -19.658857345581055], + [0.0, + -24.110658645629883, + -25.829130172729492, + -26.011390686035156, + -26.011390686035156, + -26.532140731811523, + -26.58421516418457, + -27.651750564575195, + -27.75589942932129, + -28.055330276489258], + [-1.1408883333206177, + -0.38580334186553955, + -7.494022369384766, + -12.519245147705078, + -14.576202392578125, + -16.034297943115234, + -16.945608139038086, + -17.908992767333984, + -18.664077758789062, + -19.34105110168457], + [0.0, + -26.688365936279297, + -29.83889389038086, + -30.177383422851562, + -30.64605712890625, + -31.244916915893555, + -31.270954132080078, + -32.83319854736328, + -34.655818939208984, + -34.89015579223633], + [0.0, + -18.929210662841797, + -19.16354751586914, + -23.589908599853516, + -24.683481216430664, + -24.995929718017578, + -25.516677856445312, + -25.542715072631836, + -25.77705192565918, + -26.063465118408203], + [-0.2519786059856415, + -1.5017764568328857, + -12.437495231628418, + -15.457839012145996, + -15.744250297546387, + -16.837820053100586, + -17.41064453125, + -17.56686782836914, + -17.61894416809082, + -18.035541534423828], + [0.0, + -20.517494201660156, + -24.683483123779297, + -25.67290496826172, + -26.58421516418457, + -27.651750564575195, + -27.781936645507812, + -27.912124633789062, + -28.09438705444336, + -28.445892333984375], + [-3.40932747349143e-05, + -10.284820556640625, + -18.252273559570312, + -20.17904281616211, + -21.663175582885742, + -22.027700424194336, + -22.288074493408203, + -22.704673767089844, + -23.12127113342285, + -23.277496337890625], + [0.0, + -22.60049057006836, + -25.46460723876953, + -25.829130172729492, + -26.063467025756836, + -27.287227630615234, + -27.391376495361328, + -27.4694881439209, + -27.67778778076172, + -28.055330276489258], + [0.0, + -23.902362823486328, + -28.823436737060547, + -29.240036010742188, + -29.31814956665039, + -29.917007446289062, + -30.021160125732422, + -31.21887969970703, + -32.416603088378906, + -32.416603088378906], + [0.0, + -28.641170501708984, + -31.947925567626953, + -32.59886169433594, + -33.848655700683594, + -34.109031677246094, + -34.73393249511719, + -35.02033996582031, + -35.02033996582031, + -36.074859619140625], + [-0.013183215633034706, + -4.335395336151123, + -19.619365692138672, + -20.035964965820312, + -20.244266510009766, + -21.311800003051758, + -21.441987991333008, + -22.561595916748047, + -23.108383178710938, + -23.264606475830078], + [-8.344646857949556e-07, + -14.190400123596191, + -15.9088716506958, + -18.17412567138672, + -18.46053695678711, + -18.46053695678711, + -18.512611389160156, + -18.90317153930664, + -19.059398651123047, + -19.085433959960938], + [0.0, + -17.70545196533203, + -18.903175354003906, + -20.829944610595703, + -22.574451446533203, + -22.860862731933594, + -23.069162368774414, + -23.32953643798828, + -23.694061279296875, + -24.188772201538086], + [0.0, + -20.022781372070312, + -21.038240432739258, + -21.220502853393555, + -22.496337890625, + -22.769729614257812, + -23.589908599853516, + -23.65500259399414, + -23.94141387939453, + -24.266881942749023]] +} +] diff --git a/model_server/app/tests/test_hallucination.py b/model_server/app/tests/test_hallucination.py new file mode 100644 index 00000000..8b6c387e --- /dev/null +++ b/model_server/app/tests/test_hallucination.py @@ -0,0 +1,148 @@ +import json +from app.function_calling.hallucination_handler import HallucinationStateHandler +import pytest +import os + +# Get the directory of the current file +current_dir = os.path.dirname(__file__) + +# Construct the full path to the JSON file +json_file_path = os.path.join(current_dir, "test_cases.json") + +with open(json_file_path) as f: + test_cases = json.load(f) + +get_weather_api = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "str", + "description": "The location to get the weather for", + "format": "City, State", + }, + "unit": { + "type": "str", + "description": "The unit to return the weather in.", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + }, + "days": { + "type": "str", + "description": "the number of days for the request.", + }, + }, + "required": ["location", "days"], + }, + }, +} +function_description = get_weather_api["function"] +if type(function_description) != list: + function_description = [get_weather_api["function"]] + + +@pytest.mark.parametrize("case", test_cases) +def test_hallucination(case): + state = HallucinationStateHandler( + response_iterator=None, function=function_description + ) + for token, logprob in zip(case["tokens"], case["logprobs"]): + if token != "": + state.append_and_check_token_hallucination(token, logprob) + if state.hallucination: + break + assert state.hallucination == case["expect"] + + +@pytest.mark.parametrize("is_hallucinate_sample", [True, False]) +def test_hallucination_prompt(is_hallucinate_sample): + TASK_PROMPT = """ + You are a helpful assistant. + """.strip() + + TOOL_PROMPT = """ + # Tools + + You may call one or more functions to assist with the user query. + + You are provided with function signatures within XML tags: + + {tool_text} + + """.strip() + + FORMAT_PROMPT = """ + For each function call, return a json object with function name and arguments within XML tags: + + {"name": , "arguments": } + + """.strip() + + def convert_tools(tools): + return "\n".join([json.dumps(tool) for tool in tools]) + + def format_prompt(tools): + tool_text = convert_tools(tools) + + return ( + TASK_PROMPT + + "\n\n" + + TOOL_PROMPT.format(tool_text=tool_text) + + "\n\n" + + FORMAT_PROMPT + + "\n" + ) + + openai_format_tools = [get_weather_api] + + system_prompt = format_prompt(openai_format_tools) + + from openai import OpenAI + + client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY") + + # List models API + model = client.models.list().data[0].id + assert model == "Arch-Function" + if not is_hallucinate_sample: + messages = [ + {"role": "system", "content": system_prompt}, + # {"role": "user", "content": "can you help me check weather?"}, + {"role": "user", "content": "How is the weather in Seattle in 7 days?"}, + # {"role": "assistant", "content": "Of course!"}, + # {"role": "user", "content": "Seattle please"} + ] + else: + messages = [ + {"role": "system", "content": system_prompt}, + # {"role": "user", "content": "can you help me check weather?"}, + {"role": "user", "content": "How is the weather in Seattle in days?"}, + # {"role": "assistant", "content": "Of course!"}, + # {"role": "user", "content": "Seattle please"} + ] + + extra_body = { + "temperature": 0.6, + "top_p": 1.0, + "top_k": 50, + # "continue_final_message": True, + # "add_generation_prompt": False, + "logprobs": True, + "top_logprobs": 10, + } + + resp = client.chat.completions.create( + model="Arch-Function", messages=messages, extra_body=extra_body, stream=True + ) + + hallu = HallucinationStateHandler( + response_iterator=resp, function=function_description + ) + + for token in hallu: + assert len(hallu.tokens) >= 0 + assert hallu.hallucination == is_hallucinate_sample