add more test

This commit is contained in:
cotran 2024-12-09 13:48:30 -08:00
parent 63cc2ef3f3
commit f13947732c
5 changed files with 162 additions and 236 deletions

View file

@ -480,7 +480,7 @@ class ArchFunctionHandler(ArchBaseHandler):
model_response, has_tool_call = "", None
for token in self.hallu_handler:
for _ in self.hallu_handler:
# check if the first token is <tool_call>
if len(self.hallu_handler.tokens) > 0 and has_tool_call == None:
if self.hallu_handler.tokens[0] == "<tool_call>":

View file

@ -27,10 +27,10 @@ class MaskToken(Enum):
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.001, "varentropy": 0.005},
MaskToken.TOOL_CALL.value: {"entropy": 0.05, "varentropy": 0.25},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.001,
"varentropy": 0.005,
"entropy": 0.05,
"varentropy": 0.25,
},
}

View file

@ -1,106 +1,174 @@
import json
import pytest
import os
from fastapi import Response
from unittest.mock import AsyncMock, MagicMock, patch
from src.commons.globals import handler_map
from src.core.model_utils import (
Message,
ChatMessage,
ChatCompletionResponse,
)
from src.core.model_utils import ChatMessage, Message
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
from src.main import app
from src.commons.globals import handler_map
def sample_messages():
# Ensure fields are explicitly set with valid data or empty values
return [
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
Message(
role="assistant",
content="",
tool_calls=[{"function": {"name": "sample_tool"}}],
tool_call_id="sample_id",
),
Message(
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
),
]
def sample_request(sample_messages):
return ChatMessage(
messages=sample_messages,
tools=[
{
"type": "function",
"function": {
"name": "sample_tool",
"description": "A sample tool",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
# define function
get_weather_api = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State",
},
}
],
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
},
"days": {
"type": "str",
"description": "the number of days for the request.",
},
},
"required": ["location", "days"],
},
},
}
# get_data class return request, intent, hallucination, parameter_gathering
def get_hallucination_data_complex():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
message2 = Message(
role="assistant", content="Can you specify the unit you want the weather in?"
)
message3 = Message(role="user", content="In celcius please!")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
return req, True, True, True
def get_hallucination_data_easy():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# model will hallucinate
return req, True, True, True
def get_hallucination_data_medium():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# first token will not be tool call
return req, True, False, True
def get_complete_data_2():
# Create instances of the Message class
message1 = Message(
role="user",
content="what is the weather forcast for seattle in the next 10 days?",
)
# Create a list of tools
tools = [get_weather_api]
@patch("src.commons.globals.handler_map")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = handler_map["Arch-Function"]._process_messages(messages)
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
assert len(processed) == 3
assert processed[0] == {"role": "user", "content": "Hello!"}
assert processed[1] == {
"role": "assistant",
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
}
assert processed[2] == {
"role": "user",
"content": f"<tool_response>\n{json.dumps('Response from tool')}\n</tool_response>",
}
return req, True, False, False
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
def get_complete_data():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, True, False, False
# [TODO] Review: Update the following test
# @patch("src.commons.globals.ARCH_CLIENT")
# @patch("src.commons.constants.handler_map")
# @pytest.mark.asyncio
# async def test_chat_completion(mock_hanlder, mock_client):
# # Mock the model list return for client
# mock_client.models.list.return_value = MagicMock(
# data=[MagicMock(id="sample_model")]
# )
# request = sample_request(sample_messages())
# # Simulate stream response as list of tokens
# mock_response = AsyncMock()
# mock_response.__aiter__.return_value = [
# MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
# MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
# ]
# mock_client.chat.completions.create.return_value = mock_response
def get_irrelevant_data():
# Create instances of the Message class
message1 = Message(role="user", content="What is 1+1?")
# # Mock the tool formatter
# mock_hanlder._format_system_prompt.return_value = "<formatted_tools>"
# Create a list of tools
tools = [get_weather_api]
# response = Response()
# chat_response = await chat_completion(request, response)
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# assert isinstance(chat_response, ChatCompletionResponse)
# assert chat_response.choices[0].message.content is not None
return req, False, False, False
# first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
# assert first_call_args["stream"] == True
# assert "model" in first_call_args
# assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# # Check that the arguments for the second call to 'create' include the pre-fill completion
# second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
# assert second_call_args["stream"] == False
# assert "model" in second_call_args
# assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
def get_greeting_data():
# Create instances of the Message class
message1 = Message(role="user", content="Hello how are you?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, False, False, False
@pytest.mark.asyncio
@pytest.mark.parametrize(
"get_data_func",
[
get_hallucination_data_complex,
get_hallucination_data_easy,
get_hallucination_data_medium,
get_complete_data,
get_irrelevant_data,
get_complete_data_2,
],
)
async def test_function_calling(get_data_func):
req, intent, hallucination, parameter_gathering = get_data_func()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
if intent:
function_calling_response = await handler_map["Arch-Function"].chat_completion(
req
)
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
response_txt = function_calling_response.choices[0].message.content
if parameter_gathering:
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
assert any(
response_txt.startswith(prefix) for prefix in prefill_prefix
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"

View file

@ -1,142 +0,0 @@
import os
from src.commons.globals import handler_map
from src.core.model_utils import ChatMessage, Message
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
from src.main import app
from src.commons.globals import handler_map
# define function
get_weather_api = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State",
},
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
},
"days": {
"type": "str",
"description": "the number of days for the request.",
},
},
"required": ["location", "days"],
},
},
}
def get_hallucination_data_complex():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
message2 = Message(
role="assistant", content="Can you specify the unit you want the weather in?"
)
message3 = Message(role="user", content="In celcius please!")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
return req, True, True, True
def get_hallucination_data_easy():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# model will hallucinate
return req, True, True, True
def get_hallucination_data_medium():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# first token will not be tool call
return req, True, False, True
def get_complete_data():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, True, False, False
def get_irrelevant_data():
# Create instances of the Message class
message1 = Message(role="user", content="What is 1+1?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, False, False, False
@pytest.mark.asyncio
@pytest.mark.parametrize(
"get_data_func",
[
get_hallucination_data_complex,
get_hallucination_data_easy,
get_hallucination_data_medium,
get_complete_data,
get_irrelevant_data,
],
)
async def test_function_calling(get_data_func):
req, intent, hallucination, parameter_gathering = get_data_func()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
if intent:
function_calling_response = await handler_map["Arch-Function"].chat_completion(
req
)
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
response_txt = function_calling_response.choices[0].message.content
if parameter_gathering:
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
assert any(
response_txt.startswith(prefix) for prefix in prefill_prefix
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"

View file

@ -1,7 +1,7 @@
import unittest
from unittest.mock import patch, MagicMock
from src.cli import kill_process
from src.core.cli import kill_process
class TestStopServer(unittest.TestCase):