mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix error in function name + new thresholds
This commit is contained in:
parent
2bd61d628c
commit
151a3f2aaa
1 changed files with 26 additions and 19 deletions
|
|
@ -43,8 +43,8 @@ HALLUCINATION_THRESHOLD_DICT = {
|
|||
"probability": 0.8,
|
||||
},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.2,
|
||||
"entropy": 0.35,
|
||||
"varentropy": 1.7,
|
||||
"probability": 0.8,
|
||||
},
|
||||
}
|
||||
|
|
@ -62,7 +62,7 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
|||
Returns:
|
||||
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
|
||||
"""
|
||||
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
|
||||
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
|
||||
|
||||
|
||||
def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
|
||||
|
|
@ -82,7 +82,7 @@ def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
|
|||
token_probs = torch.exp(log_probs)
|
||||
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
|
||||
varentropy = torch.sum(
|
||||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
||||
token_probs * (log_probs / math.log(2, math.e) + entropy.unsqueeze(-1)) ** 2,
|
||||
dim=-1,
|
||||
)
|
||||
return entropy.item(), varentropy.item(), token_probs[0].item()
|
||||
|
|
@ -303,22 +303,29 @@ class HallucinationState:
|
|||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||
|
||||
# checking if the parameter doesn't have enum and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and is_parameter_required(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
# check if function name is in function properties
|
||||
if self.function_name in self.function_properties:
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and is_parameter_required(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
)
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"enum",
|
||||
)
|
||||
):
|
||||
if self.parameter_name[-1] not in self.check_parameter_name:
|
||||
self._check_logprob()
|
||||
self.check_parameter_name[self.parameter_name[-1]] = True
|
||||
else:
|
||||
self.error_message = f"Function name {self.function_name} not found in function properties"
|
||||
logger.warning(
|
||||
f"Function name {self.function_name} not found in function properties"
|
||||
)
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"enum",
|
||||
)
|
||||
):
|
||||
if self.parameter_name[-1] not in self.check_parameter_name:
|
||||
self._check_logprob()
|
||||
self.check_parameter_name[self.parameter_name[-1]] = True
|
||||
else:
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
# if the state is parameter value and the token is an end token, change the state
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue