From a40cdc7b7587fe9d8f9a616fb8fbf34f527616c5 Mon Sep 17 00:00:00 2001 From: CTran Date: Sun, 8 Dec 2024 08:56:35 -0800 Subject: [PATCH] Cotran/intent (#339) * add else * integrate hallucination * remove test --- .../app/model_handler/function_calling.py | 57 ++++++++++++------- .../model_handler/hallucination_handler.py | 57 ++++++++++++++++--- 2 files changed, 86 insertions(+), 28 deletions(-) diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index 78b058a3..167f52ff 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -12,6 +12,7 @@ from app.model_handler.base_handler import ( ChatCompletionResponse, ArchBaseHandler, ) +from app.function_calling.hallucination_handler import HallucinationStateHandler SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"] @@ -79,8 +80,10 @@ class ArchIntentHandler(ArchBaseHandler): Returns: bool: A boolean value to indicate if any intent match with prompts or not """ - - return content.choices[0].message.content == "Yes" + if hasattr(content.choices[0].message, "content"): + return content.choices[0].message.content == "Yes" + else: + return False @override async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: @@ -322,7 +325,7 @@ class ArchFunctionHandler(ArchBaseHandler): if required_param not in func_args: is_valid = False error_tool_call = tool_call - error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!" + error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!" return is_valid, error_tool_call, error_message # Verify the data type of each parameter in the tool calls @@ -340,6 +343,36 @@ class ArchFunctionHandler(ArchBaseHandler): return is_valid, error_tool_call, error_message + def _prefill_response(self, messages: List[Dict[str, str]]): + """ + Prefills the response with the tool call prefix. + + Args: + messages (List[Dict[str, str]]): A list of messages. + tools (List[Dict[str, Any]]): A list of tools. + + Returns: + List[Dict[str, str]]: A list of messages with the prefill prefix. + """ + + messages.append( + { + "role": "assistant", + "content": random.choice(self.prefill_prefix), + } + ) + prefill_response = self.client.chat.completions.create( + messages=messages, + model=self.model_name, + stream=False, + extra_body={ + **self.generation_params, + **self.prefill_params, + }, + ) + + return prefill_response + @override async def chat_completion( self, req: ChatMessage, enable_prefilling=True @@ -390,23 +423,7 @@ class ArchFunctionHandler(ArchBaseHandler): # start parameter gathering if the model is not generating a tool call if has_tool_call is False: - messages.append( - { - "role": "assistant", - "content": random.choice(self.prefill_prefix), - } - ) - - prefill_response = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - stream=False, - extra_body={ - **self.generation_params, - **self.prefill_params, - }, - ) - + prefill_response = self._prefill_response(messages) model_response = prefill_response.choices[0].message.content else: model_response = response.choices[0].message.content diff --git a/model_server/app/model_handler/hallucination_handler.py b/model_server/app/model_handler/hallucination_handler.py index 09607db3..7353312a 100644 --- a/model_server/app/model_handler/hallucination_handler.py +++ b/model_server/app/model_handler/hallucination_handler.py @@ -72,6 +72,25 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]: return entropy.item(), varentropy.item() +def is_parameter_required( + function_description: Dict, + parameter_name: str, +) -> bool: + """ + Check if a parameter in required list + + Args: + function_description (dict): The API description in JSON format. + parameter_name (str): The name of the parameter to check. + + Returns: + bool: True if the parameter has the specified property, False otherwise. + """ + required_parameters = function_description.get("required", {}) + + return parameter_name in required_parameters + + class HallucinationStateHandler: """ A class to handle the state of hallucination detection in token processing. @@ -104,6 +123,7 @@ class HallucinationStateHandler: self.parameter_name: List[str] = [] self.token_probs_map: List[Tuple[str, float, float]] = [] self.response_iterator = response_iterator + self.has_tool_call = False def append_and_check_token_hallucination(self, token, logprob): """ @@ -118,7 +138,8 @@ class HallucinationStateHandler: """ self.tokens.append(token) self.logprobs.append(logprob) - self._process_token() + if self.has_tool_call: + self._process_token() return self.hallucination def __iter__(self): @@ -164,7 +185,7 @@ class HallucinationStateHandler: self.mask.append(MaskToken.FUNCTION_NAME) else: self.state = None - self._is_function_name_hallucinated() + self._get_function_name() # Check if the token is a function name start token, change the state if content.endswith(FUNC_NAME_START_PATTERN): @@ -182,8 +203,8 @@ class HallucinationStateHandler: PARAMETER_NAME_END_TOKENS ): self.state = None - self._is_parameter_name_hallucinated() self.parameter_name_done = True + self._get_parameter_name() # if the parameter name is done and the token is a parameter name start token, change the state elif self.parameter_name_done and content.endswith( PARAMETER_NAME_START_PATTERN @@ -208,11 +229,10 @@ class HallucinationStateHandler: if ( len(self.mask) > 1 and self.mask[-2] != MaskToken.PARAMETER_VALUE - # and not is_parameter_property( - # self.function_properties[self.function_name], - # self.parameter_name[-1], - # "default", - # ) + and is_parameter_required( + self.function_properties[self.function_name], + self.parameter_name[-1], + ) ): self._check_logprob() else: @@ -266,3 +286,24 @@ class HallucinationStateHandler: if self.mask and self.mask[-1] == token else 0 ) + + def _get_parameter_name(self): + """ + Get the parameter name from the tokens. + + Returns: + str: The extracted parameter name. + """ + p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME) + parameter_name = "".join(self.tokens[:-1][-p_len:]) + self.parameter_name.append(parameter_name) + + def _get_function_name(self): + """ + Get the function name from the tokens. + + Returns: + str: The extracted function name. + """ + f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME) + self.function_name = "".join(self.tokens[:-1][-f_len:])