add hallucination modification (#455)

* add hallucination modification

* disable test
This commit is contained in:
CTran 2025-03-28 09:49:20 -07:00 committed by GitHub
parent e48918259e
commit a3f2b3cef9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 86 additions and 95 deletions

View file

@ -406,47 +406,47 @@ class ArchFunctionHandler(ArchBaseHandler):
# *********************************************************************************************
# initialize the hallucination handler, which is an iterator
self.hallucination_state = HallucinationState(
response_iterator=response, function=req.tools
)
# self.hallucination_state = HallucinationState(
# response_iterator=response, function=req.tools
# )
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None:
content = ''.join(self.hallucination_state.tokens)
if "tool_calls" in content:
has_tool_calls = True
else:
has_tool_calls = False
break
# has_tool_calls, has_hallucination = None, False
# for _ in self.hallucination_state:
# # check if the first token is <tool_call>
# if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
# if self.hallucination_state.tokens[0] == "<tool_call>":
# has_tool_calls = True
# else:
# has_tool_calls = False
# break
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
has_hallucination = True
break
# # if the model is hallucinating, start parameter gathering
# if self.hallucination_state.hallucination is True:
# has_hallucination = True
# break
# if has_tool_calls:
# if has_hallucination:
# # start prompt prefilling if hallcuination is found in tool calls
# logger.info(
# f"[Hallucination]: {self.hallucination_state.error_message}"
# )
# prefill_response = self._engage_parameter_gathering(messages)
# model_response = prefill_response.choices[0].message.content
# else:
# model_response = "".join(self.hallucination_state.tokens)
if has_tool_calls:
if has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = "".join(self.hallucination_state.tokens)
# else:
# # start parameter gathering if the model is not generating tool calls
# prefill_response = self._engage_parameter_gathering(messages)
# model_response = prefill_response.choices[0].message.content
# *********************************************************************************************\
# TODO: Remove the following for loop after updating hallucination check
# *********************************************************************************************
for chunk in response:
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
model_response += chunk.choices[0].delta.content
# # *********************************************************************************************\
# # TODO: Remove the following for loop after updating hallucination check
# # *********************************************************************************************
# for chunk in response:
# if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
# model_response += chunk.choices[0].delta.content
logger.info(f"[arch-fc]: raw model response: {model_response}")
# Extract tool calls from model response

View file

@ -13,16 +13,15 @@ from src.commons.utils import get_model_server_logger
logger = get_model_server_logger()
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_START_PATTERN = ('{"name":"', "{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
END_TOOL_CALL_TOKEN = "</tool_call>"
END_TOOL_CALL_TOKEN = "}}"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'", '":"', "':'")
PARAMETER_NAME_START_PATTERN = ('","', "','")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
PARAMETER_VALUE_END_TOKEN = ('",', '"}')
BRACKETS = {"(": ")", "{": "}", "[": "]"}
@ -37,16 +36,9 @@ class MaskToken(Enum):
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {
"entropy": 0.35,
"varentropy": 1.7,
"probability": 0.8,
},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.28,
"varentropy": 1.4,
"probability": 0.8,
},
"entropy": 0.28,
"varentropy": 1.4,
"probability": 0.8,
}
@ -160,6 +152,7 @@ class HallucinationState:
self._process_function(function)
self.open_bracket = False
self.bracket = None
self.function_name = ""
self.check_parameter_name = {}
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
@ -218,12 +211,10 @@ class HallucinationState:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
if token_content == END_TOOL_CALL_TOKEN:
self._reset_parameters()
else:
self.append_and_check_token_hallucination(
token_content, logprobs
)
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content
except StopIteration:
raise StopIteration
@ -233,13 +224,13 @@ class HallucinationState:
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob()
content = "".join(self.tokens).replace(" ", "").replace("Ġ",'')
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if content.endswith(END_TOOL_CALL_TOKEN):
self._reset_parameters()
if self.state == "function_name":
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(MaskToken.FUNCTION_NAME)
@ -359,7 +350,7 @@ class HallucinationState:
if check_threshold(
entropy,
varentropy,
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
self.HALLUCINATION_THRESHOLD_DICT,
):
self.hallucination = True
self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}"

View file

@ -100,10 +100,10 @@ async def function_calling(req: ChatMessage, res: Response):
# *********************************************************************************************
# TODO: Put the following code back when hallucination check is ready
# *********************************************************************************************
# if not use_agent_orchestrator:
# final_response.metadata["hallucination"] = str(
# model_handler.hallucination_state.hallucination
# )
if not use_agent_orchestrator:
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
# No intent detected
else:
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
@ -114,9 +114,9 @@ async def function_calling(req: ChatMessage, res: Response):
# *********************************************************************************************
# TODO: Put the following code back when hallucination check is ready
# *********************************************************************************************
# final_response.metadata["hallucination"] = str(
# model_handler.hallucination_state.hallucination
# )
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503