mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
fix test
This commit is contained in:
parent
d776ed0117
commit
665dbc2d4e
3 changed files with 203 additions and 61 deletions
|
|
@ -1,8 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Dict
|
import math
|
||||||
|
import app.commons.constants as const
|
||||||
|
import random
|
||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
import json
|
||||||
|
|
||||||
def filter_tokens_and_probs(tokens: List[str], probs: List[float]) -> Tuple[List[], List[float]]:
|
|
||||||
|
def filter_tokens_and_probs(
|
||||||
|
tokens: List[str], probs: List[float]
|
||||||
|
) -> Tuple[List[str], List[float]]:
|
||||||
"""
|
"""
|
||||||
Filters out special tokens from the list of tokens and their corresponding probabilities.
|
Filters out special tokens from the list of tokens and their corresponding probabilities.
|
||||||
|
|
||||||
|
|
@ -14,17 +21,17 @@ def filter_tokens_and_probs(tokens: List[str], probs: List[float]) -> Tuple[List
|
||||||
tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities.
|
tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities.
|
||||||
"""
|
"""
|
||||||
# Use regex to identify tokens without special characters
|
# Use regex to identify tokens without special characters
|
||||||
special_tokens = ['\\n', '{"', '":', ' "', '",', ' {"', '"}}\\n', ' ', '"}}\n']
|
special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n']
|
||||||
filtered_tokens = [
|
filtered_tokens = [token for token in tokens if token not in special_tokens]
|
||||||
token for token in tokens
|
|
||||||
if token not in special_tokens
|
|
||||||
]
|
|
||||||
filtered_probs = [
|
filtered_probs = [
|
||||||
prob for token, prob in zip(tokens, probs)
|
prob for token, prob in zip(tokens, probs) if token not in special_tokens
|
||||||
if token not in special_tokens
|
|
||||||
]
|
]
|
||||||
return filtered_tokens, filtered_probs
|
return filtered_tokens, filtered_probs
|
||||||
def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_names: Dict[str, List[str]]) -> Tuple[Dict[str, List[str]], Dict[str, List[float]]]:
|
|
||||||
|
|
||||||
|
def get_all_parameter_values(
|
||||||
|
tokens: List[str], probs: List[float], parameter_names: Dict[str, Any]
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Extracts parameter values and their corresponding probabilities from the tokens.
|
Extracts parameter values and their corresponding probabilities from the tokens.
|
||||||
|
|
||||||
|
|
@ -49,7 +56,9 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na
|
||||||
# Incrementally combine tokens to find a full match with any parameter name
|
# Incrementally combine tokens to find a full match with any parameter name
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
if combined_token:
|
if combined_token:
|
||||||
combined_token += tokens[i] # Append next token to the current combination
|
combined_token += tokens[
|
||||||
|
i
|
||||||
|
] # Append next token to the current combination
|
||||||
else:
|
else:
|
||||||
combined_token = tokens[i] # Start a new combination
|
combined_token = tokens[i] # Start a new combination
|
||||||
|
|
||||||
|
|
@ -62,7 +71,11 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na
|
||||||
i += 1 # Move past the parameter name
|
i += 1 # Move past the parameter name
|
||||||
|
|
||||||
# Collect tokens as values until the next parameter or end marker
|
# Collect tokens as values until the next parameter or end marker
|
||||||
while i < len(tokens) and tokens[i] not in params and tokens[i] != '</tool_call>':
|
while (
|
||||||
|
i < len(tokens)
|
||||||
|
and tokens[i] not in params
|
||||||
|
and tokens[i] != "</tool_call>"
|
||||||
|
):
|
||||||
values.append(tokens[i])
|
values.append(tokens[i])
|
||||||
prob_values.append(probs[i])
|
prob_values.append(probs[i])
|
||||||
i += 1
|
i += 1
|
||||||
|
|
@ -83,7 +96,11 @@ def get_all_parameter_values(tokens: List[str], probs: List[float], parameter_na
|
||||||
i = start + 1
|
i = start + 1
|
||||||
|
|
||||||
return parameter_values, probs_values
|
return parameter_values, probs_values
|
||||||
def calculate_stats(data: Dict, function_description: Dict) -> Dict:
|
|
||||||
|
|
||||||
|
def calculate_stats(
|
||||||
|
data: Dict[str, Any], function_description: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Calculates statistical metrics for the given data.
|
Calculates statistical metrics for the given data.
|
||||||
|
|
||||||
|
|
@ -97,19 +114,33 @@ def calculate_stats(data: Dict, function_description: Dict) -> Dict:
|
||||||
stats = {}
|
stats = {}
|
||||||
try:
|
try:
|
||||||
for key, values in data.items():
|
for key, values in data.items():
|
||||||
if len(data[key])>=1:
|
if len(data[key]) >= 1:
|
||||||
first = values[0]
|
first = values[0]
|
||||||
max_value = max(values)
|
max_value = max(values)
|
||||||
min_value = min(values)
|
min_value = min(values)
|
||||||
avg_value = sum(values) / len(values)
|
avg_value = sum(values) / len(values)
|
||||||
has_format = check_parameter_property(function_description, key, "format")
|
has_format = check_parameter_property(
|
||||||
has_default = check_parameter_property(function_description, key , "default")
|
function_description, key, "format"
|
||||||
stats[key] = {'first':first, 'max': max_value, 'min': min_value, 'avg': avg_value, 'has_format': has_format, 'has_default': has_default}
|
)
|
||||||
|
has_default = check_parameter_property(
|
||||||
|
function_description, key, "default"
|
||||||
|
)
|
||||||
|
stats[key] = {
|
||||||
|
"first": first,
|
||||||
|
"max": max_value,
|
||||||
|
"min": min_value,
|
||||||
|
"avg": avg_value,
|
||||||
|
"has_format": has_format,
|
||||||
|
"has_default": has_default,
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def check_parameter_property(api_description: Dict, parameter_name: str, property_name: str)-> bool:
|
|
||||||
|
def check_parameter_property(
|
||||||
|
api_description: Dict[str, Any], parameter_name: str, property_name: str
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a parameter in an API description has a specific property.
|
Check if a parameter in an API description has a specific property.
|
||||||
|
|
||||||
|
|
@ -127,8 +158,36 @@ def check_parameter_property(api_description: Dict, parameter_name: str, propert
|
||||||
return property_name in parameter_info
|
return property_name in parameter_info
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_entropy(log_probs):
|
||||||
|
"""
|
||||||
|
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
|
||||||
|
|
||||||
def hallucination_detect(token:str, log_probs:List[float], current_state: Dict, entropy_thd : float= 0.7, varentropy_thd :float = 4.0) -> bool:
|
Args:
|
||||||
|
log_probs (list of float): A list of log probabilities.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing:
|
||||||
|
- log_probs (list of float): The input log probabilities as a list.
|
||||||
|
- entropy (float): The calculated entropy.
|
||||||
|
- varentropy (float): The calculated variance of entropy.
|
||||||
|
"""
|
||||||
|
log_probs = torch.tensor(log_probs)
|
||||||
|
token_probs = torch.exp(log_probs)
|
||||||
|
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
|
||||||
|
varentropy = torch.sum(
|
||||||
|
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return log_probs.tolist(), entropy.item(), varentropy.item()
|
||||||
|
|
||||||
|
|
||||||
|
def hallucination_detect(
|
||||||
|
token: str,
|
||||||
|
log_probs: List[float],
|
||||||
|
current_state: Dict[str, Any],
|
||||||
|
entropy_thd: float = 0.7,
|
||||||
|
varentropy_thd: float = 4.0,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Detects hallucinations in the token sequence based on entropy and varentropy thresholds.
|
Detects hallucinations in the token sequence based on entropy and varentropy thresholds.
|
||||||
|
|
||||||
|
|
@ -146,8 +205,8 @@ def hallucination_detect(token:str, log_probs:List[float], current_state: Dict,
|
||||||
if token:
|
if token:
|
||||||
# check if there is content in token
|
# check if there is content in token
|
||||||
current_state["tokens"].append(token)
|
current_state["tokens"].append(token)
|
||||||
current_state['content'] += token
|
current_state["content"] += token
|
||||||
current_state['logprobs'].append(log_probs)
|
current_state["logprobs"].append(log_probs)
|
||||||
# keep track of entropy and varentropy
|
# keep track of entropy and varentropy
|
||||||
_, entropy, varentropy = calculate_entropy(log_probs)
|
_, entropy, varentropy = calculate_entropy(log_probs)
|
||||||
current_state["entropy"].append(entropy)
|
current_state["entropy"].append(entropy)
|
||||||
|
|
@ -156,71 +215,148 @@ def hallucination_detect(token:str, log_probs:List[float], current_state: Dict,
|
||||||
if token == "<tool_call>":
|
if token == "<tool_call>":
|
||||||
if entropy > entropy_thd or varentropy > varentropy_thd:
|
if entropy > entropy_thd or varentropy > varentropy_thd:
|
||||||
current_state["hallucination"] = True
|
current_state["hallucination"] = True
|
||||||
current_state["hallucination_message"] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}"
|
current_state[
|
||||||
|
"hallucination_message"
|
||||||
|
] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}"
|
||||||
return True
|
return True
|
||||||
elif token == "</tool_call>":
|
elif token == "</tool_call>":
|
||||||
current_state["state"] = "tool_call_end"
|
current_state["state"] = "tool_call_end"
|
||||||
# try to extract tool call, else raise error
|
# try to extract tool call, else raise error
|
||||||
try:
|
try:
|
||||||
current_state['tool_call'] = extract_tool_calls(current_state["content"])[0]
|
current_state[
|
||||||
current_state['tool_call_process'] = True
|
"tool_call"
|
||||||
|
] = const.arch_function_hanlder.extract_tool_calls(
|
||||||
|
current_state["content"]
|
||||||
|
)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
current_state["tool_call_process"] = True
|
||||||
except:
|
except:
|
||||||
current_state['tool_call_process'] = False
|
current_state["tool_call_process"] = False
|
||||||
print(f"cant process tool")
|
print(f"cant process tool")
|
||||||
return True
|
return True
|
||||||
# check if function name is valid
|
# check if function name is valid
|
||||||
if current_state['tool_call']['function']['name'] not in current_state['parameter_names'].keys():
|
if (
|
||||||
|
current_state["tool_call"]["function"]["name"]
|
||||||
|
not in current_state["parameter_names"].keys()
|
||||||
|
):
|
||||||
current_state["hallucination"] = True
|
current_state["hallucination"] = True
|
||||||
current_state["hallucination_message"] = f"function name {current_state['tool_call']['name']} not found"
|
current_state[
|
||||||
|
"hallucination_message"
|
||||||
|
] = f"function name {current_state['tool_call']['name']} not found"
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# check if parameter names are from the given function tools
|
# check if parameter names are from the given function tools
|
||||||
current_parameter_names = current_state['tool_call']['function']['arguments'].keys()
|
current_parameter_names = current_state["tool_call"]["function"][
|
||||||
given_parameter_names = current_state['parameter_names'][current_state['tool_call']['function']['name']]
|
"arguments"
|
||||||
|
].keys()
|
||||||
|
given_parameter_names = current_state["parameter_names"][
|
||||||
|
current_state["tool_call"]["function"]["name"]
|
||||||
|
]
|
||||||
if not set(current_parameter_names).issubset(given_parameter_names):
|
if not set(current_parameter_names).issubset(given_parameter_names):
|
||||||
|
|
||||||
missing_keys = set(current_parameter_names) - set(given_parameter_names)
|
missing_keys = set(current_parameter_names) - set(given_parameter_names)
|
||||||
|
|
||||||
current_state["hallucination"] = True
|
current_state["hallucination"] = True
|
||||||
current_state["hallucination_message"] = f"parameter names {missing_keys} not found"
|
current_state[
|
||||||
|
"hallucination_message"
|
||||||
|
] = f"parameter names {missing_keys} not found"
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# filtered special tokens that are not needed in the hallucination check for parameter values
|
# filtered special tokens that are not needed in the hallucination check for parameter values
|
||||||
current_state["filtered_tokens"], current_state["filtered_entropy"] = filter_tokens_and_probs(current_state["tokens"], current_state["entropy"])
|
(
|
||||||
current_state["filtered_tokens"], current_state["filtered_varentropy"] = filter_tokens_and_probs(current_state["tokens"], current_state["varentropy"])
|
current_state["filtered_tokens"],
|
||||||
parameter_values, entropy_values = get_all_parameter_values(current_state["filtered_tokens"], current_state["filtered_entropy"], current_state['parameter_names'])
|
current_state["filtered_entropy"],
|
||||||
parameter_values, varentropy_values = get_all_parameter_values(current_state["filtered_tokens"], current_state["filtered_varentropy"], current_state['parameter_names'])
|
) = filter_tokens_and_probs(
|
||||||
|
current_state["tokens"], current_state["entropy"]
|
||||||
|
)
|
||||||
|
(
|
||||||
|
current_state["filtered_tokens"],
|
||||||
|
current_state["filtered_varentropy"],
|
||||||
|
) = filter_tokens_and_probs(
|
||||||
|
current_state["tokens"], current_state["varentropy"]
|
||||||
|
)
|
||||||
|
parameter_values, entropy_values = get_all_parameter_values(
|
||||||
|
current_state["filtered_tokens"],
|
||||||
|
current_state["filtered_entropy"],
|
||||||
|
current_state["parameter_names"],
|
||||||
|
)
|
||||||
|
parameter_values, varentropy_values = get_all_parameter_values(
|
||||||
|
current_state["filtered_tokens"],
|
||||||
|
current_state["filtered_varentropy"],
|
||||||
|
current_state["parameter_names"],
|
||||||
|
)
|
||||||
|
|
||||||
current_state['parameter_values'] = parameter_values
|
current_state["parameter_values"] = parameter_values
|
||||||
current_state['parameter_values_entropy'] = entropy_values
|
current_state["parameter_values_entropy"] = entropy_values
|
||||||
current_state['parameter_values_varentropy'] = varentropy_values
|
current_state["parameter_values_varentropy"] = varentropy_values
|
||||||
# calculate the max, first, avg of sub tokens for parameter value
|
# calculate the max, first, avg of sub tokens for parameter value
|
||||||
current_state['parameter_value_entropy_stat'] = calculate_stats(current_state['parameter_values_entropy'], current_state['function_description'][0])
|
current_state["parameter_value_entropy_stat"] = calculate_stats(
|
||||||
current_state['parameter_value_varentropy_stat'] = calculate_stats(current_state['parameter_values_varentropy'], current_state['function_description'][0])
|
current_state["parameter_values_entropy"],
|
||||||
|
current_state["function_description"][0],
|
||||||
|
)
|
||||||
|
current_state["parameter_value_varentropy_stat"] = calculate_stats(
|
||||||
|
current_state["parameter_values_varentropy"],
|
||||||
|
current_state["function_description"][0],
|
||||||
|
)
|
||||||
# get map for debugging
|
# get map for debugging
|
||||||
current_state['token_entropy_map'] = {x : y for x,y in zip(current_state['tokens'], current_state['entropy'])}
|
current_state["token_entropy_map"] = {
|
||||||
current_state['token_varentropy_map'] = {x : y for x,y in zip(current_state['tokens'], current_state['varentropy'])}
|
x: y for x, y in zip(current_state["tokens"], current_state["entropy"])
|
||||||
|
}
|
||||||
|
current_state["token_varentropy_map"] = {
|
||||||
|
x: y
|
||||||
|
for x, y in zip(current_state["tokens"], current_state["varentropy"])
|
||||||
|
}
|
||||||
|
|
||||||
# checking hallucination for parameter value
|
# checking hallucination for parameter value
|
||||||
current_state['parameter_value_check'] = {x : {'hallucination': False, 'message': ''} for x in current_state['parameter_values'].keys()}
|
current_state["parameter_value_check"] = {
|
||||||
for key in current_state['parameter_value_check'].keys():
|
x: {"hallucination": False, "message": ""}
|
||||||
|
for x in current_state["parameter_values"].keys()
|
||||||
|
}
|
||||||
|
for key in current_state["parameter_value_check"].keys():
|
||||||
# if parameter is given a format, check the first token
|
# if parameter is given a format, check the first token
|
||||||
if current_state['parameter_value_entropy_stat'][key]['has_format']:
|
if current_state["parameter_value_entropy_stat"][key]["has_format"]:
|
||||||
if current_state['parameter_value_entropy_stat'][key]['first'] > entropy_thd or current_state['parameter_value_varentropy_stat'][key]['first'] > varentropy_thd:
|
if (
|
||||||
current_state['parameter_value_check'][key]['hallucination'] = True
|
current_state["parameter_value_entropy_stat"][key]["first"]
|
||||||
|
> entropy_thd
|
||||||
|
or current_state["parameter_value_varentropy_stat"][key][
|
||||||
|
"first"
|
||||||
|
]
|
||||||
|
> varentropy_thd
|
||||||
|
):
|
||||||
|
current_state["parameter_value_check"][key][
|
||||||
|
"hallucination"
|
||||||
|
] = True
|
||||||
current_state["hallucination"] = True
|
current_state["hallucination"] = True
|
||||||
current_state['parameter_value_check'][key]['message'] = f"parameter {key} with formatting doesn't pass threshold"
|
current_state["parameter_value_check"][key][
|
||||||
|
"message"
|
||||||
|
] = f"parameter {key} with formatting doesn't pass threshold"
|
||||||
# if parameter gis given a default value, we can always use default
|
# if parameter gis given a default value, we can always use default
|
||||||
elif current_state['parameter_value_entropy_stat'][key]['has_default']:
|
elif current_state["parameter_value_entropy_stat"][key]["has_default"]:
|
||||||
current_state['parameter_value_check'][key]['hallucination'] = False
|
current_state["parameter_value_check"][key]["hallucination"] = False
|
||||||
current_state['parameter_value_check'][key]['message'] = f"parameter {key} with default"
|
current_state["parameter_value_check"][key][
|
||||||
|
"message"
|
||||||
|
] = f"parameter {key} with default"
|
||||||
# check if max sub token is > thresholds
|
# check if max sub token is > thresholds
|
||||||
else:
|
else:
|
||||||
if current_state['parameter_value_entropy_stat'][key]['max'] > entropy_thd or current_state['parameter_value_varentropy_stat'][key]['max'] > varentropy_thd:
|
if (
|
||||||
current_state['parameter_value_check'][key]['hallucination'] = True
|
current_state["parameter_value_entropy_stat"][key]["max"]
|
||||||
current_state['parameter_value_check'][key]['message'] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold"
|
> entropy_thd
|
||||||
|
or current_state["parameter_value_varentropy_stat"][key]["max"]
|
||||||
|
> varentropy_thd
|
||||||
|
):
|
||||||
|
current_state["parameter_value_check"][key][
|
||||||
|
"hallucination"
|
||||||
|
] = True
|
||||||
|
current_state["parameter_value_check"][key][
|
||||||
|
"message"
|
||||||
|
] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold"
|
||||||
current_state["hallucination"] = True
|
current_state["hallucination"] = True
|
||||||
if current_state["hallucination"] == True:
|
if current_state["hallucination"] == True:
|
||||||
current_state["hallucination_message"] = "\n".join([current_state['parameter_value_check'][key]['message'] for key in current_state['parameter_value_check'].keys()])
|
current_state["hallucination_message"] = "\n".join(
|
||||||
|
[
|
||||||
|
current_state["parameter_value_check"][key]["message"]
|
||||||
|
for key in current_state["parameter_value_check"].keys()
|
||||||
|
]
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
@ -134,4 +134,4 @@ class ArchFunctionHandler:
|
||||||
fixed_str += opening_bracket[unmatched_opening]
|
fixed_str += opening_bracket[unmatched_opening]
|
||||||
|
|
||||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||||
return fixed_str
|
return fixed_str.replace("'", '"')
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,16 @@
|
||||||
import json
|
import json
|
||||||
from app.function_calling.hallucination_handler import hallucination_detect
|
from app.function_calling.hallucination_handler import hallucination_detect
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Get the directory of the current file
|
||||||
|
current_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# Construct the full path to the JSON file
|
||||||
|
json_file_path = os.path.join(current_dir, "test_cases.json")
|
||||||
|
|
||||||
|
with open(json_file_path) as f:
|
||||||
|
test_cases = json.load(f)
|
||||||
|
|
||||||
get_weather_api = {
|
get_weather_api = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
|
@ -41,9 +50,6 @@ for func in function_description:
|
||||||
parameters = func["parameters"]["properties"]
|
parameters = func["parameters"]["properties"]
|
||||||
parameter_names[func_name] = list(parameters.keys())
|
parameter_names[func_name] = list(parameters.keys())
|
||||||
|
|
||||||
with open("test_cases.json") as f:
|
|
||||||
test_cases = json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("case", test_cases)
|
@pytest.mark.parametrize("case", test_cases)
|
||||||
def test_hallucination(case):
|
def test_hallucination(case):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue