Reorganize model_server

This commit is contained in:
Shuguang Chen 2024-12-08 09:21:53 -08:00
parent a40cdc7b75
commit b4f4695f16
20 changed files with 20 additions and 20 deletions

View file

View file

@ -0,0 +1,106 @@
import json
import pytest
from fastapi import Response
from unittest.mock import AsyncMock, MagicMock, patch
from src.commons.globals import handler_map
from src.core.base_handler import (
Message,
ChatMessage,
ChatCompletionResponse,
)
def sample_messages():
# Ensure fields are explicitly set with valid data or empty values
return [
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
Message(
role="assistant",
content="",
tool_calls=[{"function": {"name": "sample_tool"}}],
tool_call_id="sample_id",
),
Message(
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
),
]
def sample_request(sample_messages):
return ChatMessage(
messages=sample_messages,
tools=[
{
"type": "function",
"function": {
"name": "sample_tool",
"description": "A sample tool",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
],
)
@patch("app.commons.globals.handler_map")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = handler_map["Arch-Function"]._process_messages(messages)
assert len(processed) == 3
assert processed[0] == {"role": "user", "content": "Hello!"}
assert processed[1] == {
"role": "assistant",
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
}
assert processed[2] == {
"role": "user",
"content": f"<tool_response>\n{json.dumps('Response from tool')}\n</tool_response>",
}
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
# [TODO] Review: Update the following test
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")
@pytest.mark.asyncio
async def test_chat_completion(mock_hanlder, mock_client):
# Mock the model list return for client
mock_client.models.list.return_value = MagicMock(
data=[MagicMock(id="sample_model")]
)
request = sample_request(sample_messages())
# Simulate stream response as list of tokens
mock_response = AsyncMock()
mock_response.__aiter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
]
mock_client.chat.completions.create.return_value = mock_response
# Mock the tool formatter
mock_hanlder._format_system_prompt.return_value = "<formatted_tools>"
response = Response()
chat_response = await chat_completion(request, response)
assert isinstance(chat_response, ChatCompletionResponse)
assert chat_response.choices[0].message.content is not None
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
assert first_call_args["stream"] == True
assert "model" in first_call_args
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# Check that the arguments for the second call to 'create' include the pre-fill completion
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
assert second_call_args["stream"] == False
assert "model" in second_call_args
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST

View file

@ -0,0 +1,79 @@
from unittest.mock import patch, MagicMock
from src.core.guardrails import get_guardrail_handler
# Mock constants
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cpu"
mock_ov_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_ov_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "cuda"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
device = "mps"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(
guardrail["model_name"], trust_remote_code=True
)
mock_auto_model.assert_called_once_with(
guardrail["model_name"],
device_map=device,
low_cpu_mem_usage=True,
)

View file

@ -0,0 +1,153 @@
import json
import pytest
import os
from src.core.hallucination_handler import HallucinationStateHandler
# Get the directory of the current file
current_dir = os.path.dirname(__file__)
# Construct the full path to the JSON file
json_file_path = os.path.join(current_dir, "test_cases.json")
with open(json_file_path) as f:
test_cases = json.load(f)
get_weather_api = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State",
},
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
},
"days": {
"type": "str",
"description": "the number of days for the request.",
},
},
"required": ["location", "days"],
},
},
}
function_description = get_weather_api["function"]
if type(function_description) != list:
function_description = [get_weather_api["function"]]
# [TODO] Review: update the following code
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
state = HallucinationStateHandler(
response_iterator=None, function=function_description
)
for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "</tool_call>":
state.append_and_check_token_hallucination(token, logprob)
if state.hallucination:
break
assert state.hallucination == case["expect"]
# [TODO] Review: update the following code
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
def test_hallucination_prompt(is_hallucinate_sample):
TASK_PROMPT = """
You are a helpful assistant.
""".strip()
TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
def format_prompt(tools):
tool_text = convert_tools(tools)
return (
TASK_PROMPT
+ "\n\n"
+ TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ FORMAT_PROMPT
+ "\n"
)
openai_format_tools = [get_weather_api]
system_prompt = format_prompt(openai_format_tools)
from openai import OpenAI
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
# List models API
model = client.models.list().data[0].id
assert model == "Arch-Function"
if not is_hallucinate_sample:
messages = [
{"role": "system", "content": system_prompt},
# {"role": "user", "content": "can you help me check weather?"},
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
# {"role": "assistant", "content": "Of course!"},
# {"role": "user", "content": "Seattle please"}
]
else:
messages = [
{"role": "system", "content": system_prompt},
# {"role": "user", "content": "can you help me check weather?"},
{"role": "user", "content": "How is the weather in Seattle in days?"},
# {"role": "assistant", "content": "Of course!"},
# {"role": "user", "content": "Seattle please"}
]
extra_body = {
"temperature": 0.6,
"top_p": 1.0,
"top_k": 50,
# "continue_final_message": True,
# "add_generation_prompt": False,
"logprobs": True,
"top_logprobs": 10,
}
resp = client.chat.completions.create(
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
)
hallu = HallucinationStateHandler(
response_iterator=resp, function=function_description
)
for token in hallu:
assert len(hallu.tokens) >= 0
assert hallu.hallucination == is_hallucinate_sample

View file

@ -0,0 +1,50 @@
from src.commons.globals import handler_map
from src.core.function_calling import Message
test_input_history = [
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"model": "Arch-Function",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
]
def test_update_fc_history():
message_history = []
for h in test_input_history:
message_history.append(Message(**h))
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
assert len(updated_history) == 7
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])

View file

@ -0,0 +1,53 @@
import pytest
import httpx
from fastapi.testclient import TestClient
from src.main import app
client = TestClient(app)
# [TODO] Review: check the following code
# Unit tests for the health check endpoint
@pytest.mark.asyncio
async def test_healthz():
response = client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
# [TODO] Review: check the following code
# Unit test for the models endpoint
@pytest.mark.asyncio
async def test_models():
response = client.get("/models")
assert response.status_code == 200
assert response.json()["object"] == "list"
assert len(response.json()["data"]) > 0
# [TODO] Review: check the following code
# Unit test for the guardrail endpoint
@pytest.mark.asyncio
async def test_guardrail_endpoint():
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
response = client.post("/guardrails", json=request_data)
assert response.status_code == 200
assert "jailbreak_verdict" in response.json()
# [TODO] Review: check the following code
# Unit test for the function calling endpoint
@pytest.mark.asyncio
async def test_function_calling_endpoint():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
request_data = {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function",
"tools": [],
"metadata": {"x-arch-state": "[]"},
}
response = await client.post("/function_calling", json=request_data)
assert response.status_code == 200
assert "choices" in response.json()

View file

@ -0,0 +1,54 @@
import unittest
from unittest.mock import patch, MagicMock
from src.cli import kill_process
class TestStopServer(unittest.TestCase):
@patch("subprocess.run")
def test_stop_server_no_process(self, mock_run):
# Mock subprocess.run to simulate no process listening on the port
mock_run.return_value.returncode = 1
with patch("builtins.print") as mock_print:
kill_process(port=51000)
mock_print.assert_called_with("No process found listening on port 51000.")
@patch("subprocess.run")
def test_stop_server_process_killed(self, mock_run):
# Simulate lsof returning a process id
mock_run.side_effect = [
MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"),
MagicMock(returncode=0), # for killing the process
MagicMock(returncode=1), # for checking the process after it is killed
]
with patch("builtins.print") as mock_print:
kill_process(port=51000, wait=True, timeout=5)
mock_print.assert_any_call("Killing model server process with PID 1234")
mock_print.assert_any_call("Process 1234 has been killed.")
@patch("subprocess.run")
def test_stop_server_multiple_pids(self, mock_run):
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
mock_run.side_effect = [
MagicMock(
returncode=0,
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
), # lsof output
MagicMock(returncode=0), # first kill command for PID 1234
MagicMock(returncode=1), # PID 1234 is successfully terminated
MagicMock(returncode=0), # second kill command for PID 5678
MagicMock(returncode=1), # PID 5678 is successfully terminated
]
with patch("builtins.print") as mock_print:
kill_process(port=51000, wait=True, timeout=5)
# Assert that the function tried to kill both PIDs
mock_print.assert_any_call("Killing model server process with PID 1234")
mock_print.assert_any_call("Process 1234 has been killed.")
mock_print.assert_any_call("Killing model server process with PID 5678")
mock_print.assert_any_call("Process 5678 has been killed.")
if __name__ == "__main__":
unittest.main()