mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
fix
This commit is contained in:
parent
673b187eb5
commit
075c94fd39
1 changed files with 21 additions and 26 deletions
|
|
@ -1,14 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import ast
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
import app.commons.constants as const
|
import app.commons.constants as const
|
||||||
import itertools
|
import itertools
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
||||||
|
|
@ -84,24 +80,22 @@ class HallucinationStateHandler:
|
||||||
parameter_name (list): List of extracted parameter names.
|
parameter_name (list): List of extracted parameter names.
|
||||||
function_description (dict): Description of functions and their parameters.
|
function_description (dict): Description of functions and their parameters.
|
||||||
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
||||||
current_token (str): The current token being processed.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, response_iterator=None, function=None):
|
def __init__(self, response_iterator=None, function=None):
|
||||||
"""
|
"""
|
||||||
Initializes the HallucinationStateHandler with default values.
|
Initializes the HallucinationStateHandler with default values.
|
||||||
"""
|
"""
|
||||||
self.tokens = []
|
self.tokens: List[str] = []
|
||||||
self.logprobs = []
|
self.logprobs: List[float] = []
|
||||||
self.state = None
|
self.state: str = None
|
||||||
self.mask = []
|
self.mask: List[str] = []
|
||||||
self.parameter_name_done = False
|
self.parameter_name_done: bool = False
|
||||||
self.hallucination = False
|
self.hallucination: bool = False
|
||||||
self.hallucination_message = ""
|
self.error_message: str = ""
|
||||||
self.parameter_name = []
|
self.error_type: str = ""
|
||||||
|
self.parameter_name: List[str] = []
|
||||||
self.token_probs_map = []
|
self.token_probs_map: List[Tuple[str, float, float]] = []
|
||||||
self.current_token = None
|
|
||||||
self.response_iterator = response_iterator
|
self.response_iterator = response_iterator
|
||||||
self._process_function(function)
|
self._process_function(function)
|
||||||
|
|
||||||
|
|
@ -128,7 +122,6 @@ class HallucinationStateHandler:
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the token is hallucinated, False otherwise.
|
bool: True if the token is hallucinated, False otherwise.
|
||||||
"""
|
"""
|
||||||
self.current_token = token
|
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
self.logprobs.append(logprob)
|
self.logprobs.append(logprob)
|
||||||
self._process_token()
|
self._process_token()
|
||||||
|
|
@ -166,14 +159,14 @@ class HallucinationStateHandler:
|
||||||
Detects hallucinations based on the token type and log probabilities.
|
Detects hallucinations based on the token type and log probabilities.
|
||||||
"""
|
"""
|
||||||
content = "".join(self.tokens).replace(" ", "")
|
content = "".join(self.tokens).replace(" ", "")
|
||||||
if self.current_token == const.TOOL_CALL_TOKEN:
|
if self.tokens[-1] == const.TOOL_CALL_TOKEN:
|
||||||
self.mask.append(const.MaskToken.TOOL_CALL)
|
self.mask.append(const.MaskToken.TOOL_CALL)
|
||||||
self._check_logprob()
|
self._check_logprob()
|
||||||
|
|
||||||
# Function name extraction logic
|
# Function name extraction logic
|
||||||
# If the state is function name and the token is not an end token, add to the mask
|
# If the state is function name and the token is not an end token, add to the mask
|
||||||
if self.state == "function_name":
|
if self.state == "function_name":
|
||||||
if self.current_token not in const.FUNC_NAME_END_TOKEN:
|
if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN:
|
||||||
self.mask.append(const.MaskToken.FUNCTION_NAME)
|
self.mask.append(const.MaskToken.FUNCTION_NAME)
|
||||||
else:
|
else:
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
@ -181,7 +174,6 @@ class HallucinationStateHandler:
|
||||||
|
|
||||||
# Check if the token is a function name start token, change the state
|
# Check if the token is a function name start token, change the state
|
||||||
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
||||||
print("function name entered")
|
|
||||||
self.state = "function_name"
|
self.state = "function_name"
|
||||||
|
|
||||||
# Parameter name extraction logic
|
# Parameter name extraction logic
|
||||||
|
|
@ -214,7 +206,7 @@ class HallucinationStateHandler:
|
||||||
const.PARAMETER_VALUE_END_TOKEN
|
const.PARAMETER_VALUE_END_TOKEN
|
||||||
):
|
):
|
||||||
# checking if the token is a value token and is not empty
|
# checking if the token is a value token and is not empty
|
||||||
if self.current_token.strip() not in ['"', ""]:
|
if self.tokens[-1].strip() not in ['"', ""]:
|
||||||
self.mask.append(const.MaskToken.PARAMETER_VALUE)
|
self.mask.append(const.MaskToken.PARAMETER_VALUE)
|
||||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||||
if (
|
if (
|
||||||
|
|
@ -258,7 +250,10 @@ class HallucinationStateHandler:
|
||||||
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
||||||
):
|
):
|
||||||
self.hallucination = True
|
self.hallucination = True
|
||||||
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
|
self.error_type = "Hallucination"
|
||||||
|
self.error_message = (
|
||||||
|
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
|
||||||
|
)
|
||||||
|
|
||||||
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
|
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
@ -284,8 +279,8 @@ class HallucinationStateHandler:
|
||||||
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
|
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
|
||||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||||
if self.function_name not in self.function_description.keys():
|
if self.function_name not in self.function_description.keys():
|
||||||
self.hallucination = True
|
self.error_type = "function_name"
|
||||||
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||||
|
|
||||||
def _is_parameter_name_hallucinated(self):
|
def _is_parameter_name_hallucinated(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -296,5 +291,5 @@ class HallucinationStateHandler:
|
||||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||||
self.parameter_name.append(parameter_name)
|
self.parameter_name.append(parameter_name)
|
||||||
if parameter_name not in self.function_description[self.function_name]:
|
if parameter_name not in self.function_description[self.function_name]:
|
||||||
self.hallucination = True
|
self.error_type = "parameter_name"
|
||||||
self.hallucination_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
|
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue