mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
new implemenetation
This commit is contained in:
parent
665dbc2d4e
commit
abfc81b0e7
3 changed files with 224 additions and 340 deletions
|
|
@ -1,164 +1,31 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import math
|
||||
import app.commons.constants as const
|
||||
import random
|
||||
from typing import List, Dict, Any, Tuple
|
||||
import json
|
||||
import ast
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import app.commons.constants as const
|
||||
import itertools
|
||||
|
||||
|
||||
def filter_tokens_and_probs(
|
||||
tokens: List[str], probs: List[float]
|
||||
) -> Tuple[List[str], List[float]]:
|
||||
def check_threshold(entropy, varentropy, thd):
|
||||
"""
|
||||
Filters out special tokens from the list of tokens and their corresponding probabilities.
|
||||
Check if the given entropy or variance of entropy exceeds the specified thresholds.
|
||||
|
||||
Args:
|
||||
tokens (list): List of tokens.
|
||||
probs (list): List of probabilities corresponding to the tokens.
|
||||
entropy (float): The entropy value to check.
|
||||
varentropy (float): The variance of entropy value to check.
|
||||
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two lists - filtered tokens and their corresponding probabilities.
|
||||
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
|
||||
"""
|
||||
# Use regex to identify tokens without special characters
|
||||
special_tokens = ["\\n", '{"', '":', ' "', '",', ' {"', '"}}\\n', " ", '"}}\n']
|
||||
filtered_tokens = [token for token in tokens if token not in special_tokens]
|
||||
filtered_probs = [
|
||||
prob for token, prob in zip(tokens, probs) if token not in special_tokens
|
||||
]
|
||||
return filtered_tokens, filtered_probs
|
||||
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
|
||||
|
||||
|
||||
def get_all_parameter_values(
|
||||
tokens: List[str], probs: List[float], parameter_names: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
Extracts parameter values and their corresponding probabilities from the tokens.
|
||||
|
||||
Args:
|
||||
tokens (list): List of tokens.
|
||||
probs (list): List of probabilities corresponding to the tokens.
|
||||
parameter_names (dict): Dictionary of parameter names for each function.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing two dictionaries - parameter values and their corresponding probabilities.
|
||||
"""
|
||||
parameter_values = {}
|
||||
probs_values = {}
|
||||
i = 0
|
||||
|
||||
while i < len(tokens):
|
||||
# Try to form parameter names by combining tokens
|
||||
combined_token = ""
|
||||
start = i
|
||||
found_param = False
|
||||
|
||||
# Incrementally combine tokens to find a full match with any parameter name
|
||||
while i < len(tokens):
|
||||
if combined_token:
|
||||
combined_token += tokens[
|
||||
i
|
||||
] # Append next token to the current combination
|
||||
else:
|
||||
combined_token = tokens[i] # Start a new combination
|
||||
|
||||
# Check if the combined token matches any parameter name
|
||||
for func, params in parameter_names.items():
|
||||
if combined_token in params:
|
||||
# Collect values associated with this parameter
|
||||
values = []
|
||||
prob_values = []
|
||||
i += 1 # Move past the parameter name
|
||||
|
||||
# Collect tokens as values until the next parameter or end marker
|
||||
while (
|
||||
i < len(tokens)
|
||||
and tokens[i] not in params
|
||||
and tokens[i] != "</tool_call>"
|
||||
):
|
||||
values.append(tokens[i])
|
||||
prob_values.append(probs[i])
|
||||
i += 1
|
||||
|
||||
# Store the parameter values and probabilities
|
||||
parameter_values[combined_token] = values
|
||||
probs_values[combined_token] = prob_values
|
||||
|
||||
found_param = True
|
||||
break # Stop combining further once a parameter is matched
|
||||
|
||||
if found_param:
|
||||
break # Exit the outer loop if parameter was matched
|
||||
i += 1 # Move to the next token if no match was found yet
|
||||
|
||||
# Reset to the next token if no parameter match was found
|
||||
if not found_param:
|
||||
i = start + 1
|
||||
|
||||
return parameter_values, probs_values
|
||||
|
||||
|
||||
def calculate_stats(
|
||||
data: Dict[str, Any], function_description: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates statistical metrics for the given data.
|
||||
|
||||
Args:
|
||||
data (dict): Dictionary containing parameter values and their corresponding probabilities.
|
||||
function_description (dict): Description of the function containing parameter properties.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing statistical metrics for each parameter.
|
||||
"""
|
||||
stats = {}
|
||||
try:
|
||||
for key, values in data.items():
|
||||
if len(data[key]) >= 1:
|
||||
first = values[0]
|
||||
max_value = max(values)
|
||||
min_value = min(values)
|
||||
avg_value = sum(values) / len(values)
|
||||
has_format = check_parameter_property(
|
||||
function_description, key, "format"
|
||||
)
|
||||
has_default = check_parameter_property(
|
||||
function_description, key, "default"
|
||||
)
|
||||
stats[key] = {
|
||||
"first": first,
|
||||
"max": max_value,
|
||||
"min": min_value,
|
||||
"avg": avg_value,
|
||||
"has_format": has_format,
|
||||
"has_default": has_default,
|
||||
}
|
||||
except Exception as e:
|
||||
print(data)
|
||||
return stats
|
||||
|
||||
|
||||
def check_parameter_property(
|
||||
api_description: Dict[str, Any], parameter_name: str, property_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter in an API description has a specific property.
|
||||
|
||||
Args:
|
||||
api_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
property_name (str): The property to look for (e.g., 'format', 'default').
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
parameters = api_description.get("parameters", {}).get("properties", {})
|
||||
parameter_info = parameters.get(parameter_name, {})
|
||||
|
||||
return property_name in parameter_info
|
||||
|
||||
|
||||
def calculate_entropy(log_probs):
|
||||
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
|
||||
|
||||
|
|
@ -178,185 +45,196 @@ def calculate_entropy(log_probs):
|
|||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
||||
dim=-1,
|
||||
)
|
||||
return log_probs.tolist(), entropy.item(), varentropy.item()
|
||||
return entropy.item(), varentropy.item()
|
||||
|
||||
|
||||
def hallucination_detect(
|
||||
token: str,
|
||||
log_probs: List[float],
|
||||
current_state: Dict[str, Any],
|
||||
entropy_thd: float = 0.7,
|
||||
varentropy_thd: float = 4.0,
|
||||
) -> bool:
|
||||
def check_parameter_property(api_description, parameter_name, property_name):
|
||||
"""
|
||||
Detects hallucinations in the token sequence based on entropy and varentropy thresholds.
|
||||
Check if a parameter in an API description has a specific property.
|
||||
|
||||
Args:
|
||||
token (str): The current token.
|
||||
log_probs (list): List of log probabilities for the current token.
|
||||
current_state (dict): The current state of the detection process.
|
||||
entropy_thd (float): Entropy threshold for detecting hallucinations.
|
||||
varentropy_thd (float): Variance of entropy threshold for detecting hallucinations.
|
||||
api_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
property_name (str): The property to look for (e.g., 'format', 'default').
|
||||
|
||||
Returns:
|
||||
bool: True if a hallucination is detected, False otherwise.
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
parameters = api_description.get("properties", {})
|
||||
parameter_info = parameters.get(parameter_name, {})
|
||||
|
||||
return property_name in parameter_info
|
||||
|
||||
|
||||
class HallucinationStateHandler:
|
||||
"""
|
||||
A class to handle the state of hallucination detection in token processing.
|
||||
|
||||
Attributes:
|
||||
tokens (list): List of tokens processed.
|
||||
logprobs (list): List of log probabilities for each token.
|
||||
state (str): Current state of the handler.
|
||||
mask (list): List of masks indicating the type of each token.
|
||||
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
|
||||
hallucination (bool): Flag indicating if a hallucination is detected.
|
||||
hallucination_message (str): Message describing the hallucination.
|
||||
parameter_name (list): List of extracted parameter names.
|
||||
function_description (dict): Description of functions and their parameters.
|
||||
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
||||
current_token (str): The current token being processed.
|
||||
"""
|
||||
|
||||
if token:
|
||||
# check if there is content in token
|
||||
current_state["tokens"].append(token)
|
||||
current_state["content"] += token
|
||||
current_state["logprobs"].append(log_probs)
|
||||
# keep track of entropy and varentropy
|
||||
_, entropy, varentropy = calculate_entropy(log_probs)
|
||||
current_state["entropy"].append(entropy)
|
||||
current_state["varentropy"].append(varentropy)
|
||||
# first check if tool call token is certain
|
||||
if token == "<tool_call>":
|
||||
if entropy > entropy_thd or varentropy > varentropy_thd:
|
||||
current_state["hallucination"] = True
|
||||
current_state[
|
||||
"hallucination_message"
|
||||
] = f"{token} with entropy {entropy}, varentropy {varentropy} doesn't pass the threshold {entropy_thd} | {varentropy_thd}"
|
||||
return True
|
||||
elif token == "</tool_call>":
|
||||
current_state["state"] = "tool_call_end"
|
||||
# try to extract tool call, else raise error
|
||||
try:
|
||||
current_state[
|
||||
"tool_call"
|
||||
] = const.arch_function_hanlder.extract_tool_calls(
|
||||
current_state["content"]
|
||||
)[
|
||||
0
|
||||
]
|
||||
current_state["tool_call_process"] = True
|
||||
except:
|
||||
current_state["tool_call_process"] = False
|
||||
print(f"cant process tool")
|
||||
return True
|
||||
# check if function name is valid
|
||||
if (
|
||||
current_state["tool_call"]["function"]["name"]
|
||||
not in current_state["parameter_names"].keys()
|
||||
):
|
||||
current_state["hallucination"] = True
|
||||
current_state[
|
||||
"hallucination_message"
|
||||
] = f"function name {current_state['tool_call']['name']} not found"
|
||||
return True
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the HallucinationStateHandler with default values.
|
||||
"""
|
||||
self.tokens = []
|
||||
self.logprobs = []
|
||||
self.state = None
|
||||
self.mask = []
|
||||
self.parameter_name_done = False
|
||||
self.hallucination = False
|
||||
self.hallucination_message = ""
|
||||
self.parameter_name = []
|
||||
|
||||
# check if parameter names are from the given function tools
|
||||
current_parameter_names = current_state["tool_call"]["function"][
|
||||
"arguments"
|
||||
].keys()
|
||||
given_parameter_names = current_state["parameter_names"][
|
||||
current_state["tool_call"]["function"]["name"]
|
||||
]
|
||||
if not set(current_parameter_names).issubset(given_parameter_names):
|
||||
missing_keys = set(current_parameter_names) - set(given_parameter_names)
|
||||
self.token_probs_map = []
|
||||
self.current_token = None
|
||||
|
||||
current_state["hallucination"] = True
|
||||
current_state[
|
||||
"hallucination_message"
|
||||
] = f"parameter names {missing_keys} not found"
|
||||
return True
|
||||
def process_function(self, apis):
|
||||
self.apis = apis
|
||||
if self.apis is None:
|
||||
raise ValueError("API descriptions not set.")
|
||||
parameter_names = {}
|
||||
for func in self.apis:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
self.function_description = parameter_names
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.apis}
|
||||
|
||||
# filtered special tokens that are not needed in the hallucination check for parameter values
|
||||
(
|
||||
current_state["filtered_tokens"],
|
||||
current_state["filtered_entropy"],
|
||||
) = filter_tokens_and_probs(
|
||||
current_state["tokens"], current_state["entropy"]
|
||||
)
|
||||
(
|
||||
current_state["filtered_tokens"],
|
||||
current_state["filtered_varentropy"],
|
||||
) = filter_tokens_and_probs(
|
||||
current_state["tokens"], current_state["varentropy"]
|
||||
)
|
||||
parameter_values, entropy_values = get_all_parameter_values(
|
||||
current_state["filtered_tokens"],
|
||||
current_state["filtered_entropy"],
|
||||
current_state["parameter_names"],
|
||||
)
|
||||
parameter_values, varentropy_values = get_all_parameter_values(
|
||||
current_state["filtered_tokens"],
|
||||
current_state["filtered_varentropy"],
|
||||
current_state["parameter_names"],
|
||||
)
|
||||
def process_token(self):
|
||||
"""
|
||||
Processes the current token and updates the state and mask accordingly.
|
||||
Detects hallucinations based on the token type and log probabilities.
|
||||
"""
|
||||
content = "".join(self.tokens).replace(" ", "")
|
||||
if self.current_token == "<tool_call>":
|
||||
self.mask.append("t")
|
||||
self.check_logprob()
|
||||
|
||||
current_state["parameter_values"] = parameter_values
|
||||
current_state["parameter_values_entropy"] = entropy_values
|
||||
current_state["parameter_values_varentropy"] = varentropy_values
|
||||
# calculate the max, first, avg of sub tokens for parameter value
|
||||
current_state["parameter_value_entropy_stat"] = calculate_stats(
|
||||
current_state["parameter_values_entropy"],
|
||||
current_state["function_description"][0],
|
||||
)
|
||||
current_state["parameter_value_varentropy_stat"] = calculate_stats(
|
||||
current_state["parameter_values_varentropy"],
|
||||
current_state["function_description"][0],
|
||||
)
|
||||
# get map for debugging
|
||||
current_state["token_entropy_map"] = {
|
||||
x: y for x, y in zip(current_state["tokens"], current_state["entropy"])
|
||||
}
|
||||
current_state["token_varentropy_map"] = {
|
||||
x: y
|
||||
for x, y in zip(current_state["tokens"], current_state["varentropy"])
|
||||
}
|
||||
# Function name extraction logic
|
||||
if self.state == "function_name":
|
||||
if self.current_token not in const.FUNC_NAME_END_TOKEN:
|
||||
self.mask.append("f")
|
||||
else:
|
||||
self.state = None
|
||||
self.check_function_name()
|
||||
|
||||
# checking hallucination for parameter value
|
||||
current_state["parameter_value_check"] = {
|
||||
x: {"hallucination": False, "message": ""}
|
||||
for x in current_state["parameter_values"].keys()
|
||||
}
|
||||
for key in current_state["parameter_value_check"].keys():
|
||||
# if parameter is given a format, check the first token
|
||||
if current_state["parameter_value_entropy_stat"][key]["has_format"]:
|
||||
if (
|
||||
current_state["parameter_value_entropy_stat"][key]["first"]
|
||||
> entropy_thd
|
||||
or current_state["parameter_value_varentropy_stat"][key][
|
||||
"first"
|
||||
]
|
||||
> varentropy_thd
|
||||
):
|
||||
current_state["parameter_value_check"][key][
|
||||
"hallucination"
|
||||
] = True
|
||||
current_state["hallucination"] = True
|
||||
current_state["parameter_value_check"][key][
|
||||
"message"
|
||||
] = f"parameter {key} with formatting doesn't pass threshold"
|
||||
# if parameter gis given a default value, we can always use default
|
||||
elif current_state["parameter_value_entropy_stat"][key]["has_default"]:
|
||||
current_state["parameter_value_check"][key]["hallucination"] = False
|
||||
current_state["parameter_value_check"][key][
|
||||
"message"
|
||||
] = f"parameter {key} with default"
|
||||
# check if max sub token is > thresholds
|
||||
else:
|
||||
if (
|
||||
current_state["parameter_value_entropy_stat"][key]["max"]
|
||||
> entropy_thd
|
||||
or current_state["parameter_value_varentropy_stat"][key]["max"]
|
||||
> varentropy_thd
|
||||
):
|
||||
current_state["parameter_value_check"][key][
|
||||
"hallucination"
|
||||
] = True
|
||||
current_state["parameter_value_check"][key][
|
||||
"message"
|
||||
] = f"parameter {key} with {current_state['parameter_value_entropy_stat'][key]['max']} and {current_state['parameter_value_varentropy_stat'][key]['max']} doesnt pass threshold"
|
||||
current_state["hallucination"] = True
|
||||
if current_state["hallucination"] == True:
|
||||
current_state["hallucination_message"] = "\n".join(
|
||||
[
|
||||
current_state["parameter_value_check"][key]["message"]
|
||||
for key in current_state["parameter_value_check"].keys()
|
||||
]
|
||||
)
|
||||
return True
|
||||
return False
|
||||
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
||||
print("function name entered")
|
||||
self.state = "function_name"
|
||||
|
||||
# Parameter name extraction logic
|
||||
if self.state == "parameter_name" and not content.endswith(
|
||||
const.PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.mask.append("p")
|
||||
elif self.state == "parameter_name" and content.endswith(
|
||||
const.PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.state = None
|
||||
self.check_parameter_name()
|
||||
self.parameter_name_done = True
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
const.PARAMETER_NAME_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_name"
|
||||
|
||||
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# Parameter value extraction logic
|
||||
if self.state == "parameter_value" and not content.endswith(
|
||||
const.PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
if self.current_token.strip() not in ['"', ""]:
|
||||
self.mask.append("v")
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] == "v"
|
||||
and not check_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"default",
|
||||
)
|
||||
):
|
||||
self.check_logprob()
|
||||
else:
|
||||
self.mask.append("e")
|
||||
|
||||
elif self.state == "parameter_value" and content.endswith(
|
||||
const.PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
self.state = None
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
const.PARAMETER_VALUE_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_value"
|
||||
|
||||
# Maintain consistency between stack and mask
|
||||
if len(self.mask) != len(self.tokens):
|
||||
self.mask.append("e")
|
||||
|
||||
def check_logprob(self):
|
||||
"""
|
||||
Checks the log probability of the current token and updates the token probability map.
|
||||
Detects hallucinations based on entropy and variance of entropy.
|
||||
"""
|
||||
probs = self.logprobs[-1]
|
||||
entropy, varentropy = calculate_entropy(probs)
|
||||
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
|
||||
|
||||
if check_threshold(
|
||||
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]]
|
||||
):
|
||||
self.hallucination = True
|
||||
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
|
||||
|
||||
def count_consecutive_token(self, token="v") -> int:
|
||||
"""
|
||||
Counts the number of consecutive occurrences of a given token in the mask.
|
||||
|
||||
Args:
|
||||
token (str): The token to count in the mask.
|
||||
|
||||
Returns:
|
||||
int: The number of consecutive occurrences of the token.
|
||||
"""
|
||||
return (
|
||||
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
|
||||
if self.mask and self.mask[-1] == token
|
||||
else 0
|
||||
)
|
||||
|
||||
def check_function_name(self):
|
||||
"""
|
||||
Checks the extracted function name against the function descriptions.
|
||||
Detects hallucinations if the function name is not found.
|
||||
"""
|
||||
f_len = self.count_consecutive_token("f")
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
if self.function_name not in self.function_description.keys():
|
||||
self.hallucination = True
|
||||
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||
|
||||
def check_parameter_name(self):
|
||||
"""
|
||||
Checks the extracted parameter name against the function descriptions.
|
||||
Detects hallucinations if the parameter name is not found.
|
||||
"""
|
||||
p_len = self.count_consecutive_token("p")
|
||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||
self.parameter_name.append(parameter_name)
|
||||
if parameter_name not in self.function_description[self.function_name]:
|
||||
self.hallucination = True
|
||||
self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue