mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add more test
This commit is contained in:
parent
63cc2ef3f3
commit
f13947732c
5 changed files with 162 additions and 236 deletions
|
|
@ -480,7 +480,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
model_response, has_tool_call = "", None
|
||||
|
||||
for token in self.hallu_handler:
|
||||
for _ in self.hallu_handler:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallu_handler.tokens) > 0 and has_tool_call == None:
|
||||
if self.hallu_handler.tokens[0] == "<tool_call>":
|
||||
|
|
|
|||
|
|
@ -27,10 +27,10 @@ class MaskToken(Enum):
|
|||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.001, "varentropy": 0.005},
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.05, "varentropy": 0.25},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.001,
|
||||
"varentropy": 0.005,
|
||||
"entropy": 0.05,
|
||||
"varentropy": 0.25,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,106 +1,174 @@
|
|||
import json
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from fastapi import Response
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from src.core.model_utils import ChatMessage, Message
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from src.main import app
|
||||
from src.commons.globals import handler_map
|
||||
|
||||
|
||||
def sample_messages():
|
||||
# Ensure fields are explicitly set with valid data or empty values
|
||||
return [
|
||||
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "sample_tool"}}],
|
||||
tool_call_id="sample_id",
|
||||
),
|
||||
Message(
|
||||
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_request(sample_messages):
|
||||
return ChatMessage(
|
||||
messages=sample_messages,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sample_tool",
|
||||
"description": "A sample tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
# define function
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
}
|
||||
],
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# get_data class return request, intent, hallucination, parameter_gathering
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_easy():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# model will hallucinate
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, False, True
|
||||
|
||||
|
||||
def get_complete_data_2():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(
|
||||
role="user",
|
||||
content="what is the weather forcast for seattle in the next 10 days?",
|
||||
)
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
@patch("src.commons.globals.handler_map")
|
||||
def test_process_messages(mock_hanlder):
|
||||
messages = sample_messages()
|
||||
processed = handler_map["Arch-Function"]._process_messages(messages)
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
assert len(processed) == 3
|
||||
assert processed[0] == {"role": "user", "content": "Hello!"}
|
||||
assert processed[1] == {
|
||||
"role": "assistant",
|
||||
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
|
||||
}
|
||||
assert processed[2] == {
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{json.dumps('Response from tool')}\n</tool_response>",
|
||||
}
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
|
||||
def get_complete_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
# [TODO] Review: Update the following test
|
||||
# @patch("src.commons.globals.ARCH_CLIENT")
|
||||
# @patch("src.commons.constants.handler_map")
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_chat_completion(mock_hanlder, mock_client):
|
||||
# # Mock the model list return for client
|
||||
# mock_client.models.list.return_value = MagicMock(
|
||||
# data=[MagicMock(id="sample_model")]
|
||||
# )
|
||||
# request = sample_request(sample_messages())
|
||||
# # Simulate stream response as list of tokens
|
||||
# mock_response = AsyncMock()
|
||||
# mock_response.__aiter__.return_value = [
|
||||
# MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
|
||||
# MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
|
||||
# ]
|
||||
# mock_client.chat.completions.create.return_value = mock_response
|
||||
def get_irrelevant_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="What is 1+1?")
|
||||
|
||||
# # Mock the tool formatter
|
||||
# mock_hanlder._format_system_prompt.return_value = "<formatted_tools>"
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# response = Response()
|
||||
# chat_response = await chat_completion(request, response)
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# assert isinstance(chat_response, ChatCompletionResponse)
|
||||
# assert chat_response.choices[0].message.content is not None
|
||||
return req, False, False, False
|
||||
|
||||
# first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
|
||||
# assert first_call_args["stream"] == True
|
||||
# assert "model" in first_call_args
|
||||
# assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
|
||||
|
||||
# # Check that the arguments for the second call to 'create' include the pre-fill completion
|
||||
# second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
|
||||
# assert second_call_args["stream"] == False
|
||||
# assert "model" in second_call_args
|
||||
# assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
|
||||
def get_greeting_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="Hello how are you?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_hallucination_data_easy,
|
||||
get_hallucination_data_medium,
|
||||
get_complete_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
|
|
|
|||
|
|
@ -1,142 +0,0 @@
|
|||
import os
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import ChatMessage, Message
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from src.main import app
|
||||
from src.commons.globals import handler_map
|
||||
|
||||
# define function
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_easy():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# model will hallucinate
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, False, True
|
||||
|
||||
|
||||
def get_complete_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_irrelevant_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="What is 1+1?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_hallucination_data_easy,
|
||||
get_hallucination_data_medium,
|
||||
get_complete_data,
|
||||
get_irrelevant_data,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.cli import kill_process
|
||||
from src.core.cli import kill_process
|
||||
|
||||
|
||||
class TestStopServer(unittest.TestCase):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue