diff --git a/model_server/src/core/utils/hallucination_utils.py b/model_server/src/core/utils/hallucination_utils.py index 50006224..91effc92 100644 --- a/model_server/src/core/utils/hallucination_utils.py +++ b/model_server/src/core/utils/hallucination_utils.py @@ -44,7 +44,7 @@ HALLUCINATION_THRESHOLD_DICT = { }, MaskToken.PARAMETER_VALUE.value: { "entropy": 0.28, - "varentropy": 1.2, + "varentropy": 1.4, "probability": 0.8, }, } @@ -60,7 +60,7 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool: thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'. Returns: - bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise. + bool: True if both the entropy and varentropy exceeds their respective thresholds, False otherwise. """ return entropy > thd["entropy"] and varentropy > thd["varentropy"] @@ -82,7 +82,7 @@ def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]: token_probs = torch.exp(log_probs) entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e) varentropy = torch.sum( - token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2, + token_probs * (log_probs / math.log(2, math.e) + entropy.unsqueeze(-1)) ** 2, dim=-1, ) return entropy.item(), varentropy.item(), token_probs[0].item() @@ -303,22 +303,30 @@ class HallucinationState: self.mask.append(MaskToken.PARAMETER_VALUE) # checking if the parameter doesn't have enum and the token is the first parameter value token - if ( - len(self.mask) > 1 - and self.mask[-2] != MaskToken.PARAMETER_VALUE - and is_parameter_required( - self.function_properties[self.function_name], - self.parameter_name[-1], + # check if function name is in function properties + if self.function_name in self.function_properties: + if ( + len(self.mask) > 1 + and self.mask[-2] != MaskToken.PARAMETER_VALUE + and is_parameter_required( + self.function_properties[self.function_name], + self.parameter_name[-1], + ) + and not is_parameter_property( + self.function_properties[self.function_name], + self.parameter_name[-1], + "enum", + ) + ): + if self.parameter_name[-1] not in self.check_parameter_name: + self._check_logprob() + self.check_parameter_name[self.parameter_name[-1]] = True + else: + self._check_logprob() + self.error_message = f"Function name {self.function_name} not found in function properties" + logger.warning( + f"Function name {self.function_name} not found in function properties" ) - and not is_parameter_property( - self.function_properties[self.function_name], - self.parameter_name[-1], - "enum", - ) - ): - if self.parameter_name[-1] not in self.check_parameter_name: - self._check_logprob() - self.check_parameter_name[self.parameter_name[-1]] = True else: self.mask.append(MaskToken.NOT_USED) # if the state is parameter value and the token is an end token, change the state diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index 3f7971d8..0f2c9995 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -54,20 +54,6 @@ def get_hallucination_data_complex(): return req, True, True, True -def get_hallucination_data_easy(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in Seattle?") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - # model will hallucinate - return req, True, True, True - - def get_hallucination_data_medium(): # Create instances of the Message class message1 = Message(role="user", content="How is the weather in?") @@ -142,7 +128,6 @@ def get_greeting_data(): "get_data_func", [ get_hallucination_data_complex, - get_hallucination_data_easy, get_complete_data, get_irrelevant_data, get_complete_data_2, diff --git a/tests/modelserver/test_hallucination_data.yaml b/tests/modelserver/test_hallucination_data.yaml index 8e6e6b26..935a8f5f 100644 --- a/tests/modelserver/test_hallucination_data.yaml +++ b/tests/modelserver/test_hallucination_data.yaml @@ -63,7 +63,7 @@ test_cases: - role: "assistant" content: "Can you please provide me the days for the weather forecast?" - role: "user" - content: "los angeles in 5 days" + content: "5 days" tools: - type: "function" function: @@ -82,7 +82,7 @@ test_cases: required: ["location", "days"] expected: - type: "metadata" - hallucination: true + hallucination: false - id: "[WEATHER AGENT] - multi turn, single tool, clarification" input: