fix error in function name + new thresholds

This commit is contained in:
cotran 2025-02-12 10:51:51 -08:00
parent 2bd61d628c
commit 151a3f2aaa

View file

@ -43,8 +43,8 @@ HALLUCINATION_THRESHOLD_DICT = {
"probability": 0.8,
},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.28,
"varentropy": 1.2,
"entropy": 0.35,
"varentropy": 1.7,
"probability": 0.8,
},
}
@ -62,7 +62,7 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
Returns:
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
"""
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
@ -82,7 +82,7 @@ def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
token_probs = torch.exp(log_probs)
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
varentropy = torch.sum(
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
token_probs * (log_probs / math.log(2, math.e) + entropy.unsqueeze(-1)) ** 2,
dim=-1,
)
return entropy.item(), varentropy.item(), token_probs[0].item()
@ -303,22 +303,29 @@ class HallucinationState:
self.mask.append(MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have enum and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and is_parameter_required(
self.function_properties[self.function_name],
self.parameter_name[-1],
# check if function name is in function properties
if self.function_name in self.function_properties:
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and is_parameter_required(
self.function_properties[self.function_name],
self.parameter_name[-1],
)
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"enum",
)
):
if self.parameter_name[-1] not in self.check_parameter_name:
self._check_logprob()
self.check_parameter_name[self.parameter_name[-1]] = True
else:
self.error_message = f"Function name {self.function_name} not found in function properties"
logger.warning(
f"Function name {self.function_name} not found in function properties"
)
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"enum",
)
):
if self.parameter_name[-1] not in self.check_parameter_name:
self._check_logprob()
self.check_parameter_name[self.parameter_name[-1]] = True
else:
self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state