diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 59ba57b9..ffb15351 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -272,9 +272,9 @@ class ArchFunctionHandler(ArchBaseHandler): if required_param not in func_args: verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict[ - "error_message" - ] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" + verification_dict["error_message"] = ( + f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!" + ) break # Verify the data type of each parameter in the tool calls @@ -286,9 +286,9 @@ class ArchFunctionHandler(ArchBaseHandler): if param_name not in function_properties: verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict[ - "error_message" - ] = f"Parameter `{param_name}` is not defined in the function `{func_name}`." + verification_dict["error_message"] = ( + f"Parameter `{param_name}` is not defined in the function `{func_name}`." + ) break else: param_value = func_args[param_name] @@ -304,16 +304,16 @@ class ArchFunctionHandler(ArchBaseHandler): if not isinstance(param_value, data_type): verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict[ - "error_message" - ] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." + verification_dict["error_message"] = ( + f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`." + ) break else: verification_dict["is_valid"] = False verification_dict["invalid_tool_call"] = tool_call - verification_dict[ - "error_message" - ] = f"Data type `{target_type}` is not supported." + verification_dict["error_message"] = ( + f"Data type `{target_type}` is not supported." + ) return verification_dict @@ -376,8 +376,8 @@ class ArchFunctionHandler(ArchBaseHandler): has_tool_calls, has_hallucination = None, False for _ in self.hallucination_state: - # check if moodel response starts with tool calls - if len(self.hallucination_state.tokens)>5 and has_tool_calls is None: + # check if moodel response starts with tool calls, we do it after 5 tokens because we only check the first part of the response. + if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None: content = "".join(self.hallucination_state.tokens) if "tool_calls" in content: has_tool_calls = True diff --git a/model_server/src/core/utils/hallucination_utils.py b/model_server/src/core/utils/hallucination_utils.py index 51216c12..992f8aa4 100644 --- a/model_server/src/core/utils/hallucination_utils.py +++ b/model_server/src/core/utils/hallucination_utils.py @@ -201,7 +201,7 @@ class HallucinationState: r = next(self.response_iterator) if hasattr(r.choices[0].delta, "content"): token_content = r.choices[0].delta.content - if token_content != '': + if token_content != "": try: logprobs = [ p.logprob @@ -214,7 +214,7 @@ class HallucinationState: self.append_and_check_token_hallucination( token_content, [None] ) - + return token_content except StopIteration: raise StopIteration diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index e0cfbd82..6005d5e6 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -110,6 +110,6 @@ async def test_function_calling(get_data_func): final_response = await model_handler.chat_completion(req) latency = time.perf_counter() - start_time - assert intent == (len(final_response.choices[0].message.tool_calls)>=1) + assert intent == (len(final_response.choices[0].message.tool_calls) >= 1) assert hallucination == model_handler.hallucination_state.hallucination