From aa8baebac5184b8c0abaffea676802007d96dc2f Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Sun, 14 Jun 2026 16:34:31 +0200 Subject: [PATCH] feat: add llama-swap as a backend --- api/management.py | 14 +++-- api/ollama.py | 85 +++++++++++++++++++++------ api/openai.py | 95 +++++++++++++++++++++++++++--- backends/control.py | 50 ++++++++++++++++ backends/normalize.py | 39 +++++++++---- backends/probe.py | 38 +++++++++++- config.py | 4 ++ config.yaml | 13 ++++- doc/configuration.md | 31 ++++++++++ router.py | 6 +- routing.py | 8 ++- test/config_test.yaml | 4 ++ test/conftest.py | 3 + test/test_choose_endpoint.py | 24 +++++++- test/test_fetch.py | 27 ++++++++- test/test_llama_swap.py | 109 +++++++++++++++++++++++++++++++++++ test/test_unit_helpers.py | 46 +++++++++++++++ 17 files changed, 544 insertions(+), 52 deletions(-) create mode 100644 backends/control.py create mode 100644 test/test_llama_swap.py diff --git a/api/management.py b/api/management.py index ac1f356..0e9ecc2 100644 --- a/api/management.py +++ b/api/management.py @@ -27,7 +27,7 @@ from state import ( _affinity_lock, ) from sse import subscribe, unsubscribe -from backends.normalize import _normalize_llama_model_name +from backends.normalize import _normalize_llama_model_name, is_llama_server, llama_endpoints from backends.probe import _endpoint_health @@ -127,7 +127,6 @@ async def affinity_stats(request: Request): now = time.monotonic() entries: list[dict] = [] - llama_eps = set(config.llama_server_endpoints) async with _affinity_lock: for fp, (ep, mdl, expires_at) in list(_affinity_map.items()): remaining = expires_at - now @@ -136,7 +135,7 @@ async def affinity_stats(request: Request): continue # Mirror the normalisation used by /api/ps_details so the dashboard # can join affinity entries to PS rows by (endpoint, model). - display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl + display_model = _normalize_llama_model_name(mdl) if is_llama_server(ep) else mdl entries.append({ "endpoint": ep, "model": display_model, @@ -175,9 +174,12 @@ async def config_proxy(request: Request): ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints]) llama_results = [] - if config.llama_server_endpoints: + # llama-server and llama-swap render identically in the dashboard ("llama" rows), + # so health-check both and merge them into one list. + llama_eps = llama_endpoints(config) + if llama_eps: llama_results = await asyncio.gather( - *[check(ep) for ep in config.llama_server_endpoints] + *[check(ep) for ep in llama_eps] ) return { @@ -227,7 +229,7 @@ async def health_proxy(request: Request): # purposes. Probing /api/version alone would miss the case where the # Ollama process is up but /api/ps is failing — see issue #83. all_endpoints = list(config.endpoints) - llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints] all_endpoints += llama_eps_extra probe_results = await asyncio.gather( diff --git a/api/ollama.py b/api/ollama.py index afba243..0fc98aa 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -40,9 +40,12 @@ from backends.health import ( from backends.normalize import ( dedupe_on_keys, is_openai_compatible, + is_llama_server, + llama_endpoints, _normalize_llama_model_name, _extract_llama_quant, ) +from backends.control import unload_model from backends.probe import fetch from backends.sessions import _make_openai_client, get_ollama_client, get_probe_session from requests.chat import _make_moe_requests @@ -372,7 +375,7 @@ async def chat_proxy(request: Request): if use_openai: start_ts = time.perf_counter() # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) @@ -935,8 +938,8 @@ async def tags_proxy(request: Request): # 1. Query all endpoints for models tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep] - # Also query llama-server endpoints not already covered by config.endpoints - llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + # Also query llama-server / llama-swap endpoints not already covered by config.endpoints + llama_eps_for_tags = [ep for ep in llama_endpoints(config) if ep not in config.endpoints] tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags] all_models = await asyncio.gather(*tasks) @@ -960,27 +963,42 @@ async def tags_proxy(request: Request): ) +async def _fetch_llama_swap_running(endpoint: str) -> list[dict]: + """Return the list of ready (`state == "ready"`) workers from a llama-swap + endpoint's `/running` route. llama-swap omits the per-model `status` field on + `/v1/models`, so running workers must be read here instead. + """ + config = get_config() + base_url = endpoint.rstrip("/").removesuffix("/v1") + return await fetch.endpoint_details( + base_url, "/running", "running", config.api_keys.get(endpoint), + skip_error_cache=True, timeout=8, + ) + + @router.get("/api/ps") async def ps_proxy(request: Request): """ - Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models. + Proxy a ps request to all Ollama, llama-server and llama-swap endpoints and reply a unique list of all running models. For Ollama endpoints: queries /api/ps For llama-server endpoints: queries /v1/models with status.value == "loaded" + For llama-swap endpoints: queries /running (state == "ready") """ config = get_config() # 1. Query Ollama endpoints for running models via /api/ps ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] # 2. Query llama-server endpoints for loaded models via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) llama_tasks = [ fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) - for ep in all_llama_endpoints + for ep in config.llama_server_endpoints ] + # 3. Query llama-swap endpoints for running workers via /running + swap_tasks = [_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints] ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else [] + swap_running = await asyncio.gather(*swap_tasks) if swap_tasks else [] models = {'models': []} # Add Ollama models (if any) @@ -1003,6 +1021,21 @@ async def ps_proxy(request: Request): "status": item.get("status"), "details": {"quantization_level": quant} if quant else {} }) + # Add llama-swap running workers (already filtered on state == "ready") + if swap_running: + for runlist in swap_running: + for item in runlist: + if item.get("state") != "ready": + continue + raw_id = item.get("model", "") + normalized = _normalize_llama_model_name(raw_id) + quant = _extract_llama_quant(raw_id) + models['models'].append({ + "name": normalized, + "id": normalized, + "digest": "", + "details": {"quantization_level": quant} if quant else {} + }) # 3. Return a JSONResponse with deduplicated currently deployed models # Deduplicate on 'name' rather than 'digest': llama-server models always @@ -1101,16 +1134,7 @@ async def ps_details_proxy(request: Request): is_generation = "temperature" in dgs if is_sleeping: - unload_url = f"{base_url}/models/unload" - try: - async with client.post( - unload_url, - json={"model": model_id}, - headers=headers, - ) as unload_resp: - print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}") - except Exception as ue: - print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") + await unload_model(endpoint, model_id) return n_ctx, is_sleeping, is_generation except Exception as e: @@ -1131,4 +1155,31 @@ async def ps_details_proxy(request: Request): if not is_sleeping: models.append(model_dict) + # Add llama-swap running workers (read from /running; no status/props/auto-unload — + # llama-swap omits the status field on /v1/models and manages its own TTL eviction). + if config.llama_swap_endpoints: + swap_running = await asyncio.gather( + *[_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints] + ) + for endpoint, runlist in zip(config.llama_swap_endpoints, swap_running): + for item in runlist: + if not isinstance(item, dict) or item.get("state") != "ready": + continue + raw_id = item.get("model", "") + if not raw_id: + continue + normalized = _normalize_llama_model_name(raw_id) + quant = _extract_llama_quant(raw_id) + models.append({ + "name": normalized, + "id": normalized, + "original_name": raw_id, + "digest": "", + "details": {"quantization_level": quant} if quant else {}, + "endpoint": endpoint, + "state": item.get("state"), + "ttl": item.get("ttl"), + "proxy": item.get("proxy"), + }) + return JSONResponse(content={"models": models}, status_code=200) diff --git a/api/openai.py b/api/openai.py index ab24f54..1f0d22d 100644 --- a/api/openai.py +++ b/api/openai.py @@ -34,6 +34,8 @@ from backends.normalize import ( ep2base, is_ext_openai_endpoint, is_openai_compatible, + is_llama_server, + llama_endpoints, _normalize_llama_model_name, ) from backends.probe import fetch @@ -353,7 +355,7 @@ async def openai_chat_completions_proxy(request: Request): resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) send_params = {**params, "messages": resolved_msgs} # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) @@ -658,9 +660,9 @@ async def openai_models_proxy(request: Request): ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] # 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)] - # 3. Query llama-server endpoints for loaded models via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) + # 3. Query llama-server / llama-swap endpoints for advertised models via /v1/models + # Also query endpoints that may not be in config.endpoints + all_llama_endpoints = llama_endpoints(config) llama_tasks = [ fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in all_llama_endpoints @@ -783,10 +785,10 @@ async def rerank_proxy(request: Request): upstream_payload[optional_key] = payload[optional_key] # Determine upstream URL: - # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) + # llama-server / llama-swap expose /v1/rerank (base already contains /v1) # External OpenAI endpoints expose /rerank under their /v1 base - if endpoint in config.llama_server_endpoints: - # llama-server: endpoint may or may not already contain /v1 + if is_llama_server(endpoint): + # llama-server / llama-swap: endpoint may or may not already contain /v1 if "/v1" in endpoint: rerank_url = f"{endpoint}/rerank" else: @@ -823,3 +825,82 @@ async def rerank_proxy(request: Request): return JSONResponse(content=data) finally: await decrement_usage(endpoint, tracking_model) + + +async def _resolve_llama_swap_endpoint(model_id: str) -> str | None: + """Pick the llama-swap endpoint that serves ``model_id``. + + Prefers an endpoint that already has the worker running; falls back to any + that advertises the model. Returns None if none do. + """ + config = get_config() + swap_eps = config.llama_swap_endpoints + if not swap_eps: + return None + + advertised = await asyncio.gather( + *[fetch.available_models(ep, config.api_keys.get(ep)) for ep in swap_eps] + ) + candidates = [ep for ep, models in zip(swap_eps, advertised) if model_id in models] + if not candidates: + return None + if len(candidates) == 1: + return candidates[0] + + loaded = await asyncio.gather(*[fetch.loaded_models(ep) for ep in candidates]) + for ep, lm in zip(candidates, loaded): + if model_id in lm: + return ep + return candidates[0] + + +@router.api_route("/upstream/{model_id}/{path:path}", methods=["GET", "POST"]) +async def llama_swap_upstream(model_id: str, path: str, request: Request): + """Bypass llama-swap and reach a model's underlying llama-server worker directly + via llama-swap's ``/upstream/:model_id`` route. + + Lets clients use llama-server features that llama-swap itself does not forward + (e.g. token-array prompts), while still letting the router pick the backend that + actually hosts the model. ``/upstream`` is a root route, so the ``/v1`` suffix is + stripped from the configured endpoint. + """ + config = get_config() + endpoint = await _resolve_llama_swap_endpoint(model_id) + if endpoint is None: + raise HTTPException( + status_code=404, + detail=f"No configured llama-swap endpoint serves model '{model_id}'.", + ) + + base_url = endpoint.rstrip("/").removesuffix("/v1") + url = f"{base_url}/upstream/{model_id}/{path}" + if request.url.query: + url = f"{url}?{request.url.query}" + + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + content_type = request.headers.get("content-type") + if content_type: + headers["Content-Type"] = content_type + api_key = config.api_keys.get(endpoint) + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + + body = await request.body() + client: aiohttp.ClientSession = get_session(endpoint) + try: + resp = await client.request(request.method, url, data=body or None, headers=headers) + except Exception as e: + raise HTTPException(status_code=502, detail=f"Upstream request to {url} failed: {e}") + + async def _iter(): + try: + async for chunk in resp.content.iter_any(): + yield chunk + finally: + resp.release() + + return StreamingResponse( + _iter(), + status_code=resp.status, + media_type=resp.headers.get("Content-Type"), + ) diff --git a/backends/control.py b/backends/control.py new file mode 100644 index 0000000..fda5fe3 --- /dev/null +++ b/backends/control.py @@ -0,0 +1,50 @@ +"""Backend control operations (model unload). + +llama-server and llama-swap evict a resident model through different routes: + * llama-server → ``POST {base}/models/unload`` with body ``{"model": id}`` + * llama-swap → ``POST {base}/api/models/unload/{id}`` (path parameter) + +``unload_model`` dispatches on the configured backend type so callers don't +have to know which one they are talking to. Both routes live at the endpoint +root, so any ``/v1`` suffix is stripped first. +""" +from typing import Optional + +import aiohttp + +from config import get_config +from state import default_headers +from backends.sessions import get_probe_session +from backends.normalize import is_llama_swap +from backends.health import _format_connection_issue + + +async def unload_model(endpoint: str, model_id: str) -> bool: + """Ask ``endpoint`` to unload ``model_id``. Returns True on a 2xx response. + + ``model_id`` must be the backend's native model identifier (the raw HF id + for llama-server / llama-swap), not the router-normalized display name. + """ + cfg = get_config() + base_url = endpoint.rstrip("/").removesuffix("/v1") + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + api_key: Optional[str] = cfg.api_keys.get(endpoint) + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + + if is_llama_swap(endpoint): + url = f"{base_url}/api/models/unload/{model_id}" + json_body = None + else: + url = f"{base_url}/models/unload" + json_body = {"model": model_id} + + client: aiohttp.ClientSession = get_probe_session(endpoint) + try: + async with client.post(url, json=json_body, headers=headers) as resp: + ok = resp.status < 400 + print(f"[unload_model] {model_id} on {endpoint}: {resp.status}") + return ok + except Exception as e: + print(f"[unload_model] {_format_connection_issue(url, e)}") + return False diff --git a/backends/normalize.py b/backends/normalize.py index 6603f9d..41fc199 100644 --- a/backends/normalize.py +++ b/backends/normalize.py @@ -50,27 +50,46 @@ def dedupe_on_keys(dicts, key_fields): return out +def is_llama_swap(endpoint: str) -> bool: + """True if the endpoint is a configured llama-swap front.""" + return endpoint in get_config().llama_swap_endpoints + + +def is_llama_server(endpoint: str) -> bool: + """True for a llama.cpp llama-server OR a llama-swap front. + + Both speak the same OpenAI-compatible surface, so the router treats them + identically everywhere except loaded-model detection and model unload. + """ + cfg = get_config() + return endpoint in cfg.llama_server_endpoints or endpoint in cfg.llama_swap_endpoints + + +def llama_endpoints(cfg) -> list: + """Combined, de-duplicated llama-server + llama-swap endpoints (order preserved).""" + return list(dict.fromkeys([*cfg.llama_server_endpoints, *cfg.llama_swap_endpoints])) + + def is_ext_openai_endpoint(endpoint: str) -> bool: """ - Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server). + Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama, llama-server or llama-swap). Returns True for: - External services like OpenAI.com, Groq, etc. Returns False for: - Ollama endpoints (without /v1, or with /v1 but default port 11434) - - llama-server endpoints (explicitly configured in llama_server_endpoints) + - llama-server / llama-swap endpoints (explicitly configured) """ - cfg = get_config() - # Check if it's a llama-server endpoint (has /v1 and is in the configured list) - if endpoint in cfg.llama_server_endpoints: + # Check if it's a llama-server / llama-swap endpoint (has /v1 and is in a configured list) + if is_llama_server(endpoint): return False if "/v1" not in endpoint: return False base_endpoint = endpoint.replace('/v1', '') - if base_endpoint in cfg.endpoints: + if base_endpoint in get_config().endpoints: return False # It's Ollama's /v1 # Check for default Ollama port @@ -83,9 +102,9 @@ def is_ext_openai_endpoint(endpoint: str) -> bool: def is_openai_compatible(endpoint: str) -> bool: """ Return True if the endpoint speaks the OpenAI API (not native Ollama). - This includes external OpenAI endpoints AND llama-server endpoints. + This includes external OpenAI endpoints AND llama-server / llama-swap endpoints. """ - return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints + return "/v1" in endpoint or is_llama_server(endpoint) def get_tracking_model(endpoint: str, model: str) -> str: @@ -102,8 +121,8 @@ def get_tracking_model(endpoint: str, model: str) -> str: if is_ext_openai_endpoint(endpoint): return model - # llama-server endpoints use normalized names in PS - if endpoint in get_config().llama_server_endpoints: + # llama-server / llama-swap endpoints use normalized names in PS + if is_llama_server(endpoint): return _normalize_llama_model_name(model) # Ollama endpoints: append ":latest" if no version suffix diff --git a/backends/probe.py b/backends/probe.py index 3ce089f..f59e65e 100644 --- a/backends/probe.py +++ b/backends/probe.py @@ -46,7 +46,7 @@ from backends.health import ( _format_connection_issue, _is_llama_model_loaded, ) -from backends.normalize import is_ext_openai_endpoint, is_openai_compatible +from backends.normalize import is_ext_openai_endpoint, is_openai_compatible, is_llama_server, is_llama_swap class fetch: @@ -61,10 +61,10 @@ class fetch: headers["Authorization"] = "Bearer " + api_key ep_base = endpoint.rstrip("/") - if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint: + if is_llama_server(endpoint) and "/v1" not in endpoint: endpoint_url = f"{ep_base}/v1/models" key = "data" - elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints: + elif "/v1" in endpoint or is_llama_server(endpoint): endpoint_url = f"{ep_base}/models" key = "data" else: @@ -194,6 +194,38 @@ class fetch: client: aiohttp.ClientSession = get_probe_session(endpoint) cfg = get_config() + # llama-swap: loaded/running workers are reported at /running (state == "ready"), + # NOT via a status field on /v1/models (which it omits). /running is a root route, + # so strip any /v1 suffix from the configured endpoint. + if is_llama_swap(endpoint): + base_url = endpoint.rstrip("/").removesuffix("/v1") + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + api_key = cfg.api_keys.get(endpoint) + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + try: + async with client.get(f"{base_url}/running", headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + + models = { + item.get("model") + for item in data.get("running", []) + if item.get("model") and item.get("state") == "ready" + } + + async with _loaded_models_cache_lock: + _loaded_models_cache[endpoint] = (models, time.time()) + async with _loaded_error_cache_lock: + _loaded_error_cache.pop(endpoint, None) + return models + except Exception as e: + message = _format_connection_issue(f"{base_url}/running", e) + print(f"[fetch.loaded_models] {message}") + async with _loaded_error_cache_lock: + _loaded_error_cache[endpoint] = time.time() + return set() + # Check if this is a llama-server endpoint if endpoint in cfg.llama_server_endpoints: # Query /v1/models for llama-server. Send the configured key as a diff --git a/config.py b/config.py index 143a2f9..03d8e94 100644 --- a/config.py +++ b/config.py @@ -23,6 +23,10 @@ class Config(BaseSettings): ) # List of llama-server endpoints (OpenAI-compatible with /v1/models status info) llama_server_endpoints: List[str] = Field(default_factory=list) + # List of llama-swap endpoints (OpenAI-compatible front for multiple llama-server + # workers). Same surface as llama_server_endpoints, but loaded models are read from + # /running (not /v1/models status) and unload uses POST /api/models/unload/:model_id. + llama_swap_endpoints: List[str] = Field(default_factory=list) # Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL max_concurrent_connections: int = 1 # Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}} diff --git a/config.yaml b/config.yaml index 2107a3c..51ebb1b 100644 --- a/config.yaml +++ b/config.yaml @@ -6,7 +6,15 @@ endpoints: - https://api.openai.com/v1 llama_server_endpoints: - - http://192.168.0.50:8889/v1 + - http://192.168.0.51:8889/v1 + +# llama-swap endpoints (OpenAI-compatible front for multiple llama-server workers). +# Same surface as llama_server_endpoints, but the router reads loaded/running workers +# from /running (state == "ready") instead of a /v1/models status field, and unloads via +# POST /api/models/unload/:model_id. The router also exposes /upstream/:model_id/ +# to bypass llama-swap and reach a model's underlying llama-server worker directly. +llama_swap_endpoints: + - http://192.168.0.52:8890/v1 # Maximum concurrent connections *per endpoint‑model pair* (equals to OLLAMA_NUM_PARALLEL) # This is the global default; individual endpoints can override it via endpoint_config below. @@ -57,7 +65,8 @@ api_keys: "http://192.168.0.51:11434": "ollama" "http://192.168.0.52:11434": "ollama" "https://api.openai.com/v1": "${OPENAI_KEY}" - "http://192.168.0.50:8889/v1": "llama" + "http://192.168.0.51:8889/v1": "llama" + "http://192.168.0.52:8889/v1": "llama-swap" # ------------------------------------------------------------- # Semantic LLM Cache (optional — disabled by default) diff --git a/doc/configuration.md b/doc/configuration.md index 1addd66..e067207 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -78,6 +78,37 @@ endpoints: - OpenAI-compatible endpoints use `/v1` prefix - The router automatically detects endpoint type based on URL pattern +### `llama_server_endpoints` + +**Type**: `list[str]` (optional) + +**Default**: `[]` + +**Description**: List of [llama.cpp `llama-server`](https://github.com/ggml-org/llama.cpp) endpoints (OpenAI-compatible, configured with the `/v1` suffix). The router reads each backend's loaded models from `/v1/models` (entries with `status == "loaded"`) and unloads idle models via `POST /models/unload`. + +```yaml +llama_server_endpoints: + - http://192.168.0.50:8889/v1 +``` + +### `llama_swap_endpoints` + +**Type**: `list[str]` (optional) + +**Default**: `[]` + +**Description**: List of [llama-swap](https://github.com/mostlygeek/llama-swap) endpoints (OpenAI-compatible, configured with the `/v1` suffix). llama-swap fronts multiple `llama-server` workers behind one address. It is treated like `llama_server_endpoints` for routing, model discovery, and reranking, but differs in two ways the router handles automatically: + +- **Loaded-model detection** — llama-swap's `/v1/models` omits the per-model `status` field, so running workers are read from `GET /running` (entries with `state == "ready"`). +- **Model unload** — done via `POST /api/models/unload/:model_id` (path parameter), not the `llama-server` body form. + +The router also exposes a passthrough route, `GET|POST /upstream/:model_id/`, which forwards directly to a model's underlying `llama-server` worker (via llama-swap's `/upstream`), letting clients use `llama-server` features that llama-swap does not forward (e.g. token-array prompts). + +```yaml +llama_swap_endpoints: + - http://192.168.0.50:8890/v1 +``` + ### `max_concurrent_connections` **Type**: `int` diff --git a/router.py b/router.py index a2f9dd8..aca2d01 100644 --- a/router.py +++ b/router.py @@ -231,6 +231,7 @@ from backends.health import ( from backends.normalize import ( is_ext_openai_endpoint, is_openai_compatible, + llama_endpoints, get_tracking_model, ) @@ -310,6 +311,7 @@ async def startup_event() -> None: f"Loaded configuration from {config_path}:\n" f" endpoints={config.endpoints},\n" f" llama_server_endpoints={config.llama_server_endpoints},\n" + f" llama_swap_endpoints={config.llama_swap_endpoints},\n" f" max_concurrent_connections={config.max_concurrent_connections},\n" f" endpoint_config={config.endpoint_config},\n" f" priority_routing={config.priority_routing}" @@ -374,7 +376,7 @@ async def startup_event() -> None: app_state["httpx_clients"][ep] = httpx.AsyncClient(timeout=30.0) # Create per-endpoint Unix socket sessions for .sock endpoints - for ep in config.llama_server_endpoints: + for ep in llama_endpoints(config): if _is_unix_socket_endpoint(ep): sock_path = _get_socket_path(ep) sock_connector = aiohttp.UnixConnector(path=sock_path) @@ -391,7 +393,7 @@ async def startup_event() -> None: # client (/api/chat, /api/generate) and the OpenAI client (/v1/* routes), # so warm both; OpenAI-compatible endpoints only need the OpenAI client. _warm_endpoints = config.endpoints + [ - ep for ep in config.llama_server_endpoints if ep not in config.endpoints + ep for ep in llama_endpoints(config) if ep not in config.endpoints ] for ep in _warm_endpoints: try: diff --git a/routing.py b/routing.py index ecf6803..0a1cc7f 100644 --- a/routing.py +++ b/routing.py @@ -32,6 +32,8 @@ from backends.health import _is_fresh from backends.normalize import ( is_ext_openai_endpoint, is_openai_compatible, + is_llama_server, + llama_endpoints, get_tracking_model, ) from backends.probe import fetch @@ -93,8 +95,8 @@ async def choose_endpoint(model: str, reserve: bool = True, """ config = get_config() # 1️⃣ Gather advertised‑model sets for all endpoints concurrently - # Include both config.endpoints and config.llama_server_endpoints - llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + # Include config.endpoints plus any llama-server / llama-swap endpoints + llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints] all_endpoints = config.endpoints + llama_eps_extra tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)] @@ -114,7 +116,7 @@ async def choose_endpoint(model: str, reserve: bool = True, model_without_latest = model.split(":latest")[0] candidate_endpoints = [ ep for ep, models in zip(all_endpoints, advertised_sets) - if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints) + if model_without_latest in models and (is_ext_openai_endpoint(ep) or is_llama_server(ep)) ] if not candidate_endpoints: # Only add :latest suffix if model doesn't already have a version suffix diff --git a/test/config_test.yaml b/test/config_test.yaml index 30f2fa3..ed96542 100644 --- a/test/config_test.yaml +++ b/test/config_test.yaml @@ -4,10 +4,14 @@ endpoints: llama_server_endpoints: - http://192.168.0.51:12434/v1 +llama_swap_endpoints: + - http://192.168.0.51:12435/v1 + max_concurrent_connections: 2 api_keys: "http://192.168.0.51:12434": "ollama" "http://192.168.0.51:12434/v1": "llama" + "http://192.168.0.51:12435/v1": "llama-swap" cache_enabled: false diff --git a/test/conftest.py b/test/conftest.py index c5142da..da7dacf 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -57,6 +57,7 @@ def mock_config(): cfg = MagicMock() cfg.endpoints = [TEST_OLLAMA] cfg.llama_server_endpoints = [TEST_LLAMA] + cfg.llama_swap_endpoints = [] cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"} cfg.max_concurrent_connections = 2 cfg.router_api_key = None @@ -70,6 +71,7 @@ def mock_config_no_llama(): cfg = MagicMock() cfg.endpoints = [TEST_OLLAMA] cfg.llama_server_endpoints = [] + cfg.llama_swap_endpoints = [] cfg.api_keys = {TEST_OLLAMA: "ollama"} cfg.max_concurrent_connections = 2 cfg.router_api_key = None @@ -83,6 +85,7 @@ def mock_config_with_key(): cfg = MagicMock() cfg.endpoints = [TEST_OLLAMA] cfg.llama_server_endpoints = [] + cfg.llama_swap_endpoints = [] cfg.api_keys = {} cfg.max_concurrent_connections = 2 cfg.router_api_key = "test-secret-key" diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py index be75f82..17650c4 100644 --- a/test/test_choose_endpoint.py +++ b/test/test_choose_endpoint.py @@ -12,10 +12,11 @@ EP3 = "http://ep3:11434" LLAMA_EP = "http://llama:8080/v1" -def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False): +def _make_cfg(endpoints, llama_eps=None, swap_eps=None, max_conn=2, endpoint_config=None, priority_routing=False): cfg = MagicMock() cfg.endpoints = endpoints cfg.llama_server_endpoints = llama_eps or [] + cfg.llama_swap_endpoints = swap_eps or [] cfg.api_keys = {} cfg.max_concurrent_connections = max_conn cfg.endpoint_config = endpoint_config or {} @@ -46,6 +47,27 @@ class TestChooseEndpointBasic: assert ep == EP1 assert tracking == "llama3.2:latest" + async def test_llama_swap_endpoint_is_a_candidate(self): + swap_ep = "http://swap:8080/v1" + cfg = _make_cfg([EP1], swap_eps=[swap_ep]) + + async def available(ep, *_): + # Only the llama-swap backend advertises this model + return {"org/model:Q4_K_M"} if ep == swap_ep else set() + + async def loaded(ep): + return {"org/model:Q4_K_M"} if ep == swap_ep else set() + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", side_effect=loaded), + ): + ep, tracking = await router.choose_endpoint("org/model:Q4_K_M") + assert ep == swap_ep + # llama-swap models are tracked under their normalized name + assert tracking == "model" + async def test_raises_when_no_endpoint_has_model(self): cfg = _make_cfg([EP1, EP2]) with ( diff --git a/test/test_fetch.py b/test/test_fetch.py index 76121e1..dae51e4 100644 --- a/test/test_fetch.py +++ b/test/test_fetch.py @@ -20,10 +20,11 @@ MOCK_OLLAMA_EP = "http://mock-ollama:11434" MOCK_LLAMA_EP = "http://mock-llama:8080/v1" -def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None): +def _make_cfg(ollama_eps=None, llama_eps=None, swap_eps=None, api_keys=None): cfg = MagicMock() cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP] cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP] + cfg.llama_swap_endpoints = swap_eps or [] cfg.api_keys = api_keys or {} cfg.max_concurrent_connections = 2 cfg.router_api_key = None @@ -228,6 +229,30 @@ class TestFetchLoadedModels: models = await router.fetch.loaded_models(MOCK_LLAMA_EP) assert "always-on-model" in models + async def test_llama_swap_reads_running_state_ready(self): + # llama-swap omits the /v1/models status field, so loaded workers come + # from /running (a root route — the /v1 suffix must be stripped). + swap_ep = "http://mock-swap:8080/v1" + cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep]) + with patch.object(router, "config", cfg), mock_probe() as m: + m.add_get( + "http://mock-swap:8080/running", + payload={"running": [ + {"model": "org/ready-model:Q4_K_M", "state": "ready"}, + {"model": "org/starting-model:Q8_0", "state": "starting"}, + ]}, + ) + models = await router.fetch.loaded_models(swap_ep) + assert models == {"org/ready-model:Q4_K_M"} + + async def test_llama_swap_records_error_on_failure(self): + swap_ep = "http://mock-swap:8080/v1" + cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep]) + with patch.object(router, "config", cfg), mock_probe() as m: + m.add_get("http://mock-swap:8080/running", status=502, payload={}) + await router.fetch.loaded_models(swap_ep) + assert swap_ep in router._loaded_error_cache + async def test_returns_empty_on_error(self): cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) with patch.object(router, "config", cfg), mock_probe() as m: diff --git a/test/test_llama_swap.py b/test/test_llama_swap.py new file mode 100644 index 0000000..d0427bf --- /dev/null +++ b/test/test_llama_swap.py @@ -0,0 +1,109 @@ +"""Tests for llama-swap specific behavior: unload dispatch + /upstream resolution.""" +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import router +import backends.control as control +import api.openai as openai_api + +SWAP_EP = "http://swap:8080/v1" +SERVER_EP = "http://server:8080/v1" + + +def _cfg(*, server=None, swap=None, api_keys=None): + cfg = MagicMock() + cfg.endpoints = [] + cfg.llama_server_endpoints = server or [] + cfg.llama_swap_endpoints = swap or [] + cfg.api_keys = api_keys or {} + return cfg + + +class _RecordingSession: + """Captures the most recent ``post`` call and returns a 200 response.""" + + def __init__(self, status=200): + self.calls = [] + self._status = status + + def post(self, url, **kwargs): + self.calls.append((url, kwargs)) + resp = MagicMock() + resp.status = self._status + + class _Ctx: + async def __aenter__(self_): + return resp + + async def __aexit__(self_, *exc): + return False + + return _Ctx() + + +class TestUnloadDispatch: + async def test_llama_swap_uses_path_param(self): + sess = _RecordingSession() + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(router, "config", cfg), + patch.object(control, "get_probe_session", lambda ep: sess), + ): + ok = await control.unload_model(SWAP_EP, "org/model:Q4_K_M") + assert ok is True + url, kwargs = sess.calls[0] + # /v1 stripped, model id is a path param, no JSON body + assert url == "http://swap:8080/api/models/unload/org/model:Q4_K_M" + assert kwargs.get("json") is None + + async def test_llama_server_uses_body(self): + sess = _RecordingSession() + cfg = _cfg(server=[SERVER_EP]) + with ( + patch.object(router, "config", cfg), + patch.object(control, "get_probe_session", lambda ep: sess), + ): + ok = await control.unload_model(SERVER_EP, "org/model:Q4_K_M") + assert ok is True + url, kwargs = sess.calls[0] + assert url == "http://server:8080/models/unload" + assert kwargs.get("json") == {"model": "org/model:Q4_K_M"} + + async def test_unload_failure_returns_false(self): + sess = _RecordingSession(status=500) + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(router, "config", cfg), + patch.object(control, "get_probe_session", lambda ep: sess), + ): + ok = await control.unload_model(SWAP_EP, "m") + assert ok is False + + +class TestUpstreamResolution: + async def test_resolves_endpoint_that_advertises_model(self): + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(openai_api, "get_config", lambda: cfg), + patch.object(openai_api.fetch, "available_models", + AsyncMock(return_value={"org/model:Q4_K_M"})), + ): + ep = await openai_api._resolve_llama_swap_endpoint("org/model:Q4_K_M") + assert ep == SWAP_EP + + async def test_returns_none_when_unserved(self): + cfg = _cfg(swap=[SWAP_EP]) + with ( + patch.object(openai_api, "get_config", lambda: cfg), + patch.object(openai_api.fetch, "available_models", + AsyncMock(return_value=set())), + ): + ep = await openai_api._resolve_llama_swap_endpoint("missing") + assert ep is None + + async def test_returns_none_without_swap_endpoints(self): + cfg = _cfg(swap=[]) + with patch.object(openai_api, "get_config", lambda: cfg): + ep = await openai_api._resolve_llama_swap_endpoint("any") + assert ep is None diff --git a/test/test_unit_helpers.py b/test/test_unit_helpers.py index d38eb37..def7082 100644 --- a/test/test_unit_helpers.py +++ b/test/test_unit_helpers.py @@ -277,3 +277,49 @@ class TestGetTrackingModel: with patch.object(router, "config", cfg): result = router.get_tracking_model(ep, "unsloth/model:Q8_0") assert result == "model" + + +class TestLlamaSwapClassification: + def _cfg(self, *, server=None, swap=None): + cfg = MagicMock() + cfg.endpoints = [] + cfg.llama_server_endpoints = server or [] + cfg.llama_swap_endpoints = swap or [] + return cfg + + def test_is_llama_swap_only_for_swap_list(self): + from backends.normalize import is_llama_swap + swap_ep = "http://host:8890/v1" + server_ep = "http://host:8889/v1" + cfg = self._cfg(server=[server_ep], swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert is_llama_swap(swap_ep) is True + assert is_llama_swap(server_ep) is False + + def test_is_llama_server_covers_both(self): + from backends.normalize import is_llama_server + swap_ep = "http://host:8890/v1" + server_ep = "http://host:8889/v1" + cfg = self._cfg(server=[server_ep], swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert is_llama_server(swap_ep) is True + assert is_llama_server(server_ep) is True + assert is_llama_server("http://host:11434") is False + + def test_swap_is_openai_compatible_not_ext(self): + swap_ep = "http://host:8890/v1" + cfg = self._cfg(swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert router.is_openai_compatible(swap_ep) is True + assert router.is_ext_openai_endpoint(swap_ep) is False + + def test_swap_tracking_model_normalized(self): + swap_ep = "http://host:8890/v1" + cfg = self._cfg(swap=[swap_ep]) + with patch.object(router, "config", cfg): + assert router.get_tracking_model(swap_ep, "unsloth/model:Q8_0") == "model" + + def test_llama_endpoints_dedupes_and_orders(self): + from backends.normalize import llama_endpoints + cfg = self._cfg(server=["a", "b"], swap=["b", "c"]) + assert llama_endpoints(cfg) == ["a", "b", "c"]