mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
address issue
This commit is contained in:
parent
b5b56a9468
commit
4da2184d56
1 changed files with 67 additions and 15 deletions
|
|
@ -10,7 +10,7 @@ import app.commons.constants as const
|
|||
import itertools
|
||||
|
||||
|
||||
def check_threshold(entropy, varentropy, thd):
|
||||
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
||||
"""
|
||||
Check if the given entropy or variance of entropy exceeds the specified thresholds.
|
||||
|
||||
|
|
@ -48,19 +48,21 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
|||
return entropy.item(), varentropy.item()
|
||||
|
||||
|
||||
def check_parameter_property(api_description, parameter_name, property_name):
|
||||
def is_parameter_property(
|
||||
function_description: Dict, parameter_name: str, property_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter in an API description has a specific property.
|
||||
|
||||
Args:
|
||||
api_description (dict): The API description in JSON format.
|
||||
function_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
property_name (str): The property to look for (e.g., 'format', 'default').
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
parameters = api_description.get("properties", {})
|
||||
parameters = function_description.get("properties", {})
|
||||
parameter_info = parameters.get(parameter_name, {})
|
||||
|
||||
return property_name in parameter_info
|
||||
|
|
@ -84,7 +86,7 @@ class HallucinationStateHandler:
|
|||
current_token (str): The current token being processed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, response_iterator=None, function=None):
|
||||
"""
|
||||
Initializes the HallucinationStateHandler with default values.
|
||||
"""
|
||||
|
|
@ -99,18 +101,56 @@ class HallucinationStateHandler:
|
|||
|
||||
self.token_probs_map = []
|
||||
self.current_token = None
|
||||
self.response_iterator = response_iterator
|
||||
self.process_function(function)
|
||||
|
||||
def process_function(self, apis):
|
||||
self.apis = apis
|
||||
if self.apis is None:
|
||||
def process_function(self, function):
|
||||
self.function = function
|
||||
if self.function is None:
|
||||
raise ValueError("API descriptions not set.")
|
||||
parameter_names = {}
|
||||
for func in self.apis:
|
||||
for func in self.function:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
self.function_description = parameter_names
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.apis}
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
||||
|
||||
def check_token_hallucination(self, token, logprob):
|
||||
"""
|
||||
Check if the given token is hallucinated based on the log probability.
|
||||
|
||||
Args:
|
||||
token (str): The token to check.
|
||||
logprob (float): The log probability of the token.
|
||||
|
||||
Returns:
|
||||
bool: True if the token is hallucinated, False otherwise.
|
||||
"""
|
||||
self.current_token = token
|
||||
self.tokens.append(token)
|
||||
self.logprobs.append(logprob)
|
||||
self.process_token()
|
||||
return self.hallucination
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.response_iterator is not None:
|
||||
try:
|
||||
r = next(self.response_iterator)
|
||||
if hasattr(r.choices[0].delta, "content"):
|
||||
token_content = r.choices[0].delta.content
|
||||
if token_content:
|
||||
logprobs = [
|
||||
p.logprob
|
||||
for p in r.choices[0].logprobs.content[0].top_logprobs
|
||||
]
|
||||
self.check_token_hallucination(token_content, logprobs)
|
||||
return token_content
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
||||
def process_token(self):
|
||||
"""
|
||||
|
|
@ -123,42 +163,52 @@ class HallucinationStateHandler:
|
|||
self.check_logprob()
|
||||
|
||||
# Function name extraction logic
|
||||
# If the state is function name and the token is not an end token, add to the mask
|
||||
if self.state == "function_name":
|
||||
if self.current_token not in const.FUNC_NAME_END_TOKEN:
|
||||
self.mask.append("f")
|
||||
else:
|
||||
self.state = None
|
||||
self.check_function_name()
|
||||
self.is_function_name_hallucinated()
|
||||
|
||||
# Check if the token is a function name start token, change the state
|
||||
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
||||
print("function name entered")
|
||||
self.state = "function_name"
|
||||
|
||||
# Parameter name extraction logic
|
||||
# if the state is parameter name and the token is not an end token, add to the mask
|
||||
if self.state == "parameter_name" and not content.endswith(
|
||||
const.PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.mask.append("p")
|
||||
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
|
||||
# The need for parameter name done is to allow the check of parameter value pattern
|
||||
elif self.state == "parameter_name" and content.endswith(
|
||||
const.PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.state = None
|
||||
self.check_parameter_name()
|
||||
self.is_parameter_name_hallucinated()
|
||||
self.parameter_name_done = True
|
||||
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
const.PARAMETER_NAME_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# if token is a first parameter value start token, change the state
|
||||
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# Parameter value extraction logic
|
||||
# if the state is parameter value and the token is not an end token, add to the mask
|
||||
if self.state == "parameter_value" and not content.endswith(
|
||||
const.PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
# checking if the token is a value token and is not empty
|
||||
if self.current_token.strip() not in ['"', ""]:
|
||||
self.mask.append("v")
|
||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != "v"
|
||||
|
|
@ -171,17 +221,19 @@ class HallucinationStateHandler:
|
|||
self.check_logprob()
|
||||
else:
|
||||
self.mask.append("e")
|
||||
|
||||
# if the state is parameter value and the token is an end token, change the state
|
||||
elif self.state == "parameter_value" and content.endswith(
|
||||
const.PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
self.state = None
|
||||
# if the parameter name is done and the token is a parameter value start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
const.PARAMETER_VALUE_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_value"
|
||||
|
||||
# Maintain consistency between stack and mask
|
||||
# If the mask length is less than tokens, add an not used (e) token to the mask
|
||||
if len(self.mask) != len(self.tokens):
|
||||
self.mask.append("e")
|
||||
|
||||
|
|
@ -216,7 +268,7 @@ class HallucinationStateHandler:
|
|||
else 0
|
||||
)
|
||||
|
||||
def check_function_name(self):
|
||||
def is_function_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted function name against the function descriptions.
|
||||
Detects hallucinations if the function name is not found.
|
||||
|
|
@ -227,7 +279,7 @@ class HallucinationStateHandler:
|
|||
self.hallucination = True
|
||||
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||
|
||||
def check_parameter_name(self):
|
||||
def is_parameter_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted parameter name against the function descriptions.
|
||||
Detects hallucinations if the parameter name is not found.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue