From f13947732c4c84d1a7f840c66081aac72570a6a6 Mon Sep 17 00:00:00 2001 From: cotran Date: Mon, 9 Dec 2024 13:48:30 -0800 Subject: [PATCH] add more test --- model_server/src/core/function_calling.py | 2 +- model_server/src/core/hallucination.py | 6 +- .../tests/core/test_function_calling.py | 246 +++++++++++------- model_server/tests/core/test_hallucination.py | 142 ---------- model_server/tests/test_cli_stop_server.py | 2 +- 5 files changed, 162 insertions(+), 236 deletions(-) delete mode 100644 model_server/tests/core/test_hallucination.py diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 23eb441b..ec21f0f4 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -480,7 +480,7 @@ class ArchFunctionHandler(ArchBaseHandler): model_response, has_tool_call = "", None - for token in self.hallu_handler: + for _ in self.hallu_handler: # check if the first token is if len(self.hallu_handler.tokens) > 0 and has_tool_call == None: if self.hallu_handler.tokens[0] == "": diff --git a/model_server/src/core/hallucination.py b/model_server/src/core/hallucination.py index 2e04fe1f..56b84713 100644 --- a/model_server/src/core/hallucination.py +++ b/model_server/src/core/hallucination.py @@ -27,10 +27,10 @@ class MaskToken(Enum): HALLUCINATION_THRESHOLD_DICT = { - MaskToken.TOOL_CALL.value: {"entropy": 0.001, "varentropy": 0.005}, + MaskToken.TOOL_CALL.value: {"entropy": 0.05, "varentropy": 0.25}, MaskToken.PARAMETER_VALUE.value: { - "entropy": 0.001, - "varentropy": 0.005, + "entropy": 0.05, + "varentropy": 0.25, }, } diff --git a/model_server/tests/core/test_function_calling.py b/model_server/tests/core/test_function_calling.py index 2c786809..01aab8f7 100644 --- a/model_server/tests/core/test_function_calling.py +++ b/model_server/tests/core/test_function_calling.py @@ -1,106 +1,174 @@ -import json -import pytest +import os -from fastapi import Response -from unittest.mock import AsyncMock, MagicMock, patch from src.commons.globals import handler_map -from src.core.model_utils import ( - Message, - ChatMessage, - ChatCompletionResponse, -) +from src.core.model_utils import ChatMessage, Message +import pytest +from fastapi.testclient import TestClient +from unittest.mock import AsyncMock, patch +from src.main import app +from src.commons.globals import handler_map - -def sample_messages(): - # Ensure fields are explicitly set with valid data or empty values - return [ - Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""), - Message( - role="assistant", - content="", - tool_calls=[{"function": {"name": "sample_tool"}}], - tool_call_id="sample_id", - ), - Message( - role="tool", content="Response from tool", tool_calls=[], tool_call_id="" - ), - ] - - -def sample_request(sample_messages): - return ChatMessage( - messages=sample_messages, - tools=[ - { - "type": "function", - "function": { - "name": "sample_tool", - "description": "A sample tool", - "parameters": { - "type": "object", - "properties": {}, - "required": [], - }, +# define function +get_weather_api = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get current weather at a location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "str", + "description": "The location to get the weather for", + "format": "City, State", }, - } - ], + "unit": { + "type": "str", + "description": "The unit to return the weather in.", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + }, + "days": { + "type": "str", + "description": "the number of days for the request.", + }, + }, + "required": ["location", "days"], + }, + }, +} + +# get_data class return request, intent, hallucination, parameter_gathering + + +def get_hallucination_data_complex(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in Seattle?") + message2 = Message( + role="assistant", content="Can you specify the unit you want the weather in?" + ) + message3 = Message(role="user", content="In celcius please!") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1, message2, message3], tools=tools) + + return req, True, True, True + + +def get_hallucination_data_easy(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in Seattle?") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + # model will hallucinate + return req, True, True, True + + +def get_hallucination_data_medium(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in?") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + # first token will not be tool call + return req, True, False, True + + +def get_complete_data_2(): + # Create instances of the Message class + message1 = Message( + role="user", + content="what is the weather forcast for seattle in the next 10 days?", ) + # Create a list of tools + tools = [get_weather_api] -@patch("src.commons.globals.handler_map") -def test_process_messages(mock_hanlder): - messages = sample_messages() - processed = handler_map["Arch-Function"]._process_messages(messages) + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) - assert len(processed) == 3 - assert processed[0] == {"role": "user", "content": "Hello!"} - assert processed[1] == { - "role": "assistant", - "content": '\n{"name": "sample_tool"}\n', - } - assert processed[2] == { - "role": "user", - "content": f"\n{json.dumps('Response from tool')}\n", - } + return req, True, False, False -# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated. +def get_complete_data(): + # Create instances of the Message class + message1 = Message(role="user", content="How is the weather in Seattle in 7 days?") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + return req, True, False, False -# [TODO] Review: Update the following test -# @patch("src.commons.globals.ARCH_CLIENT") -# @patch("src.commons.constants.handler_map") -# @pytest.mark.asyncio -# async def test_chat_completion(mock_hanlder, mock_client): -# # Mock the model list return for client -# mock_client.models.list.return_value = MagicMock( -# data=[MagicMock(id="sample_model")] -# ) -# request = sample_request(sample_messages()) -# # Simulate stream response as list of tokens -# mock_response = AsyncMock() -# mock_response.__aiter__.return_value = [ -# MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]), -# MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream -# ] -# mock_client.chat.completions.create.return_value = mock_response +def get_irrelevant_data(): + # Create instances of the Message class + message1 = Message(role="user", content="What is 1+1?") -# # Mock the tool formatter -# mock_hanlder._format_system_prompt.return_value = "" + # Create a list of tools + tools = [get_weather_api] -# response = Response() -# chat_response = await chat_completion(request, response) + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) -# assert isinstance(chat_response, ChatCompletionResponse) -# assert chat_response.choices[0].message.content is not None + return req, False, False, False -# first_call_args = mock_client.chat.completions.create.call_args_list[0][1] -# assert first_call_args["stream"] == True -# assert "model" in first_call_args -# assert first_call_args["messages"][0]["content"] == "" -# # Check that the arguments for the second call to 'create' include the pre-fill completion -# second_call_args = mock_client.chat.completions.create.call_args_list[1][1] -# assert second_call_args["stream"] == False -# assert "model" in second_call_args -# assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST +def get_greeting_data(): + # Create instances of the Message class + message1 = Message(role="user", content="Hello how are you?") + + # Create a list of tools + tools = [get_weather_api] + + # Create an instance of the ChatMessage class + req = ChatMessage(messages=[message1], tools=tools) + + return req, False, False, False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "get_data_func", + [ + get_hallucination_data_complex, + get_hallucination_data_easy, + get_hallucination_data_medium, + get_complete_data, + get_irrelevant_data, + get_complete_data_2, + ], +) +async def test_function_calling(get_data_func): + req, intent, hallucination, parameter_gathering = get_data_func() + + intent_response = await handler_map["Arch-Intent"].chat_completion(req) + + assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent + + if intent: + function_calling_response = await handler_map["Arch-Function"].chat_completion( + req + ) + assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination + response_txt = function_calling_response.choices[0].message.content + + if parameter_gathering: + prefill_prefix = handler_map["Arch-Function"].prefill_prefix + assert any( + response_txt.startswith(prefix) for prefix in prefill_prefix + ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}" diff --git a/model_server/tests/core/test_hallucination.py b/model_server/tests/core/test_hallucination.py deleted file mode 100644 index fcbbd962..00000000 --- a/model_server/tests/core/test_hallucination.py +++ /dev/null @@ -1,142 +0,0 @@ -import os - -from src.commons.globals import handler_map -from src.core.model_utils import ChatMessage, Message -import pytest -from fastapi.testclient import TestClient -from unittest.mock import AsyncMock, patch -from src.main import app -from src.commons.globals import handler_map - -# define function -get_weather_api = { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get current weather at a location.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "str", - "description": "The location to get the weather for", - "format": "City, State", - }, - "unit": { - "type": "str", - "description": "The unit to return the weather in.", - "enum": ["celsius", "fahrenheit"], - "default": "celsius", - }, - "days": { - "type": "str", - "description": "the number of days for the request.", - }, - }, - "required": ["location", "days"], - }, - }, -} - - -def get_hallucination_data_complex(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in Seattle?") - message2 = Message( - role="assistant", content="Can you specify the unit you want the weather in?" - ) - message3 = Message(role="user", content="In celcius please!") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1, message2, message3], tools=tools) - - return req, True, True, True - - -def get_hallucination_data_easy(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in Seattle?") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - # model will hallucinate - return req, True, True, True - - -def get_hallucination_data_medium(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in?") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - # first token will not be tool call - return req, True, False, True - - -def get_complete_data(): - # Create instances of the Message class - message1 = Message(role="user", content="How is the weather in Seattle in 7 days?") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - return req, True, False, False - - -def get_irrelevant_data(): - # Create instances of the Message class - message1 = Message(role="user", content="What is 1+1?") - - # Create a list of tools - tools = [get_weather_api] - - # Create an instance of the ChatMessage class - req = ChatMessage(messages=[message1], tools=tools) - - return req, False, False, False - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "get_data_func", - [ - get_hallucination_data_complex, - get_hallucination_data_easy, - get_hallucination_data_medium, - get_complete_data, - get_irrelevant_data, - ], -) -async def test_function_calling(get_data_func): - req, intent, hallucination, parameter_gathering = get_data_func() - - intent_response = await handler_map["Arch-Intent"].chat_completion(req) - - assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent - - if intent: - function_calling_response = await handler_map["Arch-Function"].chat_completion( - req - ) - assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination - response_txt = function_calling_response.choices[0].message.content - - if parameter_gathering: - prefill_prefix = handler_map["Arch-Function"].prefill_prefix - assert any( - response_txt.startswith(prefix) for prefix in prefill_prefix - ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}" diff --git a/model_server/tests/test_cli_stop_server.py b/model_server/tests/test_cli_stop_server.py index 5c16475a..d9d380d5 100644 --- a/model_server/tests/test_cli_stop_server.py +++ b/model_server/tests/test_cli_stop_server.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch, MagicMock -from src.cli import kill_process +from src.core.cli import kill_process class TestStopServer(unittest.TestCase):