mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
Reorganize model_server
This commit is contained in:
parent
a40cdc7b75
commit
b4f4695f16
20 changed files with 20 additions and 20 deletions
0
model_server/tests/__init__.py
Normal file
0
model_server/tests/__init__.py
Normal file
106
model_server/tests/core/test_function_calling.py
Normal file
106
model_server/tests/core/test_function_calling.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
import json
|
||||
import pytest
|
||||
|
||||
from fastapi import Response
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.base_handler import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
def sample_messages():
|
||||
# Ensure fields are explicitly set with valid data or empty values
|
||||
return [
|
||||
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "sample_tool"}}],
|
||||
tool_call_id="sample_id",
|
||||
),
|
||||
Message(
|
||||
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_request(sample_messages):
|
||||
return ChatMessage(
|
||||
messages=sample_messages,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sample_tool",
|
||||
"description": "A sample tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@patch("app.commons.globals.handler_map")
|
||||
def test_process_messages(mock_hanlder):
|
||||
messages = sample_messages()
|
||||
processed = handler_map["Arch-Function"]._process_messages(messages)
|
||||
|
||||
assert len(processed) == 3
|
||||
assert processed[0] == {"role": "user", "content": "Hello!"}
|
||||
assert processed[1] == {
|
||||
"role": "assistant",
|
||||
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
|
||||
}
|
||||
assert processed[2] == {
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{json.dumps('Response from tool')}\n</tool_response>",
|
||||
}
|
||||
|
||||
|
||||
# [TODO] Review: Add tests for both `ArchIntentHandler` and `ArchFunctionHandler`. The following test may be outdated.
|
||||
|
||||
|
||||
# [TODO] Review: Update the following test
|
||||
@patch("app.commons.constants.arch_function_client")
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion(mock_hanlder, mock_client):
|
||||
# Mock the model list return for client
|
||||
mock_client.models.list.return_value = MagicMock(
|
||||
data=[MagicMock(id="sample_model")]
|
||||
)
|
||||
request = sample_request(sample_messages())
|
||||
# Simulate stream response as list of tokens
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aiter__.return_value = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
|
||||
]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Mock the tool formatter
|
||||
mock_hanlder._format_system_prompt.return_value = "<formatted_tools>"
|
||||
|
||||
response = Response()
|
||||
chat_response = await chat_completion(request, response)
|
||||
|
||||
assert isinstance(chat_response, ChatCompletionResponse)
|
||||
assert chat_response.choices[0].message.content is not None
|
||||
|
||||
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
|
||||
assert first_call_args["stream"] == True
|
||||
assert "model" in first_call_args
|
||||
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
|
||||
|
||||
# Check that the arguments for the second call to 'create' include the pre-fill completion
|
||||
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
|
||||
assert second_call_args["stream"] == False
|
||||
assert "model" in second_call_args
|
||||
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
|
||||
79
model_server/tests/core/test_guardrails.py
Normal file
79
model_server/tests/core/test_guardrails.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
from unittest.mock import patch, MagicMock
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
|
||||
# Mock constants
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for `get_guardrail_handler()` function on `cpu`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cpu"
|
||||
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_ov_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `cuda`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "cuda"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `mps`
|
||||
@patch("app.model_handler.guardrail.AutoTokenizer.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.model_handler.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
device = "mps"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
guardrail["model_name"], trust_remote_code=True
|
||||
)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail["model_name"],
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
153
model_server/tests/core/test_hallucination.py
Normal file
153
model_server/tests/core/test_hallucination.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
import json
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
from src.core.hallucination_handler import HallucinationStateHandler
|
||||
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
# Construct the full path to the JSON file
|
||||
json_file_path = os.path.join(current_dir, "test_cases.json")
|
||||
|
||||
with open(json_file_path) as f:
|
||||
test_cases = json.load(f)
|
||||
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
function_description = get_weather_api["function"]
|
||||
if type(function_description) != list:
|
||||
function_description = [get_weather_api["function"]]
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
@pytest.mark.parametrize("case", test_cases)
|
||||
def test_hallucination(case):
|
||||
state = HallucinationStateHandler(
|
||||
response_iterator=None, function=function_description
|
||||
)
|
||||
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
||||
if token != "</tool_call>":
|
||||
state.append_and_check_token_hallucination(token, logprob)
|
||||
if state.hallucination:
|
||||
break
|
||||
assert state.hallucination == case["expect"]
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
|
||||
def test_hallucination_prompt(is_hallucinate_sample):
|
||||
TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
def format_prompt(tools):
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
return (
|
||||
TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ FORMAT_PROMPT
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
openai_format_tools = [get_weather_api]
|
||||
|
||||
system_prompt = format_prompt(openai_format_tools)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
# List models API
|
||||
model = client.models.list().data[0].id
|
||||
assert model == "Arch-Function"
|
||||
if not is_hallucinate_sample:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
|
||||
extra_body = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
# "continue_final_message": True,
|
||||
# "add_generation_prompt": False,
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
|
||||
)
|
||||
|
||||
hallu = HallucinationStateHandler(
|
||||
response_iterator=resp, function=function_description
|
||||
)
|
||||
|
||||
for token in hallu:
|
||||
assert len(hallu.tokens) >= 0
|
||||
assert hallu.hallucination == is_hallucinate_sample
|
||||
50
model_server/tests/core/test_state.py
Normal file
50
model_server/tests/core/test_state.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from src.commons.globals import handler_map
|
||||
from src.core.function_calling import Message
|
||||
|
||||
|
||||
test_input_history = [
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
|
||||
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
|
||||
]
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
message_history = []
|
||||
|
||||
for h in test_input_history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
|
||||
assert len(updated_history) == 7
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
53
model_server/tests/test_app.py
Normal file
53
model_server/tests/test_app.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import pytest
|
||||
import httpx
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from src.main import app
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# [TODO] Review: check the following code
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz():
|
||||
response = client.get("/healthz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# [TODO] Review: check the following code
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_models():
|
||||
response = client.get("/models")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# [TODO] Review: check the following code
|
||||
# Unit test for the guardrail endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_guardrail_endpoint():
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guardrails", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# [TODO] Review: check the following code
|
||||
# Unit test for the function calling endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling_endpoint():
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
request_data = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function",
|
||||
"tools": [],
|
||||
"metadata": {"x-arch-state": "[]"},
|
||||
}
|
||||
response = await client.post("/function_calling", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "choices" in response.json()
|
||||
54
model_server/tests/test_cli_stop_server.py
Normal file
54
model_server/tests/test_cli_stop_server.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import unittest
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.cli import kill_process
|
||||
|
||||
|
||||
class TestStopServer(unittest.TestCase):
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_no_process(self, mock_run):
|
||||
# Mock subprocess.run to simulate no process listening on the port
|
||||
mock_run.return_value.returncode = 1
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000)
|
||||
mock_print.assert_called_with("No process found listening on port 51000.")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_process_killed(self, mock_run):
|
||||
# Simulate lsof returning a process id
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"),
|
||||
MagicMock(returncode=0), # for killing the process
|
||||
MagicMock(returncode=1), # for checking the process after it is killed
|
||||
]
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000, wait=True, timeout=5)
|
||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_multiple_pids(self, mock_run):
|
||||
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
|
||||
mock_run.side_effect = [
|
||||
MagicMock(
|
||||
returncode=0,
|
||||
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
|
||||
), # lsof output
|
||||
MagicMock(returncode=0), # first kill command for PID 1234
|
||||
MagicMock(returncode=1), # PID 1234 is successfully terminated
|
||||
MagicMock(returncode=0), # second kill command for PID 5678
|
||||
MagicMock(returncode=1), # PID 5678 is successfully terminated
|
||||
]
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000, wait=True, timeout=5)
|
||||
|
||||
# Assert that the function tried to kill both PIDs
|
||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
||||
mock_print.assert_any_call("Killing model server process with PID 5678")
|
||||
mock_print.assert_any_call("Process 5678 has been killed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue