From 151a3f2aaa7c1b32501138d76f30a98f1133efc6 Mon Sep 17 00:00:00 2001 From: cotran Date: Wed, 12 Feb 2025 10:51:51 -0800 Subject: [PATCH] fix error in function name + new thresholds --- .../src/core/utils/hallucination_utils.py | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/model_server/src/core/utils/hallucination_utils.py b/model_server/src/core/utils/hallucination_utils.py index 50006224..55de4e9a 100644 --- a/model_server/src/core/utils/hallucination_utils.py +++ b/model_server/src/core/utils/hallucination_utils.py @@ -43,8 +43,8 @@ HALLUCINATION_THRESHOLD_DICT = { "probability": 0.8, }, MaskToken.PARAMETER_VALUE.value: { - "entropy": 0.28, - "varentropy": 1.2, + "entropy": 0.35, + "varentropy": 1.7, "probability": 0.8, }, } @@ -62,7 +62,7 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: Returns: bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise. """ - return entropy > thd["entropy"] and varentropy > thd["varentropy"] + return entropy > thd["entropy"] or varentropy > thd["varentropy"] def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]: @@ -82,7 +82,7 @@ def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]: token_probs = torch.exp(log_probs) entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e) varentropy = torch.sum( - token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, + token_probs * (log_probs / math.log(2, math.e) + entropy.unsqueeze(-1)) ** 2, dim=-1, ) return entropy.item(), varentropy.item(), token_probs[0].item() @@ -303,22 +303,29 @@ class HallucinationState: self.mask.append(MaskToken.PARAMETER_VALUE) # checking if the parameter doesn't have enum and the token is the first parameter value token - if ( - len(self.mask) > 1 - and self.mask[-2] != MaskToken.PARAMETER_VALUE - and is_parameter_required( - self.function_properties[self.function_name], - self.parameter_name[-1], + # check if function name is in function properties + if self.function_name in self.function_properties: + if ( + len(self.mask) > 1 + and self.mask[-2] != MaskToken.PARAMETER_VALUE + and is_parameter_required( + self.function_properties[self.function_name], + self.parameter_name[-1], + ) + and not is_parameter_property( + self.function_properties[self.function_name], + self.parameter_name[-1], + "enum", + ) + ): + if self.parameter_name[-1] not in self.check_parameter_name: + self._check_logprob() + self.check_parameter_name[self.parameter_name[-1]] = True + else: + self.error_message = f"Function name {self.function_name} not found in function properties" + logger.warning( + f"Function name {self.function_name} not found in function properties" ) - and not is_parameter_property( - self.function_properties[self.function_name], - self.parameter_name[-1], - "enum", - ) - ): - if self.parameter_name[-1] not in self.check_parameter_name: - self._check_logprob() - self.check_parameter_name[self.parameter_name[-1]] = True else: self.mask.append(MaskToken.NOT_USED) # if the state is parameter value and the token is an end token, change the state