feat: add llama-swap as a backend
This commit is contained in:
parent
c8da58430a
commit
aa8baebac5
17 changed files with 544 additions and 52 deletions
|
|
@ -4,10 +4,14 @@ endpoints:
|
|||
llama_server_endpoints:
|
||||
- http://192.168.0.51:12434/v1
|
||||
|
||||
llama_swap_endpoints:
|
||||
- http://192.168.0.51:12435/v1
|
||||
|
||||
max_concurrent_connections: 2
|
||||
|
||||
api_keys:
|
||||
"http://192.168.0.51:12434": "ollama"
|
||||
"http://192.168.0.51:12434/v1": "llama"
|
||||
"http://192.168.0.51:12435/v1": "llama-swap"
|
||||
|
||||
cache_enabled: false
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ def mock_config():
|
|||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = [TEST_LLAMA]
|
||||
cfg.llama_swap_endpoints = []
|
||||
cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
|
|
@ -70,6 +71,7 @@ def mock_config_no_llama():
|
|||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = []
|
||||
cfg.llama_swap_endpoints = []
|
||||
cfg.api_keys = {TEST_OLLAMA: "ollama"}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
|
|
@ -83,6 +85,7 @@ def mock_config_with_key():
|
|||
cfg = MagicMock()
|
||||
cfg.endpoints = [TEST_OLLAMA]
|
||||
cfg.llama_server_endpoints = []
|
||||
cfg.llama_swap_endpoints = []
|
||||
cfg.api_keys = {}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = "test-secret-key"
|
||||
|
|
|
|||
|
|
@ -12,10 +12,11 @@ EP3 = "http://ep3:11434"
|
|||
LLAMA_EP = "http://llama:8080/v1"
|
||||
|
||||
|
||||
def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
|
||||
def _make_cfg(endpoints, llama_eps=None, swap_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = endpoints
|
||||
cfg.llama_server_endpoints = llama_eps or []
|
||||
cfg.llama_swap_endpoints = swap_eps or []
|
||||
cfg.api_keys = {}
|
||||
cfg.max_concurrent_connections = max_conn
|
||||
cfg.endpoint_config = endpoint_config or {}
|
||||
|
|
@ -46,6 +47,27 @@ class TestChooseEndpointBasic:
|
|||
assert ep == EP1
|
||||
assert tracking == "llama3.2:latest"
|
||||
|
||||
async def test_llama_swap_endpoint_is_a_candidate(self):
|
||||
swap_ep = "http://swap:8080/v1"
|
||||
cfg = _make_cfg([EP1], swap_eps=[swap_ep])
|
||||
|
||||
async def available(ep, *_):
|
||||
# Only the llama-swap backend advertises this model
|
||||
return {"org/model:Q4_K_M"} if ep == swap_ep else set()
|
||||
|
||||
async def loaded(ep):
|
||||
return {"org/model:Q4_K_M"} if ep == swap_ep else set()
|
||||
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(router.fetch, "available_models", side_effect=available),
|
||||
patch.object(router.fetch, "loaded_models", side_effect=loaded),
|
||||
):
|
||||
ep, tracking = await router.choose_endpoint("org/model:Q4_K_M")
|
||||
assert ep == swap_ep
|
||||
# llama-swap models are tracked under their normalized name
|
||||
assert tracking == "model"
|
||||
|
||||
async def test_raises_when_no_endpoint_has_model(self):
|
||||
cfg = _make_cfg([EP1, EP2])
|
||||
with (
|
||||
|
|
|
|||
|
|
@ -20,10 +20,11 @@ MOCK_OLLAMA_EP = "http://mock-ollama:11434"
|
|||
MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
|
||||
|
||||
|
||||
def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
|
||||
def _make_cfg(ollama_eps=None, llama_eps=None, swap_eps=None, api_keys=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
|
||||
cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
|
||||
cfg.llama_swap_endpoints = swap_eps or []
|
||||
cfg.api_keys = api_keys or {}
|
||||
cfg.max_concurrent_connections = 2
|
||||
cfg.router_api_key = None
|
||||
|
|
@ -228,6 +229,30 @@ class TestFetchLoadedModels:
|
|||
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
|
||||
assert "always-on-model" in models
|
||||
|
||||
async def test_llama_swap_reads_running_state_ready(self):
|
||||
# llama-swap omits the /v1/models status field, so loaded workers come
|
||||
# from /running (a root route — the /v1 suffix must be stripped).
|
||||
swap_ep = "http://mock-swap:8080/v1"
|
||||
cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
"http://mock-swap:8080/running",
|
||||
payload={"running": [
|
||||
{"model": "org/ready-model:Q4_K_M", "state": "ready"},
|
||||
{"model": "org/starting-model:Q8_0", "state": "starting"},
|
||||
]},
|
||||
)
|
||||
models = await router.fetch.loaded_models(swap_ep)
|
||||
assert models == {"org/ready-model:Q4_K_M"}
|
||||
|
||||
async def test_llama_swap_records_error_on_failure(self):
|
||||
swap_ep = "http://mock-swap:8080/v1"
|
||||
cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get("http://mock-swap:8080/running", status=502, payload={})
|
||||
await router.fetch.loaded_models(swap_ep)
|
||||
assert swap_ep in router._loaded_error_cache
|
||||
|
||||
async def test_returns_empty_on_error(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
|
|
|
|||
109
test/test_llama_swap.py
Normal file
109
test/test_llama_swap.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Tests for llama-swap specific behavior: unload dispatch + /upstream resolution."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import router
|
||||
import backends.control as control
|
||||
import api.openai as openai_api
|
||||
|
||||
SWAP_EP = "http://swap:8080/v1"
|
||||
SERVER_EP = "http://server:8080/v1"
|
||||
|
||||
|
||||
def _cfg(*, server=None, swap=None, api_keys=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = []
|
||||
cfg.llama_server_endpoints = server or []
|
||||
cfg.llama_swap_endpoints = swap or []
|
||||
cfg.api_keys = api_keys or {}
|
||||
return cfg
|
||||
|
||||
|
||||
class _RecordingSession:
|
||||
"""Captures the most recent ``post`` call and returns a 200 response."""
|
||||
|
||||
def __init__(self, status=200):
|
||||
self.calls = []
|
||||
self._status = status
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.calls.append((url, kwargs))
|
||||
resp = MagicMock()
|
||||
resp.status = self._status
|
||||
|
||||
class _Ctx:
|
||||
async def __aenter__(self_):
|
||||
return resp
|
||||
|
||||
async def __aexit__(self_, *exc):
|
||||
return False
|
||||
|
||||
return _Ctx()
|
||||
|
||||
|
||||
class TestUnloadDispatch:
|
||||
async def test_llama_swap_uses_path_param(self):
|
||||
sess = _RecordingSession()
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(control, "get_probe_session", lambda ep: sess),
|
||||
):
|
||||
ok = await control.unload_model(SWAP_EP, "org/model:Q4_K_M")
|
||||
assert ok is True
|
||||
url, kwargs = sess.calls[0]
|
||||
# /v1 stripped, model id is a path param, no JSON body
|
||||
assert url == "http://swap:8080/api/models/unload/org/model:Q4_K_M"
|
||||
assert kwargs.get("json") is None
|
||||
|
||||
async def test_llama_server_uses_body(self):
|
||||
sess = _RecordingSession()
|
||||
cfg = _cfg(server=[SERVER_EP])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(control, "get_probe_session", lambda ep: sess),
|
||||
):
|
||||
ok = await control.unload_model(SERVER_EP, "org/model:Q4_K_M")
|
||||
assert ok is True
|
||||
url, kwargs = sess.calls[0]
|
||||
assert url == "http://server:8080/models/unload"
|
||||
assert kwargs.get("json") == {"model": "org/model:Q4_K_M"}
|
||||
|
||||
async def test_unload_failure_returns_false(self):
|
||||
sess = _RecordingSession(status=500)
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(router, "config", cfg),
|
||||
patch.object(control, "get_probe_session", lambda ep: sess),
|
||||
):
|
||||
ok = await control.unload_model(SWAP_EP, "m")
|
||||
assert ok is False
|
||||
|
||||
|
||||
class TestUpstreamResolution:
|
||||
async def test_resolves_endpoint_that_advertises_model(self):
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(openai_api, "get_config", lambda: cfg),
|
||||
patch.object(openai_api.fetch, "available_models",
|
||||
AsyncMock(return_value={"org/model:Q4_K_M"})),
|
||||
):
|
||||
ep = await openai_api._resolve_llama_swap_endpoint("org/model:Q4_K_M")
|
||||
assert ep == SWAP_EP
|
||||
|
||||
async def test_returns_none_when_unserved(self):
|
||||
cfg = _cfg(swap=[SWAP_EP])
|
||||
with (
|
||||
patch.object(openai_api, "get_config", lambda: cfg),
|
||||
patch.object(openai_api.fetch, "available_models",
|
||||
AsyncMock(return_value=set())),
|
||||
):
|
||||
ep = await openai_api._resolve_llama_swap_endpoint("missing")
|
||||
assert ep is None
|
||||
|
||||
async def test_returns_none_without_swap_endpoints(self):
|
||||
cfg = _cfg(swap=[])
|
||||
with patch.object(openai_api, "get_config", lambda: cfg):
|
||||
ep = await openai_api._resolve_llama_swap_endpoint("any")
|
||||
assert ep is None
|
||||
|
|
@ -277,3 +277,49 @@ class TestGetTrackingModel:
|
|||
with patch.object(router, "config", cfg):
|
||||
result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
|
||||
assert result == "model"
|
||||
|
||||
|
||||
class TestLlamaSwapClassification:
|
||||
def _cfg(self, *, server=None, swap=None):
|
||||
cfg = MagicMock()
|
||||
cfg.endpoints = []
|
||||
cfg.llama_server_endpoints = server or []
|
||||
cfg.llama_swap_endpoints = swap or []
|
||||
return cfg
|
||||
|
||||
def test_is_llama_swap_only_for_swap_list(self):
|
||||
from backends.normalize import is_llama_swap
|
||||
swap_ep = "http://host:8890/v1"
|
||||
server_ep = "http://host:8889/v1"
|
||||
cfg = self._cfg(server=[server_ep], swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert is_llama_swap(swap_ep) is True
|
||||
assert is_llama_swap(server_ep) is False
|
||||
|
||||
def test_is_llama_server_covers_both(self):
|
||||
from backends.normalize import is_llama_server
|
||||
swap_ep = "http://host:8890/v1"
|
||||
server_ep = "http://host:8889/v1"
|
||||
cfg = self._cfg(server=[server_ep], swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert is_llama_server(swap_ep) is True
|
||||
assert is_llama_server(server_ep) is True
|
||||
assert is_llama_server("http://host:11434") is False
|
||||
|
||||
def test_swap_is_openai_compatible_not_ext(self):
|
||||
swap_ep = "http://host:8890/v1"
|
||||
cfg = self._cfg(swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.is_openai_compatible(swap_ep) is True
|
||||
assert router.is_ext_openai_endpoint(swap_ep) is False
|
||||
|
||||
def test_swap_tracking_model_normalized(self):
|
||||
swap_ep = "http://host:8890/v1"
|
||||
cfg = self._cfg(swap=[swap_ep])
|
||||
with patch.object(router, "config", cfg):
|
||||
assert router.get_tracking_model(swap_ep, "unsloth/model:Q8_0") == "model"
|
||||
|
||||
def test_llama_endpoints_dedupes_and_orders(self):
|
||||
from backends.normalize import llama_endpoints
|
||||
cfg = self._cfg(server=["a", "b"], swap=["b", "c"])
|
||||
assert llama_endpoints(cfg) == ["a", "b", "c"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue