mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
move constatns
This commit is contained in:
parent
075c94fd39
commit
4e7572f501
1 changed files with 51 additions and 22 deletions
|
|
@ -3,8 +3,37 @@ import math
|
|||
import torch
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import app.commons.constants as const
|
||||
import itertools
|
||||
from enum import Enum
|
||||
|
||||
# constants
|
||||
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
||||
FUNC_NAME_END_TOKEN = ('",', "',")
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
|
||||
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
||||
PARAMETER_NAME_START_PATTERN = (',"', ",'")
|
||||
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
||||
|
||||
|
||||
# Thresholds
|
||||
class MaskToken(Enum):
|
||||
FUNCTION_NAME = "f"
|
||||
PARAMETER_VALUE = "v"
|
||||
PARAMETER_NAME = "p"
|
||||
NOT_USED = "e"
|
||||
TOOL_CALL = "t"
|
||||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.5,
|
||||
"varentropy": 2.5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
||||
|
|
@ -159,59 +188,59 @@ class HallucinationStateHandler:
|
|||
Detects hallucinations based on the token type and log probabilities.
|
||||
"""
|
||||
content = "".join(self.tokens).replace(" ", "")
|
||||
if self.tokens[-1] == const.TOOL_CALL_TOKEN:
|
||||
self.mask.append(const.MaskToken.TOOL_CALL)
|
||||
if self.tokens[-1] == TOOL_CALL_TOKEN:
|
||||
self.mask.append(MaskToken.TOOL_CALL)
|
||||
self._check_logprob()
|
||||
|
||||
# Function name extraction logic
|
||||
# If the state is function name and the token is not an end token, add to the mask
|
||||
if self.state == "function_name":
|
||||
if self.tokens[-1] not in const.FUNC_NAME_END_TOKEN:
|
||||
self.mask.append(const.MaskToken.FUNCTION_NAME)
|
||||
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
|
||||
self.mask.append(MaskToken.FUNCTION_NAME)
|
||||
else:
|
||||
self.state = None
|
||||
self._is_function_name_hallucinated()
|
||||
|
||||
# Check if the token is a function name start token, change the state
|
||||
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
||||
if content.endswith(FUNC_NAME_START_PATTERN):
|
||||
self.state = "function_name"
|
||||
|
||||
# Parameter name extraction logic
|
||||
# if the state is parameter name and the token is not an end token, add to the mask
|
||||
if self.state == "parameter_name" and not content.endswith(
|
||||
const.PARAMETER_NAME_END_TOKENS
|
||||
PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.mask.append(const.MaskToken.PARAMETER_NAME)
|
||||
self.mask.append(MaskToken.PARAMETER_NAME)
|
||||
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
|
||||
# The need for parameter name done is to allow the check of parameter value pattern
|
||||
elif self.state == "parameter_name" and content.endswith(
|
||||
const.PARAMETER_NAME_END_TOKENS
|
||||
PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.state = None
|
||||
self._is_parameter_name_hallucinated()
|
||||
self.parameter_name_done = True
|
||||
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
const.PARAMETER_NAME_START_PATTERN
|
||||
PARAMETER_NAME_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# if token is a first parameter value start token, change the state
|
||||
if content.endswith(const.FIRST_PARAM_NAME_START_PATTERN):
|
||||
if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
|
||||
self.state = "parameter_name"
|
||||
|
||||
# Parameter value extraction logic
|
||||
# if the state is parameter value and the token is not an end token, add to the mask
|
||||
if self.state == "parameter_value" and not content.endswith(
|
||||
const.PARAMETER_VALUE_END_TOKEN
|
||||
PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
# checking if the token is a value token and is not empty
|
||||
if self.tokens[-1].strip() not in ['"', ""]:
|
||||
self.mask.append(const.MaskToken.PARAMETER_VALUE)
|
||||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != const.MaskToken.PARAMETER_VALUE
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
|
|
@ -220,22 +249,22 @@ class HallucinationStateHandler:
|
|||
):
|
||||
self._check_logprob()
|
||||
else:
|
||||
self.mask.append(const.MaskToken.NOT_USED)
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
# if the state is parameter value and the token is an end token, change the state
|
||||
elif self.state == "parameter_value" and content.endswith(
|
||||
const.PARAMETER_VALUE_END_TOKEN
|
||||
PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
self.state = None
|
||||
# if the parameter name is done and the token is a parameter value start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
const.PARAMETER_VALUE_START_PATTERN
|
||||
PARAMETER_VALUE_START_PATTERN
|
||||
):
|
||||
self.state = "parameter_value"
|
||||
|
||||
# Maintain consistency between stack and mask
|
||||
# If the mask length is less than tokens, add an not used (e) token to the mask
|
||||
if len(self.mask) != len(self.tokens):
|
||||
self.mask.append(const.MaskToken.NOT_USED)
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
|
||||
def _check_logprob(self):
|
||||
"""
|
||||
|
|
@ -247,7 +276,7 @@ class HallucinationStateHandler:
|
|||
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
|
||||
|
||||
if check_threshold(
|
||||
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
||||
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
||||
):
|
||||
self.hallucination = True
|
||||
self.error_type = "Hallucination"
|
||||
|
|
@ -255,7 +284,7 @@ class HallucinationStateHandler:
|
|||
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
|
||||
)
|
||||
|
||||
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
|
||||
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
|
||||
"""
|
||||
Counts the number of consecutive occurrences of a given token in the mask.
|
||||
|
||||
|
|
@ -276,7 +305,7 @@ class HallucinationStateHandler:
|
|||
Checks the extracted function name against the function descriptions.
|
||||
Detects hallucinations if the function name is not found.
|
||||
"""
|
||||
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
|
||||
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
if self.function_name not in self.function_description.keys():
|
||||
self.error_type = "function_name"
|
||||
|
|
@ -287,7 +316,7 @@ class HallucinationStateHandler:
|
|||
Checks the extracted parameter name against the function descriptions.
|
||||
Detects hallucinations if the parameter name is not found.
|
||||
"""
|
||||
p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME)
|
||||
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
|
||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||
self.parameter_name.append(parameter_name)
|
||||
if parameter_name not in self.function_description[self.function_name]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue