mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add hallucination modification
This commit is contained in:
parent
e5949c584f
commit
1f9e147999
3 changed files with 58 additions and 67 deletions
|
|
@ -406,47 +406,47 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
# *********************************************************************************************
|
||||
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
# self.hallucination_state = HallucinationState(
|
||||
# response_iterator=response, function=req.tools
|
||||
# )
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None:
|
||||
content = ''.join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
break
|
||||
|
||||
# has_tool_calls, has_hallucination = None, False
|
||||
# for _ in self.hallucination_state:
|
||||
# # check if the first token is <tool_call>
|
||||
# if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
# if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
# has_tool_calls = True
|
||||
# else:
|
||||
# has_tool_calls = False
|
||||
# break
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
|
||||
# # if the model is hallucinating, start parameter gathering
|
||||
# if self.hallucination_state.hallucination is True:
|
||||
# has_hallucination = True
|
||||
# break
|
||||
|
||||
# if has_tool_calls:
|
||||
# if has_hallucination:
|
||||
# # start prompt prefilling if hallcuination is found in tool calls
|
||||
# logger.info(
|
||||
# f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
# )
|
||||
# prefill_response = self._engage_parameter_gathering(messages)
|
||||
# model_response = prefill_response.choices[0].message.content
|
||||
# else:
|
||||
# model_response = "".join(self.hallucination_state.tokens)
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
# else:
|
||||
# # start parameter gathering if the model is not generating tool calls
|
||||
# prefill_response = self._engage_parameter_gathering(messages)
|
||||
# model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# *********************************************************************************************\
|
||||
# TODO: Remove the following for loop after updating hallucination check
|
||||
# *********************************************************************************************
|
||||
for chunk in response:
|
||||
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
model_response += chunk.choices[0].delta.content
|
||||
# # *********************************************************************************************\
|
||||
# # TODO: Remove the following for loop after updating hallucination check
|
||||
# # *********************************************************************************************
|
||||
# for chunk in response:
|
||||
# if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
# model_response += chunk.choices[0].delta.content
|
||||
|
||||
# Extract tool calls from model response
|
||||
response_dict = self._parse_model_resonse(model_response)
|
||||
|
|
|
|||
|
|
@ -13,16 +13,15 @@ from src.commons.utils import get_model_server_logger
|
|||
logger = get_model_server_logger()
|
||||
|
||||
# constants
|
||||
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
||||
FUNC_NAME_START_PATTERN = ('{"name":"', "{'name':'")
|
||||
FUNC_NAME_END_TOKEN = ('",', "',")
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
END_TOOL_CALL_TOKEN = "</tool_call>"
|
||||
END_TOOL_CALL_TOKEN = "}}"
|
||||
|
||||
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
||||
PARAMETER_NAME_START_PATTERN = (',"', ",'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'", '":"', "':'")
|
||||
PARAMETER_NAME_START_PATTERN = ('","', "','")
|
||||
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', '"}')
|
||||
|
||||
BRACKETS = {"(": ")", "{": "}", "[": "]"}
|
||||
|
||||
|
|
@ -37,16 +36,9 @@ class MaskToken(Enum):
|
|||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {
|
||||
"entropy": 0.35,
|
||||
"varentropy": 1.7,
|
||||
"probability": 0.8,
|
||||
},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.4,
|
||||
"probability": 0.8,
|
||||
},
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.4,
|
||||
"probability": 0.8,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -160,6 +152,7 @@ class HallucinationState:
|
|||
self._process_function(function)
|
||||
self.open_bracket = False
|
||||
self.bracket = None
|
||||
self.function_name = ""
|
||||
self.check_parameter_name = {}
|
||||
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
|
||||
|
||||
|
|
@ -218,12 +211,10 @@ class HallucinationState:
|
|||
raise ValueError(
|
||||
f"Error extracting logprobs from response: {e}"
|
||||
)
|
||||
if token_content == END_TOOL_CALL_TOKEN:
|
||||
self._reset_parameters()
|
||||
else:
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, logprobs
|
||||
)
|
||||
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, logprobs
|
||||
)
|
||||
return token_content
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
|
@ -233,13 +224,13 @@ class HallucinationState:
|
|||
Processes the current token and updates the state and mask accordingly.
|
||||
Detects hallucinations based on the token type and log probabilities.
|
||||
"""
|
||||
content = "".join(self.tokens).replace(" ", "")
|
||||
if self.tokens[-1] == TOOL_CALL_TOKEN:
|
||||
self.mask.append(MaskToken.TOOL_CALL)
|
||||
self._check_logprob()
|
||||
content = "".join(self.tokens).replace(" ", "").replace("Ġ",'')
|
||||
|
||||
# Function name extraction logic
|
||||
# If the state is function name and the token is not an end token, add to the mask
|
||||
if content.endswith(END_TOOL_CALL_TOKEN):
|
||||
self._reset_parameters()
|
||||
|
||||
if self.state == "function_name":
|
||||
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
|
||||
self.mask.append(MaskToken.FUNCTION_NAME)
|
||||
|
|
@ -359,7 +350,7 @@ class HallucinationState:
|
|||
if check_threshold(
|
||||
entropy,
|
||||
varentropy,
|
||||
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
|
||||
self.HALLUCINATION_THRESHOLD_DICT,
|
||||
):
|
||||
self.hallucination = True
|
||||
self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}"
|
||||
|
|
|
|||
|
|
@ -101,10 +101,10 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
# *********************************************************************************************
|
||||
# TODO: Put the following code back when hallucination check is ready
|
||||
# *********************************************************************************************
|
||||
# if not use_agent_orchestrator:
|
||||
# final_response.metadata["hallucination"] = str(
|
||||
# model_handler.hallucination_state.hallucination
|
||||
# )
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["hallucination"] = str(
|
||||
model_handler.hallucination_state.hallucination
|
||||
)
|
||||
# No intent detected
|
||||
else:
|
||||
final_response.metadata = {
|
||||
|
|
@ -117,9 +117,9 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
# *********************************************************************************************
|
||||
# TODO: Put the following code back when hallucination check is ready
|
||||
# *********************************************************************************************
|
||||
# final_response.metadata["hallucination"] = str(
|
||||
# model_handler.hallucination_state.hallucination
|
||||
# )
|
||||
final_response.metadata["hallucination"] = str(
|
||||
model_handler.hallucination_state.hallucination
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue