mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
fix test
This commit is contained in:
parent
2405fb36e3
commit
b2ef3f7266
4 changed files with 32 additions and 39 deletions
|
|
@ -18,6 +18,7 @@ class ArchGuardHanlder:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.model = model_dict["model"]
|
self.model = model_dict["model"]
|
||||||
|
self.model_name = model_dict["model_name"]
|
||||||
self.tokenizer = model_dict["tokenizer"]
|
self.tokenizer = model_dict["tokenizer"]
|
||||||
self.device = model_dict["device"]
|
self.device = model_dict["device"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,9 @@ arch_guard_model_type = {
|
||||||
|
|
||||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||||
# Test for `get_guardrail_handler()` function on `cpu`
|
# Test for `get_guardrail_handler()` function on `cpu`
|
||||||
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
|
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||||
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
|
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
|
||||||
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
|
@ -22,21 +22,19 @@ def test_guardrail_handler_on_cpu(mock_auto_model, mock_ov_model, mock_tokenizer
|
||||||
|
|
||||||
guardrail = get_guardrail_handler(device=device)
|
guardrail = get_guardrail_handler(device=device)
|
||||||
|
|
||||||
mock_tokenizer.assert_called_once_with(
|
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||||
guardrail["model_name"], trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_ov_model.assert_called_once_with(
|
mock_ov_model.assert_called_once_with(
|
||||||
guardrail["model_name"],
|
guardrail.model_name,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Test for `get_guardrail_handler()` function on `cuda`
|
# Test for `get_guardrail_handler()` function on `cuda`
|
||||||
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
|
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||||
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
|
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
|
||||||
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
|
|
@ -45,21 +43,19 @@ def test_guardrail_handler_on_cuda(mock_auto_model, mock_ov_model, mock_tokenize
|
||||||
|
|
||||||
guardrail = get_guardrail_handler(device=device)
|
guardrail = get_guardrail_handler(device=device)
|
||||||
|
|
||||||
mock_tokenizer.assert_called_once_with(
|
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||||
guardrail["model_name"], trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_auto_model.assert_called_once_with(
|
mock_auto_model.assert_called_once_with(
|
||||||
guardrail["model_name"],
|
guardrail.model_name,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Test for `get_guardrail_handler()` function on `mps`
|
# Test for `get_guardrail_handler()` function on `mps`
|
||||||
@patch("src.core.guardrail.AutoTokenizer.from_pretrained")
|
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||||
@patch("src.core.guardrail.OVModelForSequenceClassification.from_pretrained")
|
@patch("src.core.guardrails.OVModelForSequenceClassification.from_pretrained")
|
||||||
@patch("src.core.guardrail.AutoModelForSequenceClassification.from_pretrained")
|
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||||
device = "mps"
|
device = "mps"
|
||||||
|
|
||||||
|
|
@ -68,12 +64,10 @@ def test_guardrail_handler_on_mps(mock_auto_model, mock_ov_model, mock_tokenizer
|
||||||
|
|
||||||
guardrail = get_guardrail_handler(device=device)
|
guardrail = get_guardrail_handler(device=device)
|
||||||
|
|
||||||
mock_tokenizer.assert_called_once_with(
|
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||||
guardrail["model_name"], trust_remote_code=True
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_auto_model.assert_called_once_with(
|
mock_auto_model.assert_called_once_with(
|
||||||
guardrail["model_name"],
|
guardrail.model_name,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ async def test_guardrail_endpoint():
|
||||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||||
response = client.post("/guardrails", json=request_data)
|
response = client.post("/guardrails", json=request_data)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert "jailbreak_verdict" in response.json()
|
assert "response" in response.json()
|
||||||
|
|
||||||
|
|
||||||
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
|
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
|
||||||
|
|
@ -50,4 +50,4 @@ async def test_function_calling_endpoint():
|
||||||
}
|
}
|
||||||
response = await client.post("/function_calling", json=request_data)
|
response = await client.post("/function_calling", json=request_data)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert "choices" in response.json()
|
assert "result" in response.json()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
from src.core.cli import kill_process
|
from src.commons.utils import kill_processes
|
||||||
|
|
||||||
|
|
||||||
class TestStopServer(unittest.TestCase):
|
class TestStopServer(unittest.TestCase):
|
||||||
|
|
@ -10,8 +9,8 @@ class TestStopServer(unittest.TestCase):
|
||||||
# Mock subprocess.run to simulate no process listening on the port
|
# Mock subprocess.run to simulate no process listening on the port
|
||||||
mock_run.return_value.returncode = 1
|
mock_run.return_value.returncode = 1
|
||||||
with patch("builtins.print") as mock_print:
|
with patch("builtins.print") as mock_print:
|
||||||
kill_process(port=51000)
|
kill_processes(port_processes=[""], wait=True, timeout=5)
|
||||||
mock_print.assert_called_with("No process found listening on port 51000.")
|
mock_print.assert_not_called()
|
||||||
|
|
||||||
@patch("subprocess.run")
|
@patch("subprocess.run")
|
||||||
def test_stop_server_process_killed(self, mock_run):
|
def test_stop_server_process_killed(self, mock_run):
|
||||||
|
|
@ -22,18 +21,15 @@ class TestStopServer(unittest.TestCase):
|
||||||
MagicMock(returncode=1), # for checking the process after it is killed
|
MagicMock(returncode=1), # for checking the process after it is killed
|
||||||
]
|
]
|
||||||
with patch("builtins.print") as mock_print:
|
with patch("builtins.print") as mock_print:
|
||||||
kill_process(port=51000, wait=True, timeout=5)
|
kill_processes(
|
||||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
port_processes=["uvicorn 1234 user LISTEN\n"], wait=True, timeout=5
|
||||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
)
|
||||||
|
mock_print.assert_any_call("Killing process with PID 1234...")
|
||||||
|
|
||||||
@patch("subprocess.run")
|
@patch("subprocess.run")
|
||||||
def test_stop_server_multiple_pids(self, mock_run):
|
def test_stop_server_multiple_pids(self, mock_run):
|
||||||
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
|
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
|
||||||
mock_run.side_effect = [
|
mock_run.side_effect = [
|
||||||
MagicMock(
|
|
||||||
returncode=0,
|
|
||||||
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
|
|
||||||
), # lsof output
|
|
||||||
MagicMock(returncode=0), # first kill command for PID 1234
|
MagicMock(returncode=0), # first kill command for PID 1234
|
||||||
MagicMock(returncode=1), # PID 1234 is successfully terminated
|
MagicMock(returncode=1), # PID 1234 is successfully terminated
|
||||||
MagicMock(returncode=0), # second kill command for PID 5678
|
MagicMock(returncode=0), # second kill command for PID 5678
|
||||||
|
|
@ -41,13 +37,15 @@ class TestStopServer(unittest.TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("builtins.print") as mock_print:
|
with patch("builtins.print") as mock_print:
|
||||||
kill_process(port=51000, wait=True, timeout=5)
|
kill_processes(
|
||||||
|
port_processes=["uvicorn 1234 user LISTEN", "uvicorn 5678 user LISTEN"],
|
||||||
|
wait=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
|
||||||
# Assert that the function tried to kill both PIDs
|
# Assert that the function tried to kill both PIDs
|
||||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
mock_print.assert_any_call("Killing process with PID 1234...")
|
||||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
mock_print.assert_any_call("Killing process with PID 5678...")
|
||||||
mock_print.assert_any_call("Killing model server process with PID 5678")
|
|
||||||
mock_print.assert_any_call("Process 5678 has been killed.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue