mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
address comments
This commit is contained in:
parent
e44d189d86
commit
e30bbe39e7
3 changed files with 51 additions and 32 deletions
|
|
@ -4,6 +4,7 @@ import app.loader as loader
|
||||||
|
|
||||||
from app.function_calling.model_handler import ArchFunctionHandler
|
from app.function_calling.model_handler import ArchFunctionHandler
|
||||||
from app.prompt_guard.model_handler import ArchGuardHanlder
|
from app.prompt_guard.model_handler import ArchGuardHanlder
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
logger = utils.get_model_server_logger()
|
logger = utils.get_model_server_logger()
|
||||||
|
|
||||||
|
|
@ -38,6 +39,7 @@ arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||||
# Patterns for function name and parameter parsing
|
# Patterns for function name and parameter parsing
|
||||||
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
||||||
FUNC_NAME_END_TOKEN = ('",', "',")
|
FUNC_NAME_END_TOKEN = ('",', "',")
|
||||||
|
TOOL_CALL_TOKEN = "<tool_call>"
|
||||||
|
|
||||||
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
||||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
||||||
|
|
@ -45,10 +47,19 @@ PARAMETER_NAME_START_PATTERN = (',"', ",'")
|
||||||
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
||||||
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
||||||
|
|
||||||
|
|
||||||
# Thresholds
|
# Thresholds
|
||||||
|
class MaskToken(Enum):
|
||||||
|
FUNCTION_NAME = "f"
|
||||||
|
PARAMETER_VALUE = "v"
|
||||||
|
PARAMETER_NAME = "p"
|
||||||
|
NOT_USED = "e"
|
||||||
|
TOOL_CALL = "t"
|
||||||
|
|
||||||
|
|
||||||
HALLUCINATION_THRESHOLD_DICT = {
|
HALLUCINATION_THRESHOLD_DICT = {
|
||||||
"t": {"entropy": 0.1, "varentropy": 0.5},
|
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
|
||||||
"v": {
|
MaskToken.PARAMETER_VALUE.value: {
|
||||||
"entropy": 0.5,
|
"entropy": 0.5,
|
||||||
"varentropy": 2.5,
|
"varentropy": 2.5,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import random
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
import app.commons.constants as const
|
import app.commons.constants as const
|
||||||
import itertools
|
import itertools
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
||||||
|
|
@ -102,9 +103,9 @@ class HallucinationStateHandler:
|
||||||
self.token_probs_map = []
|
self.token_probs_map = []
|
||||||
self.current_token = None
|
self.current_token = None
|
||||||
self.response_iterator = response_iterator
|
self.response_iterator = response_iterator
|
||||||
self.process_function(function)
|
self._process_function(function)
|
||||||
|
|
||||||
def process_function(self, function):
|
def _process_function(self, function):
|
||||||
self.function = function
|
self.function = function
|
||||||
if self.function is None:
|
if self.function is None:
|
||||||
raise ValueError("API descriptions not set.")
|
raise ValueError("API descriptions not set.")
|
||||||
|
|
@ -116,7 +117,7 @@ class HallucinationStateHandler:
|
||||||
self.function_description = parameter_names
|
self.function_description = parameter_names
|
||||||
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
||||||
|
|
||||||
def check_token_hallucination(self, token, logprob):
|
def append_and_check_token_hallucination(self, token, logprob):
|
||||||
"""
|
"""
|
||||||
Check if the given token is hallucinated based on the log probability.
|
Check if the given token is hallucinated based on the log probability.
|
||||||
|
|
||||||
|
|
@ -130,7 +131,7 @@ class HallucinationStateHandler:
|
||||||
self.current_token = token
|
self.current_token = token
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
self.logprobs.append(logprob)
|
self.logprobs.append(logprob)
|
||||||
self.process_token()
|
self._process_token()
|
||||||
return self.hallucination
|
return self.hallucination
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
@ -143,33 +144,40 @@ class HallucinationStateHandler:
|
||||||
if hasattr(r.choices[0].delta, "content"):
|
if hasattr(r.choices[0].delta, "content"):
|
||||||
token_content = r.choices[0].delta.content
|
token_content = r.choices[0].delta.content
|
||||||
if token_content:
|
if token_content:
|
||||||
logprobs = [
|
try:
|
||||||
p.logprob
|
logprobs = [
|
||||||
for p in r.choices[0].logprobs.content[0].top_logprobs
|
p.logprob
|
||||||
]
|
for p in r.choices[0].logprobs.content[0].top_logprobs
|
||||||
self.check_token_hallucination(token_content, logprobs)
|
]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error extracting logprobs from response: {e}"
|
||||||
|
)
|
||||||
|
self.append_and_check_token_hallucination(
|
||||||
|
token_content, logprobs
|
||||||
|
)
|
||||||
return token_content
|
return token_content
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
def process_token(self):
|
def _process_token(self):
|
||||||
"""
|
"""
|
||||||
Processes the current token and updates the state and mask accordingly.
|
Processes the current token and updates the state and mask accordingly.
|
||||||
Detects hallucinations based on the token type and log probabilities.
|
Detects hallucinations based on the token type and log probabilities.
|
||||||
"""
|
"""
|
||||||
content = "".join(self.tokens).replace(" ", "")
|
content = "".join(self.tokens).replace(" ", "")
|
||||||
if self.current_token == "<tool_call>":
|
if self.current_token == const.TOOL_CALL_TOKEN:
|
||||||
self.mask.append("t")
|
self.mask.append(const.MaskToken.TOOL_CALL)
|
||||||
self.check_logprob()
|
self._check_logprob()
|
||||||
|
|
||||||
# Function name extraction logic
|
# Function name extraction logic
|
||||||
# If the state is function name and the token is not an end token, add to the mask
|
# If the state is function name and the token is not an end token, add to the mask
|
||||||
if self.state == "function_name":
|
if self.state == "function_name":
|
||||||
if self.current_token not in const.FUNC_NAME_END_TOKEN:
|
if self.current_token not in const.FUNC_NAME_END_TOKEN:
|
||||||
self.mask.append("f")
|
self.mask.append(const.MaskToken.FUNCTION_NAME)
|
||||||
else:
|
else:
|
||||||
self.state = None
|
self.state = None
|
||||||
self.is_function_name_hallucinated()
|
self._is_function_name_hallucinated()
|
||||||
|
|
||||||
# Check if the token is a function name start token, change the state
|
# Check if the token is a function name start token, change the state
|
||||||
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
if content.endswith(const.FUNC_NAME_START_PATTERN):
|
||||||
|
|
@ -181,14 +189,14 @@ class HallucinationStateHandler:
|
||||||
if self.state == "parameter_name" and not content.endswith(
|
if self.state == "parameter_name" and not content.endswith(
|
||||||
const.PARAMETER_NAME_END_TOKENS
|
const.PARAMETER_NAME_END_TOKENS
|
||||||
):
|
):
|
||||||
self.mask.append("p")
|
self.mask.append(const.MaskToken.PARAMETER_NAME)
|
||||||
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
|
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
|
||||||
# The need for parameter name done is to allow the check of parameter value pattern
|
# The need for parameter name done is to allow the check of parameter value pattern
|
||||||
elif self.state == "parameter_name" and content.endswith(
|
elif self.state == "parameter_name" and content.endswith(
|
||||||
const.PARAMETER_NAME_END_TOKENS
|
const.PARAMETER_NAME_END_TOKENS
|
||||||
):
|
):
|
||||||
self.state = None
|
self.state = None
|
||||||
self.is_parameter_name_hallucinated()
|
self._is_parameter_name_hallucinated()
|
||||||
self.parameter_name_done = True
|
self.parameter_name_done = True
|
||||||
# if the parameter name is done and the token is a parameter name start token, change the state
|
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||||
elif self.parameter_name_done and content.endswith(
|
elif self.parameter_name_done and content.endswith(
|
||||||
|
|
@ -207,20 +215,20 @@ class HallucinationStateHandler:
|
||||||
):
|
):
|
||||||
# checking if the token is a value token and is not empty
|
# checking if the token is a value token and is not empty
|
||||||
if self.current_token.strip() not in ['"', ""]:
|
if self.current_token.strip() not in ['"', ""]:
|
||||||
self.mask.append("v")
|
self.mask.append(const.MaskToken.PARAMETER_VALUE)
|
||||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||||
if (
|
if (
|
||||||
len(self.mask) > 1
|
len(self.mask) > 1
|
||||||
and self.mask[-2] != "v"
|
and self.mask[-2] != const.MaskToken.PARAMETER_VALUE
|
||||||
and not is_parameter_property(
|
and not is_parameter_property(
|
||||||
self.function_properties[self.function_name],
|
self.function_properties[self.function_name],
|
||||||
self.parameter_name[-1],
|
self.parameter_name[-1],
|
||||||
"default",
|
"default",
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
self.check_logprob()
|
self._check_logprob()
|
||||||
else:
|
else:
|
||||||
self.mask.append("e")
|
self.mask.append(const.MaskToken.NOT_USED)
|
||||||
# if the state is parameter value and the token is an end token, change the state
|
# if the state is parameter value and the token is an end token, change the state
|
||||||
elif self.state == "parameter_value" and content.endswith(
|
elif self.state == "parameter_value" and content.endswith(
|
||||||
const.PARAMETER_VALUE_END_TOKEN
|
const.PARAMETER_VALUE_END_TOKEN
|
||||||
|
|
@ -235,9 +243,9 @@ class HallucinationStateHandler:
|
||||||
# Maintain consistency between stack and mask
|
# Maintain consistency between stack and mask
|
||||||
# If the mask length is less than tokens, add an not used (e) token to the mask
|
# If the mask length is less than tokens, add an not used (e) token to the mask
|
||||||
if len(self.mask) != len(self.tokens):
|
if len(self.mask) != len(self.tokens):
|
||||||
self.mask.append("e")
|
self.mask.append(const.MaskToken.NOT_USED)
|
||||||
|
|
||||||
def check_logprob(self):
|
def _check_logprob(self):
|
||||||
"""
|
"""
|
||||||
Checks the log probability of the current token and updates the token probability map.
|
Checks the log probability of the current token and updates the token probability map.
|
||||||
Detects hallucinations based on entropy and variance of entropy.
|
Detects hallucinations based on entropy and variance of entropy.
|
||||||
|
|
@ -247,12 +255,12 @@ class HallucinationStateHandler:
|
||||||
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
|
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
|
||||||
|
|
||||||
if check_threshold(
|
if check_threshold(
|
||||||
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1]]
|
entropy, varentropy, const.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
||||||
):
|
):
|
||||||
self.hallucination = True
|
self.hallucination = True
|
||||||
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
|
self.hallucination_message = f"Token '{self.current_token}' is uncertain."
|
||||||
|
|
||||||
def count_consecutive_token(self, token="v") -> int:
|
def _count_consecutive_token(self, token=const.MaskToken.PARAMETER_VALUE) -> int:
|
||||||
"""
|
"""
|
||||||
Counts the number of consecutive occurrences of a given token in the mask.
|
Counts the number of consecutive occurrences of a given token in the mask.
|
||||||
|
|
||||||
|
|
@ -268,23 +276,23 @@ class HallucinationStateHandler:
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_function_name_hallucinated(self):
|
def _is_function_name_hallucinated(self):
|
||||||
"""
|
"""
|
||||||
Checks the extracted function name against the function descriptions.
|
Checks the extracted function name against the function descriptions.
|
||||||
Detects hallucinations if the function name is not found.
|
Detects hallucinations if the function name is not found.
|
||||||
"""
|
"""
|
||||||
f_len = self.count_consecutive_token("f")
|
f_len = self._count_consecutive_token(const.MaskToken.FUNCTION_NAME)
|
||||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||||
if self.function_name not in self.function_description.keys():
|
if self.function_name not in self.function_description.keys():
|
||||||
self.hallucination = True
|
self.hallucination = True
|
||||||
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
self.hallucination_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||||
|
|
||||||
def is_parameter_name_hallucinated(self):
|
def _is_parameter_name_hallucinated(self):
|
||||||
"""
|
"""
|
||||||
Checks the extracted parameter name against the function descriptions.
|
Checks the extracted parameter name against the function descriptions.
|
||||||
Detects hallucinations if the parameter name is not found.
|
Detects hallucinations if the parameter name is not found.
|
||||||
"""
|
"""
|
||||||
p_len = self.count_consecutive_token("p")
|
p_len = self._count_consecutive_token(const.MaskToken.PARAMETER_NAME)
|
||||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||||
self.parameter_name.append(parameter_name)
|
self.parameter_name.append(parameter_name)
|
||||||
if parameter_name not in self.function_description[self.function_name]:
|
if parameter_name not in self.function_description[self.function_name]:
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ def test_hallucination(case):
|
||||||
)
|
)
|
||||||
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
||||||
if token != "</tool_call>":
|
if token != "</tool_call>":
|
||||||
state.check_token_hallucination(token, logprob)
|
state.append_and_check_token_hallucination(token, logprob)
|
||||||
if state.hallucination:
|
if state.hallucination:
|
||||||
break
|
break
|
||||||
assert state.hallucination == case["expect"]
|
assert state.hallucination == case["expect"]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue