279 lines
9.8 KiB
Python
279 lines
9.8 KiB
Python
"""Unit tests for pure helper functions in router.py (no network, no app)."""
|
|
import time
|
|
import asyncio
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import aiohttp
|
|
import pytest
|
|
|
|
import router
|
|
|
|
|
|
class TestMaskSecrets:
|
|
def test_masks_openai_key(self):
|
|
text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ"
|
|
result = router._mask_secrets(text)
|
|
assert "sk-***redacted***" in result
|
|
assert "sk-abcd1234" not in result
|
|
|
|
def test_masks_api_key_assignment(self):
|
|
result = router._mask_secrets("api_key: supersecretvalue123")
|
|
assert "supersecretvalue123" not in result
|
|
assert "***redacted***" in result
|
|
|
|
def test_masks_api_key_with_colon(self):
|
|
result = router._mask_secrets("api-key: mykey")
|
|
assert "mykey" not in result
|
|
|
|
def test_empty_string_returns_empty(self):
|
|
assert router._mask_secrets("") == ""
|
|
|
|
def test_none_returns_none(self):
|
|
assert router._mask_secrets(None) is None
|
|
|
|
def test_no_secrets_unchanged(self):
|
|
text = "this is a normal log line"
|
|
assert router._mask_secrets(text) == text
|
|
|
|
|
|
class TestIsFresh:
|
|
def test_fresh_within_ttl(self):
|
|
cached_at = time.time() - 10
|
|
assert router._is_fresh(cached_at, 300) is True
|
|
|
|
def test_expired_beyond_ttl(self):
|
|
cached_at = time.time() - 400
|
|
assert router._is_fresh(cached_at, 300) is False
|
|
|
|
def test_exactly_at_boundary(self):
|
|
cached_at = time.time() - 300
|
|
# May be True or False depending on timing, just verify it runs
|
|
result = router._is_fresh(cached_at, 300)
|
|
assert isinstance(result, bool)
|
|
|
|
def test_just_cached(self):
|
|
assert router._is_fresh(time.time(), 1) is True
|
|
|
|
|
|
class TestNormalizeLlamaModelName:
|
|
def test_strips_hf_prefix(self):
|
|
assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF"
|
|
|
|
def test_strips_quant_suffix(self):
|
|
assert router._normalize_llama_model_name("model:Q8_0") == "model"
|
|
|
|
def test_strips_both(self):
|
|
result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16")
|
|
assert result == "gpt-oss-20b-GGUF"
|
|
|
|
def test_no_prefix_no_suffix(self):
|
|
assert router._normalize_llama_model_name("plain-model") == "plain-model"
|
|
|
|
def test_multiple_slashes(self):
|
|
result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M")
|
|
assert result == "model-name"
|
|
|
|
|
|
class TestExtractLlamaQuant:
|
|
def test_extracts_quant(self):
|
|
assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0"
|
|
|
|
def test_no_quant_returns_empty(self):
|
|
assert router._extract_llama_quant("plain-model") == ""
|
|
|
|
def test_f16(self):
|
|
assert router._extract_llama_quant("model:F16") == "F16"
|
|
|
|
def test_q4_k_m(self):
|
|
assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M"
|
|
|
|
|
|
class TestIsUnixSocketEndpoint:
|
|
def test_sock_endpoint_detected(self):
|
|
assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True
|
|
|
|
def test_regular_http_not_sock(self):
|
|
assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False
|
|
|
|
def test_ollama_not_sock(self):
|
|
assert router._is_unix_socket_endpoint("http://localhost:11434") is False
|
|
|
|
def test_dot_sock_in_host_detected(self):
|
|
assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True
|
|
|
|
|
|
class TestGetSocketPath:
|
|
def test_returns_run_user_path(self):
|
|
import os
|
|
path = router._get_socket_path("http://192.168.0.52.sock/v1")
|
|
uid = os.getuid()
|
|
assert path == f"/run/user/{uid}/192.168.0.52.sock"
|
|
|
|
|
|
class TestIsBase64:
|
|
def test_valid_base64(self):
|
|
import base64
|
|
data = base64.b64encode(b"hello world").decode()
|
|
assert router.is_base64(data) is True
|
|
|
|
def test_invalid_base64(self):
|
|
assert router.is_base64("not-base64!@#$") is False
|
|
|
|
def test_empty_string(self):
|
|
# Empty string is valid base64 (decodes to empty bytes)
|
|
assert router.is_base64("") is True
|
|
|
|
def test_non_string(self):
|
|
# Non-strings fall through without returning True (returns None)
|
|
assert not router.is_base64(12345)
|
|
|
|
|
|
class TestIsLlamaModelLoaded:
|
|
def test_status_dict_loaded(self):
|
|
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True
|
|
|
|
def test_status_dict_unloaded(self):
|
|
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False
|
|
|
|
def test_status_string_loaded(self):
|
|
assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True
|
|
|
|
def test_status_string_unloaded(self):
|
|
assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False
|
|
|
|
def test_no_status_field_always_loaded(self):
|
|
# No status field → always available (single-model server)
|
|
assert router._is_llama_model_loaded({"id": "m"}) is True
|
|
|
|
def test_status_none_always_loaded(self):
|
|
assert router._is_llama_model_loaded({"id": "m", "status": None}) is True
|
|
|
|
|
|
class TestEp2Base:
|
|
def test_adds_v1_to_ollama(self):
|
|
assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1"
|
|
|
|
def test_keeps_v1_if_present(self):
|
|
assert router.ep2base("http://host/v1") == "http://host/v1"
|
|
|
|
def test_llama_server_endpoint_unchanged(self):
|
|
ep = "http://192.168.0.50:8889/v1"
|
|
assert router.ep2base(ep) == ep
|
|
|
|
|
|
class TestDedupeOnKeys:
|
|
def test_removes_duplicate_by_single_key(self):
|
|
items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}]
|
|
result = router.dedupe_on_keys(items, ["name"])
|
|
assert len(result) == 2
|
|
assert result[0]["name"] == "a"
|
|
assert result[1]["name"] == "b"
|
|
|
|
def test_removes_duplicate_by_two_keys(self):
|
|
items = [
|
|
{"digest": "abc", "name": "m1"},
|
|
{"digest": "abc", "name": "m1"},
|
|
{"digest": "def", "name": "m2"},
|
|
]
|
|
result = router.dedupe_on_keys(items, ["digest", "name"])
|
|
assert len(result) == 2
|
|
|
|
def test_empty_list(self):
|
|
assert router.dedupe_on_keys([], ["name"]) == []
|
|
|
|
def test_no_duplicates(self):
|
|
items = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
|
|
assert len(router.dedupe_on_keys(items, ["name"])) == 3
|
|
|
|
|
|
class TestFormatConnectionIssue:
|
|
def test_connector_error_message(self):
|
|
err = aiohttp.ClientConnectorError(
|
|
connection_key=MagicMock(host="localhost", port=11434),
|
|
os_error=OSError(111, "Connection refused"),
|
|
)
|
|
msg = router._format_connection_issue("http://localhost:11434", err)
|
|
assert "localhost" in msg
|
|
assert "Connection refused" in msg or "111" in msg
|
|
|
|
def test_timeout_error_message(self):
|
|
msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError())
|
|
assert "Timed out" in msg
|
|
assert "host:1234" in msg
|
|
|
|
def test_generic_error(self):
|
|
msg = router._format_connection_issue("http://host:1234", ValueError("boom"))
|
|
assert "host:1234" in msg
|
|
assert "boom" in msg
|
|
|
|
|
|
class TestIsExtOpenaiEndpoint:
|
|
def test_openai_com_is_ext(self):
|
|
cfg = MagicMock()
|
|
cfg.endpoints = []
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True
|
|
|
|
def test_ollama_default_port_not_ext(self):
|
|
cfg = MagicMock()
|
|
cfg.endpoints = ["http://host:11434"]
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_ext_openai_endpoint("http://host:11434") is False
|
|
|
|
def test_llama_server_not_ext(self):
|
|
cfg = MagicMock()
|
|
cfg.endpoints = []
|
|
cfg.llama_server_endpoints = ["http://host:8080/v1"]
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_ext_openai_endpoint("http://host:8080/v1") is False
|
|
|
|
def test_no_v1_not_ext(self):
|
|
cfg = MagicMock()
|
|
cfg.endpoints = ["http://host:11434"]
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_ext_openai_endpoint("http://host:11434") is False
|
|
|
|
|
|
class TestIsOpenaiCompatible:
|
|
def test_v1_endpoint_compatible(self):
|
|
cfg = MagicMock()
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_openai_compatible("http://host/v1") is True
|
|
|
|
def test_ollama_not_compatible(self):
|
|
cfg = MagicMock()
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_openai_compatible("http://localhost:11434") is False
|
|
|
|
def test_llama_server_in_list_compatible(self):
|
|
cfg = MagicMock()
|
|
cfg.llama_server_endpoints = ["http://host:8080"]
|
|
with patch.object(router, "config", cfg):
|
|
assert router.is_openai_compatible("http://host:8080") is True
|
|
|
|
|
|
class TestGetTrackingModel:
|
|
def test_ollama_adds_latest(self):
|
|
cfg = MagicMock()
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest"
|
|
|
|
def test_ollama_keeps_existing_tag(self):
|
|
cfg = MagicMock()
|
|
cfg.llama_server_endpoints = []
|
|
with patch.object(router, "config", cfg):
|
|
assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b"
|
|
|
|
def test_llama_server_normalizes(self):
|
|
ep = "http://host:8080/v1"
|
|
cfg = MagicMock()
|
|
cfg.llama_server_endpoints = [ep]
|
|
with patch.object(router, "config", cfg):
|
|
result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
|
|
assert result == "model"
|