From 770ebbdd4e2bd7e8cf920e5604242cf4c9efcdd9 Mon Sep 17 00:00:00 2001 From: cotran Date: Wed, 11 Dec 2024 13:33:38 -0800 Subject: [PATCH] add type check and length checl --- model_server/src/core/function_calling.py | 74 ++++++++++++++++++++--- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index ab3c7a56..b0e6414e 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -356,6 +356,17 @@ class ArchFunctionHandler(ArchBaseHandler): return {"result": tool_calls, "status": is_valid, "message": error_message} + def _correcting_type(value, target_type): + try: + if target_type == float and isinstance(value, int): + return float(value) + elif target_type == list and isinstance(value, str): + return ast.literal_eval(value) + # Add more conversion rules as needed + except (ValueError, TypeError, json.JSONDecodeError): + pass + return value + def _verify_tool_calls( self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]] ) -> Dict[str, any]: @@ -410,7 +421,8 @@ class ArchFunctionHandler(ArchBaseHandler): if data_type in self.support_data_types: if not isinstance( - param_value, self.support_data_types[data_type] + self._correcting_type(param_value), + self.support_data_types[data_type], ): is_valid = False invalid_tool_call = tool_call @@ -457,6 +469,48 @@ class ArchFunctionHandler(ArchBaseHandler): ) return prefill_response + def _check_length_and_pop_messages(messages, max_tokens=4096): + """ + Trims the `messages` list to ensure the total token count does not exceed `max_tokens`. + + Args: + messages (list): List of message dictionaries. + max_tokens (int): Maximum allowed token count. + + Returns: + list: Trimmed list of messages. + """ + + def estimate_token_length(messages): + """Estimate the total token length of the messages.""" + total_tokens = 0 + for message in messages: + # Approximate token length: assuming ~4 characters per token on average + total_tokens += len(message["content"]) // 4 + return total_tokens + + # Calculate initial token length + total_tokens = estimate_token_length(messages) + + # Trim messages if token count exceeds the limit + while total_tokens > max_tokens: + # Find the first non-system message pair + for i in range(len(messages)): + if messages[i]["role"] != "system": + # Remove the 'user'/'assistant' pair + if i + 1 < len(messages) and messages[i + 1]["role"] in [ + "user", + "assistant", + ]: + del messages[i : i + 2] + else: + del messages[i] + break + # Recalculate token length + total_tokens = estimate_token_length(messages) + + return messages + @override async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: """ @@ -465,7 +519,6 @@ class ArchFunctionHandler(ArchBaseHandler): Args: req (ChatMessage): A chat message request object. enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True. - Returns: ChatCompletionResponse: The model's response to the chat request. @@ -474,6 +527,7 @@ class ArchFunctionHandler(ArchBaseHandler): """ messages = self._process_messages(req.messages, req.tools) + messages = self._check_length_and_pop_messages(messages) # always enable `stream=True` to collect model responses response = self.client.chat.completions.create( @@ -488,15 +542,15 @@ class ArchFunctionHandler(ArchBaseHandler): response_iterator=response, function=req.tools ) - model_response, has_tool_call = "", None + model_response, self.has_tool_call = "", None for _ in self.hallu_handler: # check if the first token is - if len(self.hallu_handler.tokens) > 0 and has_tool_call is None: + if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None: if self.hallu_handler.tokens[0] == "": - has_tool_call = True + self.has_tool_call = True else: - has_tool_call = False + self.has_tool_call = False break # if the model is hallucinating, start parameter gathering @@ -512,13 +566,13 @@ class ArchFunctionHandler(ArchBaseHandler): model_response = prefill_response.choices[0].message.content break - if has_tool_call and self.hallu_handler.hallucination is False: + if self.has_tool_call and self.hallu_handler.hallucination is False: # [TODO] - Review: remove the following code print("Tool call found, no hallucination detected!") model_response = "".join(self.hallu_handler.tokens) # start parameter gathering if the model is not generating tool calls - if has_tool_call is False: + if self.has_tool_call is False: # [TODO] - Review: remove the following code print("No tool call found, start parameter gathering") print(f"Token entropy/varentropy map: {self.hallu_handler.token_probs_map}") @@ -528,7 +582,7 @@ class ArchFunctionHandler(ArchBaseHandler): # Extract tool calls from model response extracted = self._extract_tool_calls(model_response) # [TODO] - Review: remvoe the following code - print(f"[Extracted] - {extracted}") + # print(f"[Extracted] - {extracted}") if len(extracted["result"]) and extracted["status"]: # [TODO] Review: define the behavior in the case that tool call extraction fails @@ -538,7 +592,7 @@ class ArchFunctionHandler(ArchBaseHandler): tools=req.tools, tool_calls=extracted["result"] ) # [TODO] - Review: remvoe the following code - print(f"[Verified] - {verified}") + # print(f"[Verified] - {verified}") # [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately if verified["status"]: