feat: add llama-swap as a backend
This commit is contained in:
parent
c8da58430a
commit
aa8baebac5
17 changed files with 544 additions and 52 deletions
50
backends/control.py
Normal file
50
backends/control.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Backend control operations (model unload).
|
||||
|
||||
llama-server and llama-swap evict a resident model through different routes:
|
||||
* llama-server → ``POST {base}/models/unload`` with body ``{"model": id}``
|
||||
* llama-swap → ``POST {base}/api/models/unload/{id}`` (path parameter)
|
||||
|
||||
``unload_model`` dispatches on the configured backend type so callers don't
|
||||
have to know which one they are talking to. Both routes live at the endpoint
|
||||
root, so any ``/v1`` suffix is stripped first.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from config import get_config
|
||||
from state import default_headers
|
||||
from backends.sessions import get_probe_session
|
||||
from backends.normalize import is_llama_swap
|
||||
from backends.health import _format_connection_issue
|
||||
|
||||
|
||||
async def unload_model(endpoint: str, model_id: str) -> bool:
|
||||
"""Ask ``endpoint`` to unload ``model_id``. Returns True on a 2xx response.
|
||||
|
||||
``model_id`` must be the backend's native model identifier (the raw HF id
|
||||
for llama-server / llama-swap), not the router-normalized display name.
|
||||
"""
|
||||
cfg = get_config()
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
api_key: Optional[str] = cfg.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
if is_llama_swap(endpoint):
|
||||
url = f"{base_url}/api/models/unload/{model_id}"
|
||||
json_body = None
|
||||
else:
|
||||
url = f"{base_url}/models/unload"
|
||||
json_body = {"model": model_id}
|
||||
|
||||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
try:
|
||||
async with client.post(url, json=json_body, headers=headers) as resp:
|
||||
ok = resp.status < 400
|
||||
print(f"[unload_model] {model_id} on {endpoint}: {resp.status}")
|
||||
return ok
|
||||
except Exception as e:
|
||||
print(f"[unload_model] {_format_connection_issue(url, e)}")
|
||||
return False
|
||||
|
|
@ -50,27 +50,46 @@ def dedupe_on_keys(dicts, key_fields):
|
|||
return out
|
||||
|
||||
|
||||
def is_llama_swap(endpoint: str) -> bool:
|
||||
"""True if the endpoint is a configured llama-swap front."""
|
||||
return endpoint in get_config().llama_swap_endpoints
|
||||
|
||||
|
||||
def is_llama_server(endpoint: str) -> bool:
|
||||
"""True for a llama.cpp llama-server OR a llama-swap front.
|
||||
|
||||
Both speak the same OpenAI-compatible surface, so the router treats them
|
||||
identically everywhere except loaded-model detection and model unload.
|
||||
"""
|
||||
cfg = get_config()
|
||||
return endpoint in cfg.llama_server_endpoints or endpoint in cfg.llama_swap_endpoints
|
||||
|
||||
|
||||
def llama_endpoints(cfg) -> list:
|
||||
"""Combined, de-duplicated llama-server + llama-swap endpoints (order preserved)."""
|
||||
return list(dict.fromkeys([*cfg.llama_server_endpoints, *cfg.llama_swap_endpoints]))
|
||||
|
||||
|
||||
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
||||
"""
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama, llama-server or llama-swap).
|
||||
|
||||
Returns True for:
|
||||
- External services like OpenAI.com, Groq, etc.
|
||||
|
||||
Returns False for:
|
||||
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
||||
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
||||
- llama-server / llama-swap endpoints (explicitly configured)
|
||||
"""
|
||||
cfg = get_config()
|
||||
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
# Check if it's a llama-server / llama-swap endpoint (has /v1 and is in a configured list)
|
||||
if is_llama_server(endpoint):
|
||||
return False
|
||||
|
||||
if "/v1" not in endpoint:
|
||||
return False
|
||||
|
||||
base_endpoint = endpoint.replace('/v1', '')
|
||||
if base_endpoint in cfg.endpoints:
|
||||
if base_endpoint in get_config().endpoints:
|
||||
return False # It's Ollama's /v1
|
||||
|
||||
# Check for default Ollama port
|
||||
|
|
@ -83,9 +102,9 @@ def is_ext_openai_endpoint(endpoint: str) -> bool:
|
|||
def is_openai_compatible(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
||||
This includes external OpenAI endpoints AND llama-server endpoints.
|
||||
This includes external OpenAI endpoints AND llama-server / llama-swap endpoints.
|
||||
"""
|
||||
return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints
|
||||
return "/v1" in endpoint or is_llama_server(endpoint)
|
||||
|
||||
|
||||
def get_tracking_model(endpoint: str, model: str) -> str:
|
||||
|
|
@ -102,8 +121,8 @@ def get_tracking_model(endpoint: str, model: str) -> str:
|
|||
if is_ext_openai_endpoint(endpoint):
|
||||
return model
|
||||
|
||||
# llama-server endpoints use normalized names in PS
|
||||
if endpoint in get_config().llama_server_endpoints:
|
||||
# llama-server / llama-swap endpoints use normalized names in PS
|
||||
if is_llama_server(endpoint):
|
||||
return _normalize_llama_model_name(model)
|
||||
|
||||
# Ollama endpoints: append ":latest" if no version suffix
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from backends.health import (
|
|||
_format_connection_issue,
|
||||
_is_llama_model_loaded,
|
||||
)
|
||||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible
|
||||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible, is_llama_server, is_llama_swap
|
||||
|
||||
|
||||
class fetch:
|
||||
|
|
@ -61,10 +61,10 @@ class fetch:
|
|||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
ep_base = endpoint.rstrip("/")
|
||||
if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint:
|
||||
if is_llama_server(endpoint) and "/v1" not in endpoint:
|
||||
endpoint_url = f"{ep_base}/v1/models"
|
||||
key = "data"
|
||||
elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints:
|
||||
elif "/v1" in endpoint or is_llama_server(endpoint):
|
||||
endpoint_url = f"{ep_base}/models"
|
||||
key = "data"
|
||||
else:
|
||||
|
|
@ -194,6 +194,38 @@ class fetch:
|
|||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
cfg = get_config()
|
||||
|
||||
# llama-swap: loaded/running workers are reported at /running (state == "ready"),
|
||||
# NOT via a status field on /v1/models (which it omits). /running is a root route,
|
||||
# so strip any /v1 suffix from the configured endpoint.
|
||||
if is_llama_swap(endpoint):
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
api_key = cfg.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
try:
|
||||
async with client.get(f"{base_url}/running", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
models = {
|
||||
item.get("model")
|
||||
for item in data.get("running", [])
|
||||
if item.get("model") and item.get("state") == "ready"
|
||||
}
|
||||
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
message = _format_connection_issue(f"{base_url}/running", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
# Check if this is a llama-server endpoint
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
# Query /v1/models for llama-server. Send the configured key as a
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue