diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index cf2adb28..13a77876 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -416,15 +416,14 @@ class ArchFunctionHandler(ArchBaseHandler): has_tool_calls, has_hallucination = None, False for _ in self.hallucination_state: # check if the first token is - if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None: - content = "".join(self.hallucination_state.tokens) - if "tool_calls" in content: - logger.info( - f"[Content]: {content}" - ) - has_tool_calls = True - else: - has_tool_calls = False + content = "".join(self.hallucination_state.tokens) + if "tool_calls" in content: + logger.info( + f"[Content]: {content}" + ) + has_tool_calls = True + else: + has_tool_calls = False # if the model is hallucinating, start parameter gathering diff --git a/model_server/src/core/utils/hallucination_utils.py b/model_server/src/core/utils/hallucination_utils.py index 05432710..ba1b5293 100644 --- a/model_server/src/core/utils/hallucination_utils.py +++ b/model_server/src/core/utils/hallucination_utils.py @@ -36,8 +36,8 @@ class MaskToken(Enum): HALLUCINATION_THRESHOLD_DICT = { - "entropy": 0.28, - "varentropy": 1.4, + "entropy": 0.35, + "varentropy": 1.1, "probability": 0.8, } diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index 01b9ad95..7aba557b 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -37,24 +37,7 @@ get_weather_api = { # get_data class return request, intent, hallucination, parameter_gathering -def get_hallucination_data_complex(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in Seattle?") - message2 = Message( - role="assistant", content="Can you specify the unit you want the weather in?" - ) - message3 = Message(role="user", content="In celcius please!") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1, message2, message3], tools=tools) - - return req, True, True, True - - -def get_hallucination_data_medium(): +def get_hallucination_data(): # Create instances of the Message class message1 = Message(role="user", content="How is the weather in?") @@ -65,26 +48,10 @@ def get_hallucination_data_medium(): req = ChatMessage(messages=[message1], tools=tools) # first token will not be tool call - return req, True, True, True + return req, False, True -def get_complete_data_2(): - # Create instances of the Message class - message1 = Message( - role="user", - content="what is the weather forecast for seattle in the next 10 days?", - ) - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - return req, True, False, False - - -def get_complete_data(): +def get_success_tool_call_data(): # Create instances of the Message class message1 = Message(role="user", content="How is the weather in Seattle in 7 days?") @@ -94,7 +61,7 @@ def get_complete_data(): # Create an instance of the ChatMessage class req = ChatMessage(messages=[message1], tools=tools) - return req, True, False, False + return req, True, False def get_irrelevant_data(): @@ -107,7 +74,7 @@ def get_irrelevant_data(): # Create an instance of the ChatMessage class req = ChatMessage(messages=[message1], tools=tools) - return req, False, False, False + return req, False, False def get_greeting_data(): @@ -120,38 +87,29 @@ def get_greeting_data(): # Create an instance of the ChatMessage class req = ChatMessage(messages=[message1], tools=tools) - return req, False, False, False + return req, False, False -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "get_data_func", -# [ -# get_hallucination_data_complex, -# get_complete_data, -# get_irrelevant_data, -# get_complete_data_2, -# ], -# ) -# async def test_function_calling(get_data_func): -# req, intent, hallucination, parameter_gathering = get_data_func() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "get_data_func", + [ + get_hallucination_data, + get_greeting_data, + get_irrelevant_data, + get_success_tool_call_data, + ], +) +async def test_function_calling(get_data_func): + req, intent, hallucination, parameter_gathering = get_data_func() + handler_name = "Arch-Function" + use_agent_orchestrator = False + model_handler: ArchFunctionHandler = handler_map[handler_name] -# intent_response = await handler_map["Arch-Intent"].chat_completion(req) + start_time = time.perf_counter() + final_response = await model_handler.chat_completion(req) + latency = time.perf_counter() - start_time -# assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent + assert intent == (len(final_response.choices[0].message.tool_calls)>=1) -# if intent: -# function_calling_response = await handler_map["Arch-Function"].chat_completion( -# req -# ) -# assert ( -# handler_map["Arch-Function"].hallucination_state.hallucination -# == hallucination -# ) -# response_txt = function_calling_response.choices[0].message.content - -# if parameter_gathering: -# prefill_prefix = handler_map["Arch-Function"].prefill_prefix -# assert any( -# response_txt.startswith(prefix) for prefix in prefill_prefix -# ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}" + assert hallucination == model_handler.hallucination_state.hallucination