Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

View file

@ -0,0 +1,173 @@
import os
from src.commons.globals import handler_map
from src.core.model_utils import ChatMessage, Message
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
from src.main import app
from src.commons.globals import handler_map
# define function
get_weather_api = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State",
},
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
},
"days": {
"type": "str",
"description": "the number of days for the request.",
},
},
"required": ["location", "days"],
},
},
}
# get_data class return request, intent, hallucination, parameter_gathering
def get_hallucination_data_complex():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
message2 = Message(
role="assistant", content="Can you specify the unit you want the weather in?"
)
message3 = Message(role="user", content="In celcius please!")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
return req, True, True, True
def get_hallucination_data_easy():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# model will hallucinate
return req, True, True, True
def get_hallucination_data_medium():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# first token will not be tool call
return req, True, True, True
def get_complete_data_2():
# Create instances of the Message class
message1 = Message(
role="user",
content="what is the weather forecast for seattle in the next 10 days?",
)
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, True, False, False
def get_complete_data():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, True, False, False
def get_irrelevant_data():
# Create instances of the Message class
message1 = Message(role="user", content="What is 1+1?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, False, False, False
def get_greeting_data():
# Create instances of the Message class
message1 = Message(role="user", content="Hello how are you?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, False, False, False
@pytest.mark.asyncio
@pytest.mark.parametrize(
"get_data_func",
[
get_hallucination_data_complex,
get_hallucination_data_easy,
get_complete_data,
get_irrelevant_data,
get_complete_data_2,
],
)
async def test_function_calling(get_data_func):
req, intent, hallucination, parameter_gathering = get_data_func()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
if intent:
function_calling_response = await handler_map["Arch-Function"].chat_completion(
req
)
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
response_txt = function_calling_response.choices[0].message.content
if parameter_gathering:
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
assert any(
response_txt.startswith(prefix) for prefix in prefill_prefix
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"

View file

@ -0,0 +1,69 @@
from unittest.mock import patch, MagicMock
from src.core.guardrails import get_guardrail_handler
# Mock constants
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_tokenizer):
device = "cpu"
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_tokenizer):
device = "cuda"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_tokenizer):
device = "mps"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)

View file

@ -0,0 +1,50 @@
from src.commons.globals import handler_map
from src.core.function_calling import Message
test_input_history = [
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"model": "Arch-Function",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
]
def test_update_fc_history():
message_history = []
for h in test_input_history:
message_history.append(Message(**h))
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
assert len(updated_history) == 7
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])