mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 17:26:26 +02:00
Use intent model from archfc to pick prompt gateway (#328)
This commit is contained in:
parent
67b8fd635e
commit
ba7279becb
151 changed files with 8642 additions and 10932 deletions
0
model_server/tests/core/__init__.py
Normal file
0
model_server/tests/core/__init__.py
Normal file
173
model_server/tests/core/test_function_calling.py
Normal file
173
model_server/tests/core/test_function_calling.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
import os
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import ChatMessage, Message
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from src.main import app
|
||||
from src.commons.globals import handler_map
|
||||
|
||||
# define function
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# get_data class return request, intent, hallucination, parameter_gathering
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_easy():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# model will hallucinate
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_complete_data_2():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(
|
||||
role="user",
|
||||
content="what is the weather forecast for seattle in the next 10 days?",
|
||||
)
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_complete_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_irrelevant_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="What is 1+1?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
|
||||
|
||||
def get_greeting_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="Hello how are you?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_hallucination_data_easy,
|
||||
get_complete_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
69
model_server/tests/core/test_guardrails.py
Normal file
69
model_server/tests/core/test_guardrails.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from unittest.mock import patch, MagicMock
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
|
||||
# Mock constants
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for `get_guardrail_handler()` function on `cpu`
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_tokenizer):
|
||||
device = "cpu"
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `cuda`
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_tokenizer):
|
||||
device = "cuda"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `mps`
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_tokenizer):
|
||||
device = "mps"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
50
model_server/tests/core/test_state.py
Normal file
50
model_server/tests/core/test_state.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from src.commons.globals import handler_map
|
||||
from src.core.function_calling import Message
|
||||
|
||||
|
||||
test_input_history = [
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
|
||||
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
|
||||
]
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
message_history = []
|
||||
|
||||
for h in test_input_history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
|
||||
assert len(updated_history) == 7
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
Loading…
Add table
Add a link
Reference in a new issue