diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 1aa9ad46..cf2adb28 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -260,8 +260,7 @@ class ArchFunctionHandler(ArchBaseHandler): functions = {} for tool in tools: - if tool["type"] == "function": - functions[tool["function"]["name"]] = tool["function"]["parameters"] + functions[tool["function"]["name"]] = tool["function"]["parameters"] for tool_call in tool_calls: if not verification_dict["is_valid"]: @@ -420,26 +419,28 @@ class ArchFunctionHandler(ArchBaseHandler): if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None: content = "".join(self.hallucination_state.tokens) if "tool_calls" in content: + logger.info( + f"[Content]: {content}" + ) has_tool_calls = True else: has_tool_calls = False - break + # if the model is hallucinating, start parameter gathering if self.hallucination_state.hallucination is True: has_hallucination = True break - if has_tool_calls: - if has_hallucination: - # start prompt prefilling if hallcuination is found in tool calls - logger.info( - f"[Hallucination]: {self.hallucination_state.error_message}" - ) - prefill_response = self._engage_parameter_gathering(messages) - model_response = prefill_response.choices[0].message.content - else: - model_response = "".join(self.hallucination_state.tokens) + if has_tool_calls and has_hallucination: + # start prompt prefilling if hallcuination is found in tool calls + logger.info( + f"[Hallucination]: {self.hallucination_state.error_message}" + ) + prefill_response = self._engage_parameter_gathering(messages) + model_response = prefill_response.choices[0].message.content + else: + model_response = "".join(self.hallucination_state.tokens) # else: # # start parameter gathering if the model is not generating tool calls # prefill_response = self._engage_parameter_gathering(messages)