456 lines
20 KiB
Python
456 lines
20 KiB
Python
"""Backend probe / discovery primitives.
|
||
|
||
The ``fetch`` class wraps the three discovery paths the router uses:
|
||
* ``available_models`` — what the endpoint advertises (Ollama ``/api/tags``
|
||
or OpenAI-style ``/v1/models``)
|
||
* ``loaded_models`` — what is currently resident (Ollama ``/api/ps`` or
|
||
llama-server ``/v1/models`` filtered on ``status == "loaded"``)
|
||
* ``endpoint_details`` — arbitrary detail fetch used by management routes
|
||
|
||
Each path goes through three layers of cache: success cache, error cache,
|
||
and an in-flight request map. Stale-while-revalidate refreshes happen in
|
||
background tasks tracked by the ``_bg_refresh_*`` maps in ``state``.
|
||
|
||
``_raw_probe`` and ``_endpoint_health`` are the lower-level dual probes
|
||
used by ``/health`` and ``/api/config`` to distinguish a healthy daemon
|
||
with a broken model-introspection path from a dead daemon.
|
||
"""
|
||
import asyncio
|
||
import time
|
||
from typing import List, Optional, Set
|
||
|
||
import aiohttp
|
||
|
||
from config import get_config
|
||
from state import (
|
||
_models_cache,
|
||
_models_cache_lock,
|
||
_loaded_models_cache,
|
||
_loaded_models_cache_lock,
|
||
_available_error_cache,
|
||
_available_error_cache_lock,
|
||
_loaded_error_cache,
|
||
_loaded_error_cache_lock,
|
||
_inflight_available_models,
|
||
_inflight_loaded_models,
|
||
_inflight_lock,
|
||
_bg_refresh_available,
|
||
_bg_refresh_loaded,
|
||
_bg_refresh_lock,
|
||
default_headers,
|
||
)
|
||
from backends.sessions import get_probe_session
|
||
from backends.health import (
|
||
_is_fresh,
|
||
_ensure_success,
|
||
_format_connection_issue,
|
||
_is_llama_model_loaded,
|
||
)
|
||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible
|
||
|
||
|
||
class fetch:
|
||
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||
"""
|
||
Internal function that performs the actual HTTP request to fetch available models.
|
||
This is called by available_models() after checking caches and in-flight requests.
|
||
"""
|
||
cfg = get_config()
|
||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||
if api_key is not None:
|
||
headers["Authorization"] = "Bearer " + api_key
|
||
|
||
ep_base = endpoint.rstrip("/")
|
||
if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint:
|
||
endpoint_url = f"{ep_base}/v1/models"
|
||
key = "data"
|
||
elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints:
|
||
endpoint_url = f"{ep_base}/models"
|
||
key = "data"
|
||
else:
|
||
endpoint_url = f"{ep_base}/api/tags"
|
||
key = "models"
|
||
|
||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||
try:
|
||
async with client.get(endpoint_url, headers=headers) as resp:
|
||
await _ensure_success(resp)
|
||
data = await resp.json()
|
||
|
||
items = data.get(key, [])
|
||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||
|
||
async with _models_cache_lock:
|
||
_models_cache[endpoint] = (models, time.time())
|
||
return models
|
||
except Exception as e:
|
||
# Treat any error as if the endpoint offers no models
|
||
message = _format_connection_issue(endpoint_url, e)
|
||
print(f"[fetch.available_models] {message}")
|
||
# Update error cache with lock protection
|
||
async with _available_error_cache_lock:
|
||
_available_error_cache[endpoint] = time.time()
|
||
return set()
|
||
|
||
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
|
||
"""
|
||
Background task to refresh available models cache without blocking the caller.
|
||
Used for stale-while-revalidate pattern.
|
||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||
"""
|
||
async with _bg_refresh_lock:
|
||
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
||
return # A refresh is already running for this endpoint
|
||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||
_bg_refresh_available[endpoint] = task
|
||
|
||
try:
|
||
await task
|
||
except Exception as e:
|
||
# Silently fail - cache will remain stale but functional
|
||
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
||
finally:
|
||
async with _bg_refresh_lock:
|
||
if _bg_refresh_available.get(endpoint) is task:
|
||
_bg_refresh_available.pop(endpoint, None)
|
||
|
||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||
"""
|
||
Query <endpoint>/api/tags and return a set of all model names that the
|
||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||
every model that is installed on the Ollama instance, regardless of
|
||
whether the model is currently loaded into memory.
|
||
|
||
Uses request coalescing to prevent cache stampede: if multiple requests
|
||
arrive when cache is expired, only one actual HTTP request is made.
|
||
|
||
Uses stale-while-revalidate: when the cache is between 300-600s old,
|
||
the stale data is returned immediately while a background refresh runs.
|
||
This prevents model blackouts caused by transient timeouts.
|
||
|
||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||
set is returned.
|
||
"""
|
||
# Check models cache with lock protection
|
||
async with _models_cache_lock:
|
||
if endpoint in _models_cache:
|
||
models, cached_at = _models_cache[endpoint]
|
||
|
||
# FRESH: <= 300s old - return immediately
|
||
if _is_fresh(cached_at, 300):
|
||
return models
|
||
|
||
# STALE: 300-600s old - return stale data and refresh in background
|
||
if _is_fresh(cached_at, 600):
|
||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||
return models # Return stale data immediately
|
||
|
||
# EXPIRED: > 600s old - too stale, must refresh synchronously
|
||
del _models_cache[endpoint]
|
||
|
||
# Check error cache with lock protection
|
||
async with _available_error_cache_lock:
|
||
if endpoint in _available_error_cache:
|
||
err_age = time.time() - _available_error_cache[endpoint]
|
||
if err_age < 30:
|
||
# Very fresh error (<30s) – endpoint likely still down, bail fast
|
||
return set()
|
||
elif err_age < 300:
|
||
# Stale error (30-300s) – endpoint may have recovered, probe in background
|
||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||
return set()
|
||
# Error expired (>300s) – remove and fall through to fresh fetch
|
||
del _available_error_cache[endpoint]
|
||
|
||
# Request coalescing: check if another request is already fetching this endpoint
|
||
async with _inflight_lock:
|
||
if endpoint in _inflight_available_models:
|
||
# Another request is already fetching - wait for it
|
||
task = _inflight_available_models[endpoint]
|
||
else:
|
||
# Create new fetch task
|
||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||
_inflight_available_models[endpoint] = task
|
||
|
||
try:
|
||
# Wait for the fetch to complete (either ours or another request's)
|
||
result = await task
|
||
return result
|
||
finally:
|
||
# Clean up in-flight tracking (only if we created it)
|
||
async with _inflight_lock:
|
||
if _inflight_available_models.get(endpoint) == task:
|
||
_inflight_available_models.pop(endpoint, None)
|
||
|
||
|
||
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
|
||
"""
|
||
Internal function that performs the actual HTTP request to fetch loaded models.
|
||
This is called by loaded_models() after checking caches and in-flight requests.
|
||
|
||
For Ollama endpoints: queries /api/ps and returns model names
|
||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||
"""
|
||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||
cfg = get_config()
|
||
|
||
# Check if this is a llama-server endpoint
|
||
if endpoint in cfg.llama_server_endpoints:
|
||
# Query /v1/models for llama-server. Send the configured key as a
|
||
# Bearer token — current llama.cpp leaves /models public, but a
|
||
# build/config that protects it would otherwise 401 this probe.
|
||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||
api_key = cfg.api_keys.get(endpoint)
|
||
if api_key is not None:
|
||
headers["Authorization"] = "Bearer " + api_key
|
||
try:
|
||
async with client.get(f"{endpoint}/models", headers=headers) as resp:
|
||
await _ensure_success(resp)
|
||
data = await resp.json()
|
||
|
||
# Filter for loaded models only
|
||
items = data.get("data", [])
|
||
models = {
|
||
item.get("id")
|
||
for item in items
|
||
if item.get("id") and _is_llama_model_loaded(item)
|
||
}
|
||
|
||
# Update cache with lock protection
|
||
async with _loaded_models_cache_lock:
|
||
_loaded_models_cache[endpoint] = (models, time.time())
|
||
# Probe succeeded — clear any stale error so the endpoint
|
||
# becomes routable again.
|
||
async with _loaded_error_cache_lock:
|
||
_loaded_error_cache.pop(endpoint, None)
|
||
return models
|
||
except Exception as e:
|
||
# If anything goes wrong we simply assume the endpoint has no models
|
||
message = _format_connection_issue(f"{endpoint}/models", e)
|
||
print(f"[fetch.loaded_models] {message}")
|
||
# Record the failure so `choose_endpoint` can avoid routing
|
||
# to an unhealthy backend and repeated probes short-circuit.
|
||
async with _loaded_error_cache_lock:
|
||
_loaded_error_cache[endpoint] = time.time()
|
||
return set()
|
||
else:
|
||
# Original Ollama /api/ps logic
|
||
try:
|
||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||
await _ensure_success(resp)
|
||
data = await resp.json()
|
||
# The response format is:
|
||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||
|
||
# Update cache with lock protection
|
||
async with _loaded_models_cache_lock:
|
||
_loaded_models_cache[endpoint] = (models, time.time())
|
||
async with _loaded_error_cache_lock:
|
||
_loaded_error_cache.pop(endpoint, None)
|
||
return models
|
||
except Exception as e:
|
||
# If anything goes wrong we simply assume the endpoint has no models
|
||
message = _format_connection_issue(f"{endpoint}/api/ps", e)
|
||
print(f"[fetch.loaded_models] {message}")
|
||
async with _loaded_error_cache_lock:
|
||
_loaded_error_cache[endpoint] = time.time()
|
||
return set()
|
||
|
||
async def _refresh_loaded_models(endpoint: str) -> None:
|
||
"""
|
||
Background task to refresh loaded models cache without blocking the caller.
|
||
Used for stale-while-revalidate pattern.
|
||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||
"""
|
||
async with _bg_refresh_lock:
|
||
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
||
return # A refresh is already running for this endpoint
|
||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||
_bg_refresh_loaded[endpoint] = task
|
||
|
||
try:
|
||
await task
|
||
except Exception as e:
|
||
# Silently fail - cache will remain stale but functional
|
||
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||
finally:
|
||
async with _bg_refresh_lock:
|
||
if _bg_refresh_loaded.get(endpoint) is task:
|
||
_bg_refresh_loaded.pop(endpoint, None)
|
||
|
||
async def loaded_models(endpoint: str) -> Set[str]:
|
||
"""
|
||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||
set is returned.
|
||
|
||
Uses request coalescing to prevent cache stampede and stale-while-revalidate
|
||
to serve requests immediately even when cache is stale (refreshing in background).
|
||
"""
|
||
if is_ext_openai_endpoint(endpoint):
|
||
return set()
|
||
|
||
# Check loaded models cache with lock protection
|
||
async with _loaded_models_cache_lock:
|
||
if endpoint in _loaded_models_cache:
|
||
models, cached_at = _loaded_models_cache[endpoint]
|
||
|
||
# FRESH: < 10s old - return immediately
|
||
if _is_fresh(cached_at, 10):
|
||
return models
|
||
|
||
# STALE: 10-60s old - return stale data and refresh in background
|
||
if _is_fresh(cached_at, 60):
|
||
# Kick off background refresh (fire-and-forget)
|
||
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
|
||
return models # Return stale data immediately
|
||
|
||
# EXPIRED: > 60s old - too stale, must refresh synchronously
|
||
del _loaded_models_cache[endpoint]
|
||
|
||
# Check error cache with lock protection
|
||
async with _loaded_error_cache_lock:
|
||
if endpoint in _loaded_error_cache:
|
||
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
||
return set()
|
||
# Error expired - remove it
|
||
del _loaded_error_cache[endpoint]
|
||
|
||
# Request coalescing: check if another request is already fetching this endpoint
|
||
async with _inflight_lock:
|
||
if endpoint in _inflight_loaded_models:
|
||
# Another request is already fetching - wait for it
|
||
task = _inflight_loaded_models[endpoint]
|
||
else:
|
||
# Create new fetch task
|
||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||
_inflight_loaded_models[endpoint] = task
|
||
|
||
try:
|
||
# Wait for the fetch to complete (either ours or another request's)
|
||
result = await task
|
||
return result
|
||
finally:
|
||
# Clean up in-flight tracking (only if we created it)
|
||
async with _inflight_lock:
|
||
if _inflight_loaded_models.get(endpoint) == task:
|
||
_inflight_loaded_models.pop(endpoint, None)
|
||
|
||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]:
|
||
"""
|
||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||
|
||
When ``skip_error_cache`` is False (the default), the call is short-circuited
|
||
if the endpoint recently failed (recorded in ``_available_error_cache``).
|
||
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
|
||
|
||
``timeout`` overrides the session default for this single request (seconds, total).
|
||
"""
|
||
# Fast-fail if the endpoint is known to be down (unless caller opts out)
|
||
if not skip_error_cache:
|
||
async with _available_error_cache_lock:
|
||
if endpoint in _available_error_cache:
|
||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||
return []
|
||
|
||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||
if api_key is not None:
|
||
headers["Authorization"] = "Bearer " + api_key
|
||
|
||
request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}"
|
||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||
req_kwargs = {}
|
||
if timeout is not None:
|
||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||
try:
|
||
async with client.get(request_url, headers=headers, **req_kwargs) as resp:
|
||
await _ensure_success(resp)
|
||
data = await resp.json()
|
||
detail = data.get(detail, [])
|
||
return detail
|
||
except Exception as e:
|
||
# If anything goes wrong we cannot reply details
|
||
message = _format_connection_issue(request_url, e)
|
||
print(f"[fetch.endpoint_details] {message}")
|
||
if not skip_error_cache:
|
||
async with _available_error_cache_lock:
|
||
_available_error_cache[endpoint] = time.time()
|
||
return []
|
||
|
||
|
||
# -------------------------------------------------------------
|
||
# Endpoint health probes (shared by /api/config and /health)
|
||
# -------------------------------------------------------------
|
||
async def _raw_probe(
|
||
ep: str,
|
||
route: str,
|
||
api_key: Optional[str] = None,
|
||
timeout: Optional[float] = None,
|
||
) -> tuple[bool, object]:
|
||
"""Direct HTTP probe that distinguishes success from failure
|
||
(unlike `fetch.endpoint_details`, which returns [] on either).
|
||
Returns `(ok, payload_or_error_message)`.
|
||
"""
|
||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||
if api_key is not None:
|
||
headers["Authorization"] = "Bearer " + api_key
|
||
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
|
||
req_kwargs = {}
|
||
if timeout is not None:
|
||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||
try:
|
||
client: aiohttp.ClientSession = get_probe_session(ep)
|
||
async with client.get(url, headers=headers, **req_kwargs) as resp:
|
||
await _ensure_success(resp)
|
||
data = await resp.json()
|
||
return True, data
|
||
except Exception as exc:
|
||
return False, _format_connection_issue(url, exc)
|
||
|
||
|
||
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
|
||
"""Probe an endpoint and return `{status, version?, detail?}`.
|
||
|
||
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
|
||
that a daemon which is reachable but has a broken model-introspection
|
||
path (issue #83) is reported as `error` rather than `ok`.
|
||
OpenAI-compatible endpoints use a single `/models` probe.
|
||
"""
|
||
if is_openai_compatible(ep):
|
||
ok, payload = await _raw_probe(
|
||
ep, "/models", get_config().api_keys.get(ep), timeout=timeout,
|
||
)
|
||
if ok:
|
||
return {"status": "ok", "version": "latest"}
|
||
return {"status": "error", "detail": str(payload)}
|
||
|
||
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
|
||
_raw_probe(ep, "/api/version", timeout=timeout),
|
||
_raw_probe(ep, "/api/ps", timeout=timeout),
|
||
)
|
||
|
||
version_value = (
|
||
version_payload.get("version")
|
||
if version_ok and isinstance(version_payload, dict)
|
||
else None
|
||
)
|
||
|
||
if version_ok and ps_ok:
|
||
return {"status": "ok", "version": version_value}
|
||
if not version_ok and not ps_ok:
|
||
return {"status": "error", "detail": str(version_payload)}
|
||
# Partial failure — daemon reachable but one probe failed. Report
|
||
# as "error" so callers can surface the issue; include `version` so
|
||
# the operator knows the daemon itself is alive.
|
||
if not ps_ok:
|
||
return {
|
||
"status": "error",
|
||
"version": version_value,
|
||
"detail": f"/api/ps: {ps_payload}",
|
||
}
|
||
return {
|
||
"status": "error",
|
||
"detail": f"/api/version: {version_payload}",
|
||
}
|