mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
hallucination with log probs (#281)
* first init * fix * fix test * new implemenetation * fix bug * fix bug * fix bug * address issue * address issues * address comments * fix test * fix * move constatns * remove consts
This commit is contained in:
parent
f5cdafb7c8
commit
cadd3cdaf9
5 changed files with 1269 additions and 1 deletions
|
|
@ -19,6 +19,7 @@ arch_function_generation_params = {
|
|||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
# "top_logprobs": 10,
|
||||
}
|
||||
|
||||
arch_guard_model_type = {
|
||||
|
|
@ -34,3 +35,4 @@ zero_shot_model = loader.get_zero_shot_model()
|
|||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
# Patterns for function name and parameter parsing
|
||||
|
|
|
|||
324
model_server/app/function_calling/hallucination_handler.py
Normal file
324
model_server/app/function_calling/hallucination_handler.py
Normal file
|
|
@ -0,0 +1,324 @@
|
|||
import json
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import itertools
|
||||
from enum import Enum
|
||||
|
||||
# constants
|
||||
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
||||
FUNC_NAME_END_TOKEN = ('",', "',")
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
|
||||
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
||||
PARAMETER_NAME_START_PATTERN = (',"', ",'")
|
||||
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
||||
|
||||
|
||||
# Thresholds
|
||||
class MaskToken(Enum):
|
||||
FUNCTION_NAME = "f"
|
||||
PARAMETER_VALUE = "v"
|
||||
PARAMETER_NAME = "p"
|
||||
NOT_USED = "e"
|
||||
TOOL_CALL = "t"
|
||||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.5,
|
||||
"varentropy": 2.5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
||||
"""
|
||||
Check if the given entropy or variance of entropy exceeds the specified thresholds.
|
||||
|
||||
Args:
|
||||
entropy (float): The entropy value to check.
|
||||
varentropy (float): The variance of entropy value to check.
|
||||
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
|
||||
|
||||
Returns:
|
||||
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
|
||||
"""
|
||||
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
|
||||
|
||||
|
||||
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
|
||||
|
||||
Args:
|
||||
log_probs (list of float): A list of log probabilities.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- log_probs (list of float): The input log probabilities as a list.
|
||||
- entropy (float): The calculated entropy.
|
||||
- varentropy (float): The calculated variance of entropy.
|
||||
"""
|
||||
log_probs = torch.tensor(log_probs)
|
||||
token_probs = torch.exp(log_probs)
|
||||
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
|
||||
varentropy = torch.sum(
|
||||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
||||
dim=-1,
|
||||
)
|
||||
return entropy.item(), varentropy.item()
|
||||
|
||||
|
||||
def is_parameter_property(
|
||||
function_description: Dict, parameter_name: str, property_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter in an API description has a specific property.
|
||||
|
||||
Args:
|
||||
function_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
property_name (str): The property to look for (e.g., 'format', 'default').
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
parameters = function_description.get("properties", {})
|
||||
parameter_info = parameters.get(parameter_name, {})
|
||||
|
||||
return property_name in parameter_info
|
||||
|
||||
|
||||
class HallucinationStateHandler:
|
||||
"""
|
||||
A class to handle the state of hallucination detection in token processing.
|
||||
|
||||
Attributes:
|
||||
tokens (list): List of tokens processed.
|
||||
logprobs (list): List of log probabilities for each token.
|
||||
state (str): Current state of the handler.
|
||||
mask (list): List of masks indicating the type of each token.
|
||||
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
|
||||
hallucination (bool): Flag indicating if a hallucination is detected.
|
||||
hallucination_message (str): Message describing the hallucination.
|
||||
parameter_name (list): List of extracted parameter names.
|
||||
function_description (dict): Description of functions and their parameters.
|
||||
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
||||
"""
|
||||
|
||||
def __init__(self, response_iterator=None, function=None):
|
||||
"""
|
||||
Initializes the HallucinationStateHandler with default values.
|
||||
"""
|
||||
self.tokens: List[str] = []
|
||||
self.logprobs: List[float] = []
|
||||
self.state: str = None
|
||||
self.mask: List[str] = []
|
||||
self.parameter_name_done: bool = False
|
||||
self.hallucination: bool = False
|
||||
self.error_message: str = ""
|
||||
self.error_type: str = ""
|
||||
self.parameter_name: List[str] = []
|
||||
self.token_probs_map: List[Tuple[str, float, float]] = []
|
||||
self.response_iterator = response_iterator
|
||||
self._process_function(function)
|
||||
|
||||
def _process_function(self, function):
|
||||
self.function = function
|
||||
if self.function is None:
|
||||
raise ValueError("API descriptions not set.")
|
||||
parameter_names = {}
|
||||
for func in self.function:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
self.function_description = parameter_names
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
||||
|
||||
def append_and_check_token_hallucination(self, token, logprob):
|
||||
"""
|
||||
Check if the given token is hallucinated based on the log probability.
|
||||
|
||||
Args:
|
||||
token (str): The token to check.
|
||||
logprob (float): The log probability of the token.
|
||||
|
||||
Returns:
|
||||
bool: True if the token is hallucinated, False otherwise.
|
||||
"""
|
||||
self.tokens.append(token)
|
||||
self.logprobs.append(logprob)
|
||||
self._process_token()
|
||||
return self.hallucination
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.response_iterator is not None:
|
||||
try:
|
||||
r = next(self.response_iterator)
|
||||
if hasattr(r.choices[0].delta, "content"):
|
||||
token_content = r.choices[0].delta.content
|
||||
if token_content:
|
||||
try:
|
||||
logprobs = [
|
||||
p.logprob
|
||||
for p in r.choices[0].logprobs.content[0].top_logprobs
|
||||
]
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error extracting logprobs from response: {e}"
|
||||
)
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, logprobs
|
||||
)
|
||||
return token_content
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def _process_token(self):
|
||||
"""
|
||||
Processes the current token and updates the state and mask accordingly.
|
||||
Detects hallucinations based on the token type and log probabilities.
|
||||
"""
|
||||
content = "".join(self.tokens).replace(" ", "")
|
||||
if self.tokens[-1] == TOOL_CALL_TOKEN:
|
||||
self.mask.append(MaskToken.TOOL_CALL)
|
||||
self._check_logprob()
|
||||
|
||||
# Function name extraction logic
|
||||
# If the state is function name and the token is not an end token, add to the mask
|
||||
if self.state == "function_name":
|
||||
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
|
||||
self.mask.append(MaskToken.FUNCTION_NAME)
|
||||
else:
|
||||
self.state = None
|
||||
self._is_function_name_hallucinated()
|
||||
|
||||
# Check if the token is a function name start token, change the state
|
||||
if content.endswith(FUNC_NAME_START_PATTERN):
|
||||
self.state = "function_name"
|
||||
|
||||
# Parameter name extraction logic
|
||||
# if the state is parameter name and the token is not an end token, add to the mask
|
||||
if self.state == "parameter_name" and not content.endswith(
|
||||
PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.mask.append(MaskToken.PARAMETER_NAME)
|
||||
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
|
||||
# The need for parameter name done is to allow the check of parameter value pattern
|
||||
elif self.state == "parameter_name" and content.endswith(
|
||||
PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.state = None
|
||||
self._is_parameter_name_hallucinated()
|
||||
self.parameter_name_done = True
|
||||
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
PARAMETER_NAME_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# if token is a first parameter value start token, change the state
|
||||
if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# Parameter value extraction logic
|
||||
# if the state is parameter value and the token is not an end token, add to the mask
|
||||
if self.state == "parameter_value" and not content.endswith(
|
||||
PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
# checking if the token is a value token and is not empty
|
||||
if self.tokens[-1].strip() not in ['"', ""]:
|
||||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"default",
|
||||
)
|
||||
):
|
||||
self._check_logprob()
|
||||
else:
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
# if the state is parameter value and the token is an end token, change the state
|
||||
elif self.state == "parameter_value" and content.endswith(
|
||||
PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
self.state = None
|
||||
# if the parameter name is done and the token is a parameter value start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
PARAMETER_VALUE_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_value"
|
||||
|
||||
# Maintain consistency between stack and mask
|
||||
# If the mask length is less than tokens, add an not used (e) token to the mask
|
||||
if len(self.mask) != len(self.tokens):
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
|
||||
def _check_logprob(self):
|
||||
"""
|
||||
Checks the log probability of the current token and updates the token probability map.
|
||||
Detects hallucinations based on entropy and variance of entropy.
|
||||
"""
|
||||
probs = self.logprobs[-1]
|
||||
entropy, varentropy = calculate_entropy(probs)
|
||||
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
|
||||
|
||||
if check_threshold(
|
||||
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
||||
):
|
||||
self.hallucination = True
|
||||
self.error_type = "Hallucination"
|
||||
self.error_message = (
|
||||
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
|
||||
)
|
||||
|
||||
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
|
||||
"""
|
||||
Counts the number of consecutive occurrences of a given token in the mask.
|
||||
|
||||
Args:
|
||||
token (str): The token to count in the mask.
|
||||
|
||||
Returns:
|
||||
int: The number of consecutive occurrences of the token.
|
||||
"""
|
||||
return (
|
||||
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
|
||||
if self.mask and self.mask[-1] == token
|
||||
else 0
|
||||
)
|
||||
|
||||
def _is_function_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted function name against the function descriptions.
|
||||
Detects hallucinations if the function name is not found.
|
||||
"""
|
||||
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
if self.function_name not in self.function_description.keys():
|
||||
self.error_type = "function_name"
|
||||
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||
|
||||
def _is_parameter_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted parameter name against the function descriptions.
|
||||
Detects hallucinations if the parameter name is not found.
|
||||
"""
|
||||
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
|
||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||
self.parameter_name.append(parameter_name)
|
||||
if parameter_name not in self.function_description[self.function_name]:
|
||||
self.error_type = "parameter_name"
|
||||
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
|
||||
|
|
@ -134,4 +134,4 @@ class ArchFunctionHandler:
|
|||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str
|
||||
return fixed_str.replace("'", '"')
|
||||
|
|
|
|||
794
model_server/app/tests/test_cases.json
Normal file
794
model_server/app/tests/test_cases.json
Normal file
|
|
@ -0,0 +1,794 @@
|
|||
[{
|
||||
"case": "tool_call_halluciation",
|
||||
"tokens" : ["<tool_call>"],
|
||||
"expect": 1,
|
||||
"logprobs": [[-0.3333307206630707,
|
||||
-1.5310522317886353,
|
||||
-3.5098977088928223,
|
||||
-3.9004578590393066,
|
||||
-5.775152683258057,
|
||||
-5.814209461212158,
|
||||
-5.9574151039123535,
|
||||
-6.0094895362854,
|
||||
-6.0094895362854,
|
||||
-6.673445224761963]]
|
||||
},
|
||||
{
|
||||
"case" : "parameter_value_hallucination",
|
||||
"expect" : 0,
|
||||
"tokens" : ["<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Sea",
|
||||
",",
|
||||
" Australia",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"1",
|
||||
"'}}\n",
|
||||
"</tool_call>"],
|
||||
"logprobs": [[-0.008103232830762863,
|
||||
-5.085402488708496,
|
||||
-6.777836799621582,
|
||||
-7.558959007263184,
|
||||
-9.850253105163574,
|
||||
-10.266852378845215,
|
||||
-10.540244102478027,
|
||||
-10.722506523132324,
|
||||
-10.800618171691895,
|
||||
-10.917786598205566],
|
||||
[0.0,
|
||||
-23.25142478942871,
|
||||
-25.139137268066406,
|
||||
-26.2847843170166,
|
||||
-28.992677688598633,
|
||||
-29.070789337158203,
|
||||
-29.55248260498047,
|
||||
-29.91700553894043,
|
||||
-30.20341682434082,
|
||||
-30.307567596435547],
|
||||
[0.0,
|
||||
-21.66313934326172,
|
||||
-23.06916046142578,
|
||||
-23.32953453063965,
|
||||
-25.65988540649414,
|
||||
-25.985353469848633,
|
||||
-26.519121170043945,
|
||||
-27.07892417907715,
|
||||
-27.977216720581055,
|
||||
-28.458908081054688],
|
||||
[0.0,
|
||||
-28.094383239746094,
|
||||
-28.56305694580078,
|
||||
-29.109844207763672,
|
||||
-29.44832992553711,
|
||||
-31.79170036315918,
|
||||
-32.0,
|
||||
-32.05207443237305,
|
||||
-32.31244659423828,
|
||||
-32.364524841308594],
|
||||
[0.0,
|
||||
-30.489830017089844,
|
||||
-31.140766143798828,
|
||||
-31.81774139404297,
|
||||
-34.525634765625,
|
||||
-35.8275032043457,
|
||||
-36.504478454589844,
|
||||
-39.05614471435547,
|
||||
-40.123680114746094,
|
||||
-40.696502685546875],
|
||||
[0.0,
|
||||
-25.646865844726562,
|
||||
-26.66232681274414,
|
||||
-27.781936645507812,
|
||||
-28.979660034179688,
|
||||
-31.140764236450195,
|
||||
-31.92188835144043,
|
||||
-31.973962783813477,
|
||||
-33.04149627685547,
|
||||
-33.58828353881836],
|
||||
[0.0,
|
||||
-23.511798858642578,
|
||||
-24.136695861816406,
|
||||
-25.230268478393555,
|
||||
-25.777053833007812,
|
||||
-25.80309295654297,
|
||||
-26.45402717590332,
|
||||
-26.636289596557617,
|
||||
-26.740440368652344,
|
||||
-26.896663665771484],
|
||||
[0.0,
|
||||
-22.366153717041016,
|
||||
-24.683483123779297,
|
||||
-26.610252380371094,
|
||||
-26.610252380371094,
|
||||
-27.313264846801758,
|
||||
-27.67778778076172,
|
||||
-28.510986328125,
|
||||
-28.615135192871094,
|
||||
-29.13588523864746],
|
||||
[0.0,
|
||||
-22.52237319946289,
|
||||
-24.292919158935547,
|
||||
-24.344993591308594,
|
||||
-24.39706802368164,
|
||||
-24.73555564880371,
|
||||
-29.943042755126953,
|
||||
-29.969079971313477,
|
||||
-30.021154403686523,
|
||||
-30.0341739654541],
|
||||
[0.0,
|
||||
-30.17738151550293,
|
||||
-30.411718368530273,
|
||||
-30.88039207458496,
|
||||
-30.984540939331055,
|
||||
-31.270952224731445,
|
||||
-31.895851135253906,
|
||||
-32.46867370605469,
|
||||
-32.624900817871094,
|
||||
-33.484134674072266],
|
||||
[0.0,
|
||||
-28.146459579467773,
|
||||
-29.396255493164062,
|
||||
-30.099267959594727,
|
||||
-31.127744674682617,
|
||||
-31.179821014404297,
|
||||
-32.807159423828125,
|
||||
-33.7445068359375,
|
||||
-33.770545959472656,
|
||||
-34.069976806640625],
|
||||
[0.0,
|
||||
-26.323841094970703,
|
||||
-26.558177947998047,
|
||||
-30.515867233276367,
|
||||
-30.932466506958008,
|
||||
-31.37510108947754,
|
||||
-31.531326293945312,
|
||||
-31.70056915283203,
|
||||
-32.065093994140625,
|
||||
-32.364524841308594],
|
||||
[0.0,
|
||||
-26.922698974609375,
|
||||
-30.28152847290039,
|
||||
-31.505287170410156,
|
||||
-33.30187225341797,
|
||||
-33.73148727416992,
|
||||
-34.27827453613281,
|
||||
-34.33034896850586,
|
||||
-34.460533142089844,
|
||||
-34.720909118652344],
|
||||
[0.0,
|
||||
-21.532955169677734,
|
||||
-26.94873809814453,
|
||||
-29.109848022460938,
|
||||
-30.80228042602539,
|
||||
-31.55736541748047,
|
||||
-33.484134674072266,
|
||||
-34.681854248046875,
|
||||
-35.384864807128906,
|
||||
-35.853538513183594],
|
||||
[0.0,
|
||||
-19.502033233642578,
|
||||
-20.46541976928711,
|
||||
-24.110658645629883,
|
||||
-24.501218795776367,
|
||||
-25.256305694580078,
|
||||
-25.82912826538086,
|
||||
-25.881202697753906,
|
||||
-26.063465118408203,
|
||||
-26.063465118408203],
|
||||
[0.0,
|
||||
-24.37103271484375,
|
||||
-25.256305694580078,
|
||||
-25.933277130126953,
|
||||
-26.714401245117188,
|
||||
-28.2506103515625,
|
||||
-31.010576248168945,
|
||||
-32.07810974121094,
|
||||
-34.62977981567383,
|
||||
-35.241661071777344],
|
||||
[-1.1920922133867862e-06,
|
||||
-14.398697853088379,
|
||||
-14.424736976623535,
|
||||
-17.158666610717773,
|
||||
-17.41904067993164,
|
||||
-18.200162887573242,
|
||||
-18.434499740600586,
|
||||
-18.66883659362793,
|
||||
-19.71033477783203,
|
||||
-19.71033477783203],
|
||||
[-0.0001445904199499637,
|
||||
-8.98305892944336,
|
||||
-11.35246467590332,
|
||||
-13.1490478515625,
|
||||
-13.669795989990234,
|
||||
-14.073375701904297,
|
||||
-14.516012191772461,
|
||||
-14.555068969726562,
|
||||
-15.622602462768555,
|
||||
-15.635622024536133],
|
||||
[-0.44747352600097656,
|
||||
-1.0202960968017578,
|
||||
-8.467000961303711,
|
||||
-10.914518356323242,
|
||||
-11.25300407409668,
|
||||
-11.435266494750977,
|
||||
-12.346576690673828,
|
||||
-13.075624465942383,
|
||||
-13.12769889831543,
|
||||
-13.231849670410156],
|
||||
[-3.123767137527466,
|
||||
-1.1188862323760986,
|
||||
-1.639634370803833,
|
||||
-2.0562336444854736,
|
||||
-2.8633930683135986,
|
||||
-2.9675419330596924,
|
||||
-3.4882919788360596,
|
||||
-3.69659161567688,
|
||||
-4.217339515686035,
|
||||
-4.243376731872559],
|
||||
[-7.199982064776123e-05,
|
||||
-9.76410961151123,
|
||||
-11.144091606140137,
|
||||
-16.507802963256836,
|
||||
-17.132701873779297,
|
||||
-17.44515037536621,
|
||||
-17.9138240814209,
|
||||
-18.33042335510254,
|
||||
-18.9162654876709,
|
||||
-19.39795684814453],
|
||||
[0.0,
|
||||
-22.991050720214844,
|
||||
-23.824249267578125,
|
||||
-24.969894409179688,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.480066299438477,
|
||||
-26.909683227539062,
|
||||
-27.33930206298828,
|
||||
-27.391376495361328],
|
||||
[-0.21928852796554565,
|
||||
-1.625309705734253,
|
||||
-9.775025367736816,
|
||||
-12.977627754211426,
|
||||
-16.388530731201172,
|
||||
-17.091541290283203,
|
||||
-19.044347763061523,
|
||||
-19.38283348083496,
|
||||
-19.460947036743164,
|
||||
-19.59113311767578],
|
||||
[0.0,
|
||||
-24.006507873535156,
|
||||
-27.443450927734375,
|
||||
-27.729862213134766,
|
||||
-28.12042236328125,
|
||||
-28.276647567749023,
|
||||
-28.927583694458008,
|
||||
-30.099267959594727,
|
||||
-31.479251861572266,
|
||||
-32.07810974121094],
|
||||
[0.0,
|
||||
-18.17412567138672,
|
||||
-18.772987365722656,
|
||||
-21.689178466796875,
|
||||
-21.92351531982422,
|
||||
-23.7200984954834,
|
||||
-23.79821014404297,
|
||||
-23.79821014404297,
|
||||
-24.032546997070312,
|
||||
-25.308382034301758],
|
||||
[-0.12947827577590942,
|
||||
-2.1083219051361084,
|
||||
-12.419143676757812,
|
||||
-15.23118782043457,
|
||||
-15.595710754394531,
|
||||
-15.830047607421875,
|
||||
-17.001731872558594,
|
||||
-17.60059356689453,
|
||||
-18.121341705322266,
|
||||
-18.251529693603516],
|
||||
[0.0,
|
||||
-19.449962615966797,
|
||||
-24.371034622192383,
|
||||
-24.917821884155273,
|
||||
-25.529701232910156,
|
||||
-25.85516929626465,
|
||||
-26.037429809570312,
|
||||
-26.115543365478516,
|
||||
-26.623271942138672,
|
||||
-26.649309158325195],
|
||||
[-0.03332124650478363,
|
||||
-3.4181859493255615,
|
||||
-15.759925842285156,
|
||||
-15.812002182006836,
|
||||
-16.593124389648438,
|
||||
-17.894996643066406,
|
||||
-18.09027671813965,
|
||||
-18.79328727722168,
|
||||
-19.144792556762695,
|
||||
-20.147233963012695],
|
||||
[0.0,
|
||||
-21.142393112182617,
|
||||
-22.157852172851562,
|
||||
-23.511798858642578,
|
||||
-24.657445907592773,
|
||||
-25.021968841552734,
|
||||
-25.5427188873291,
|
||||
-25.59479331970215,
|
||||
-25.75101661682129,
|
||||
-25.95931625366211],
|
||||
[0.0,
|
||||
-23.04312515258789,
|
||||
-24.94385528564453,
|
||||
-26.323841094970703,
|
||||
-27.54759979248047,
|
||||
-28.563060760498047,
|
||||
-29.786819458007812,
|
||||
-30.620018005371094,
|
||||
-30.69812774658203,
|
||||
-31.08869171142578],
|
||||
[0.0,
|
||||
-26.167617797851562,
|
||||
-28.771360397338867,
|
||||
-29.55248260498047,
|
||||
-30.906429290771484,
|
||||
-31.114728927612305,
|
||||
-31.414159774780273,
|
||||
-31.622459411621094,
|
||||
-31.713590621948242,
|
||||
-31.726608276367188],
|
||||
[-0.05012698099017143,
|
||||
-3.018392562866211,
|
||||
-11.740934371948242,
|
||||
-13.146955490112305,
|
||||
-13.797887802124023,
|
||||
-14.943536758422852,
|
||||
-16.037107467651367,
|
||||
-16.375595092773438,
|
||||
-16.714080810546875,
|
||||
-17.36501693725586],
|
||||
[-0.9704352021217346,
|
||||
-0.7360983490943909,
|
||||
-2.1941938400268555,
|
||||
-4.225115776062012,
|
||||
-5.0062360763549805,
|
||||
-5.2666120529174805,
|
||||
-5.839434623718262,
|
||||
-7.2714948654174805,
|
||||
-8.33902645111084,
|
||||
-8.495253562927246],
|
||||
[-0.014467108063399792,
|
||||
-4.258565902709961,
|
||||
-8.789079666137695,
|
||||
-10.429437637329102,
|
||||
-10.793962478637695,
|
||||
-11.835458755493164,
|
||||
-11.939607620239258,
|
||||
-13.31959342956543,
|
||||
-13.866378784179688,
|
||||
-15.038063049316406],
|
||||
[0.0,
|
||||
-20.08787727355957,
|
||||
-21.350692749023438,
|
||||
-21.415786743164062,
|
||||
-21.50691795349121,
|
||||
-21.50691795349121,
|
||||
-22.7176570892334,
|
||||
-24.13669776916504,
|
||||
-24.188772201538086,
|
||||
-24.34499740600586]]
|
||||
},
|
||||
{
|
||||
"case": "fail_case",
|
||||
"expect" : 0,
|
||||
"tokens" : ["<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Seattle",
|
||||
",",
|
||||
" WA",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"7",
|
||||
"'}}\n",
|
||||
"</tool_call>"],
|
||||
"logprobs":[[-0.00013815402053296566,
|
||||
-9.113236427307129,
|
||||
-10.571331977844238,
|
||||
-14.099404335021973,
|
||||
-14.28166675567627,
|
||||
-15.583537101745605,
|
||||
-15.81787395477295,
|
||||
-16.143341064453125,
|
||||
-16.143341064453125,
|
||||
-16.260509490966797],
|
||||
[0.0,
|
||||
-26.896663665771484,
|
||||
-27.32628059387207,
|
||||
-27.41741180419922,
|
||||
-32.07810974121094,
|
||||
-32.07810974121094,
|
||||
-32.28641128540039,
|
||||
-32.29943084716797,
|
||||
-32.44263458251953,
|
||||
-32.520748138427734],
|
||||
[0.0,
|
||||
-22.444263458251953,
|
||||
-24.527257919311523,
|
||||
-27.15703773498535,
|
||||
-28.016273498535156,
|
||||
-28.2506103515625,
|
||||
-28.693246841430664,
|
||||
-29.070789337158203,
|
||||
-29.565500259399414,
|
||||
-29.812854766845703],
|
||||
[0.0,
|
||||
-27.860050201416016,
|
||||
-28.641170501708984,
|
||||
-29.448333740234375,
|
||||
-30.932466506958008,
|
||||
-31.63547706604004,
|
||||
-32.33848571777344,
|
||||
-32.85923767089844,
|
||||
-33.17168426513672,
|
||||
-33.45809555053711],
|
||||
[0.0,
|
||||
-31.81774139404297,
|
||||
-31.895854949951172,
|
||||
-32.05207824707031,
|
||||
-35.43694305419922,
|
||||
-36.3482551574707,
|
||||
-38.61351013183594,
|
||||
-39.26444625854492,
|
||||
-40.61839294433594,
|
||||
-41.71196365356445],
|
||||
[0.0,
|
||||
-27.33930206298828,
|
||||
-27.834014892578125,
|
||||
-28.849472045898438,
|
||||
-30.567943572998047,
|
||||
-32.98942565917969,
|
||||
-33.067535400390625,
|
||||
-33.067535400390625,
|
||||
-35.67127990722656,
|
||||
-35.69731903076172],
|
||||
[0.0,
|
||||
-25.33441925048828,
|
||||
-26.063465118408203,
|
||||
-26.219690322875977,
|
||||
-26.2457275390625,
|
||||
-26.53213882446289,
|
||||
-27.365337371826172,
|
||||
-28.354759216308594,
|
||||
-28.667207717895508,
|
||||
-28.74532127380371],
|
||||
[0.0,
|
||||
-24.423107147216797,
|
||||
-24.579330444335938,
|
||||
-26.81855010986328,
|
||||
-28.12042236328125,
|
||||
-28.32872200012207,
|
||||
-28.61513328552246,
|
||||
-29.16191864013672,
|
||||
-29.187957763671875,
|
||||
-29.240032196044922],
|
||||
[0.0,
|
||||
-22.027664184570312,
|
||||
-23.850284576416016,
|
||||
-23.980472564697266,
|
||||
-24.292922973632812,
|
||||
-24.787633895874023,
|
||||
-29.279088973999023,
|
||||
-29.55248260498047,
|
||||
-29.903987884521484,
|
||||
-30.190399169921875],
|
||||
[0.0,
|
||||
-31.609439849853516,
|
||||
-31.817739486694336,
|
||||
-32.54678726196289,
|
||||
-32.676971435546875,
|
||||
-32.781124114990234,
|
||||
-32.98942565917969,
|
||||
-33.106590270996094,
|
||||
-33.57526397705078,
|
||||
-34.369407653808594],
|
||||
[0.0,
|
||||
-29.34418296813965,
|
||||
-29.63059425354004,
|
||||
-30.021156311035156,
|
||||
-30.984540939331055,
|
||||
-33.21073913574219,
|
||||
-34.30431365966797,
|
||||
-34.56468963623047,
|
||||
-34.70789337158203,
|
||||
-34.79902648925781],
|
||||
[0.0,
|
||||
-25.438566207885742,
|
||||
-25.69894027709961,
|
||||
-30.190397262573242,
|
||||
-30.802276611328125,
|
||||
-31.58340072631836,
|
||||
-31.609437942504883,
|
||||
-31.64849281311035,
|
||||
-31.973960876464844,
|
||||
-32.29943084716797],
|
||||
[0.0,
|
||||
-27.157039642333984,
|
||||
-32.104148864746094,
|
||||
-32.33848571777344,
|
||||
-34.04393768310547,
|
||||
-34.12205505371094,
|
||||
-34.40846252441406,
|
||||
-34.42148208618164,
|
||||
-34.772987365722656,
|
||||
-34.87713623046875],
|
||||
[0.0,
|
||||
-24.813671112060547,
|
||||
-26.974777221679688,
|
||||
-31.010578155517578,
|
||||
-31.08869171142578,
|
||||
-32.1822624206543,
|
||||
-35.33279037475586,
|
||||
-35.489013671875,
|
||||
-36.999183654785156,
|
||||
-37.88446044921875],
|
||||
[0.0,
|
||||
-20.46541976928711,
|
||||
-20.647682189941406,
|
||||
-23.069164276123047,
|
||||
-24.136699676513672,
|
||||
-25.438570022583008,
|
||||
-25.646869659423828,
|
||||
-26.193655014038086,
|
||||
-26.297805786132812,
|
||||
-26.506103515625],
|
||||
[0.0,
|
||||
-27.18307113647461,
|
||||
-28.30268096923828,
|
||||
-28.56305694580078,
|
||||
-29.526439666748047,
|
||||
-32.416595458984375,
|
||||
-35.202598571777344,
|
||||
-36.426361083984375,
|
||||
-39.31651306152344,
|
||||
-39.38160705566406],
|
||||
[0.0,
|
||||
-18.7469482421875,
|
||||
-20.100894927978516,
|
||||
-21.402767181396484,
|
||||
-21.428804397583008,
|
||||
-22.20992660522461,
|
||||
-22.34011459350586,
|
||||
-22.730674743652344,
|
||||
-23.069162368774414,
|
||||
-23.980472564697266],
|
||||
[-3.576278118089249e-07,
|
||||
-15.2579345703125,
|
||||
-16.481693267822266,
|
||||
-17.991863250732422,
|
||||
-19.215621948242188,
|
||||
-20.25712013244629,
|
||||
-21.350692749023438,
|
||||
-22.314077377319336,
|
||||
-22.496337890625,
|
||||
-22.938974380493164],
|
||||
[-0.08506780862808228,
|
||||
-2.506549835205078,
|
||||
-14.848289489746094,
|
||||
-15.473188400268555,
|
||||
-16.33242416381836,
|
||||
-16.358461380004883,
|
||||
-16.566761016845703,
|
||||
-17.03543472290039,
|
||||
-17.686370849609375,
|
||||
-17.816556930541992],
|
||||
[-0.0194891095161438,
|
||||
-4.445854187011719,
|
||||
-5.591499328613281,
|
||||
-5.956024169921875,
|
||||
-6.685070037841797,
|
||||
-13.142353057861328,
|
||||
-13.558952331542969,
|
||||
-15.173273086547852,
|
||||
-15.303461074829102,
|
||||
-15.85024642944336],
|
||||
[-0.0005990855861455202,
|
||||
-7.4212646484375,
|
||||
-15.675132751464844,
|
||||
-15.72720718383789,
|
||||
-16.76870346069336,
|
||||
-16.76870346069336,
|
||||
-17.706050872802734,
|
||||
-18.669435501098633,
|
||||
-19.398483276367188,
|
||||
-19.658857345581055],
|
||||
[0.0,
|
||||
-24.110658645629883,
|
||||
-25.829130172729492,
|
||||
-26.011390686035156,
|
||||
-26.011390686035156,
|
||||
-26.532140731811523,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.75589942932129,
|
||||
-28.055330276489258],
|
||||
[-1.1408883333206177,
|
||||
-0.38580334186553955,
|
||||
-7.494022369384766,
|
||||
-12.519245147705078,
|
||||
-14.576202392578125,
|
||||
-16.034297943115234,
|
||||
-16.945608139038086,
|
||||
-17.908992767333984,
|
||||
-18.664077758789062,
|
||||
-19.34105110168457],
|
||||
[0.0,
|
||||
-26.688365936279297,
|
||||
-29.83889389038086,
|
||||
-30.177383422851562,
|
||||
-30.64605712890625,
|
||||
-31.244916915893555,
|
||||
-31.270954132080078,
|
||||
-32.83319854736328,
|
||||
-34.655818939208984,
|
||||
-34.89015579223633],
|
||||
[0.0,
|
||||
-18.929210662841797,
|
||||
-19.16354751586914,
|
||||
-23.589908599853516,
|
||||
-24.683481216430664,
|
||||
-24.995929718017578,
|
||||
-25.516677856445312,
|
||||
-25.542715072631836,
|
||||
-25.77705192565918,
|
||||
-26.063465118408203],
|
||||
[-0.2519786059856415,
|
||||
-1.5017764568328857,
|
||||
-12.437495231628418,
|
||||
-15.457839012145996,
|
||||
-15.744250297546387,
|
||||
-16.837820053100586,
|
||||
-17.41064453125,
|
||||
-17.56686782836914,
|
||||
-17.61894416809082,
|
||||
-18.035541534423828],
|
||||
[0.0,
|
||||
-20.517494201660156,
|
||||
-24.683483123779297,
|
||||
-25.67290496826172,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.781936645507812,
|
||||
-27.912124633789062,
|
||||
-28.09438705444336,
|
||||
-28.445892333984375],
|
||||
[-3.40932747349143e-05,
|
||||
-10.284820556640625,
|
||||
-18.252273559570312,
|
||||
-20.17904281616211,
|
||||
-21.663175582885742,
|
||||
-22.027700424194336,
|
||||
-22.288074493408203,
|
||||
-22.704673767089844,
|
||||
-23.12127113342285,
|
||||
-23.277496337890625],
|
||||
[0.0,
|
||||
-22.60049057006836,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.063467025756836,
|
||||
-27.287227630615234,
|
||||
-27.391376495361328,
|
||||
-27.4694881439209,
|
||||
-27.67778778076172,
|
||||
-28.055330276489258],
|
||||
[0.0,
|
||||
-23.902362823486328,
|
||||
-28.823436737060547,
|
||||
-29.240036010742188,
|
||||
-29.31814956665039,
|
||||
-29.917007446289062,
|
||||
-30.021160125732422,
|
||||
-31.21887969970703,
|
||||
-32.416603088378906,
|
||||
-32.416603088378906],
|
||||
[0.0,
|
||||
-28.641170501708984,
|
||||
-31.947925567626953,
|
||||
-32.59886169433594,
|
||||
-33.848655700683594,
|
||||
-34.109031677246094,
|
||||
-34.73393249511719,
|
||||
-35.02033996582031,
|
||||
-35.02033996582031,
|
||||
-36.074859619140625],
|
||||
[-0.013183215633034706,
|
||||
-4.335395336151123,
|
||||
-19.619365692138672,
|
||||
-20.035964965820312,
|
||||
-20.244266510009766,
|
||||
-21.311800003051758,
|
||||
-21.441987991333008,
|
||||
-22.561595916748047,
|
||||
-23.108383178710938,
|
||||
-23.264606475830078],
|
||||
[-8.344646857949556e-07,
|
||||
-14.190400123596191,
|
||||
-15.9088716506958,
|
||||
-18.17412567138672,
|
||||
-18.46053695678711,
|
||||
-18.46053695678711,
|
||||
-18.512611389160156,
|
||||
-18.90317153930664,
|
||||
-19.059398651123047,
|
||||
-19.085433959960938],
|
||||
[0.0,
|
||||
-17.70545196533203,
|
||||
-18.903175354003906,
|
||||
-20.829944610595703,
|
||||
-22.574451446533203,
|
||||
-22.860862731933594,
|
||||
-23.069162368774414,
|
||||
-23.32953643798828,
|
||||
-23.694061279296875,
|
||||
-24.188772201538086],
|
||||
[0.0,
|
||||
-20.022781372070312,
|
||||
-21.038240432739258,
|
||||
-21.220502853393555,
|
||||
-22.496337890625,
|
||||
-22.769729614257812,
|
||||
-23.589908599853516,
|
||||
-23.65500259399414,
|
||||
-23.94141387939453,
|
||||
-24.266881942749023]]
|
||||
}
|
||||
]
|
||||
148
model_server/app/tests/test_hallucination.py
Normal file
148
model_server/app/tests/test_hallucination.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
import json
|
||||
from app.function_calling.hallucination_handler import HallucinationStateHandler
|
||||
import pytest
|
||||
import os
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
# Construct the full path to the JSON file
|
||||
json_file_path = os.path.join(current_dir, "test_cases.json")
|
||||
|
||||
with open(json_file_path) as f:
|
||||
test_cases = json.load(f)
|
||||
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
function_description = get_weather_api["function"]
|
||||
if type(function_description) != list:
|
||||
function_description = [get_weather_api["function"]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", test_cases)
|
||||
def test_hallucination(case):
|
||||
state = HallucinationStateHandler(
|
||||
response_iterator=None, function=function_description
|
||||
)
|
||||
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
||||
if token != "</tool_call>":
|
||||
state.append_and_check_token_hallucination(token, logprob)
|
||||
if state.hallucination:
|
||||
break
|
||||
assert state.hallucination == case["expect"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
|
||||
def test_hallucination_prompt(is_hallucinate_sample):
|
||||
TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
def format_prompt(tools):
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
return (
|
||||
TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ FORMAT_PROMPT
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
openai_format_tools = [get_weather_api]
|
||||
|
||||
system_prompt = format_prompt(openai_format_tools)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
# List models API
|
||||
model = client.models.list().data[0].id
|
||||
assert model == "Arch-Function"
|
||||
if not is_hallucinate_sample:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
|
||||
extra_body = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
# "continue_final_message": True,
|
||||
# "add_generation_prompt": False,
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
|
||||
)
|
||||
|
||||
hallu = HallucinationStateHandler(
|
||||
response_iterator=resp, function=function_description
|
||||
)
|
||||
|
||||
for token in hallu:
|
||||
assert len(hallu.tokens) >= 0
|
||||
assert hallu.hallucination == is_hallucinate_sample
|
||||
Loading…
Add table
Add a link
Reference in a new issue