refac: modularize backend IV
This commit is contained in:
parent
c88ba1e5a4
commit
3a9854c5db
8 changed files with 822 additions and 666 deletions
0
backends/__init__.py
Normal file
0
backends/__init__.py
Normal file
136
backends/health.py
Normal file
136
backends/health.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Backend health probes and error classification helpers.
|
||||
|
||||
Contains:
|
||||
* cache-freshness check (``_is_fresh``)
|
||||
* aiohttp response success assertion (``_ensure_success``)
|
||||
* human-readable connection-issue formatter
|
||||
* upstream-error detection that distinguishes connection failures from
|
||||
legitimate 4xx responses (``_is_backend_connection_error``)
|
||||
* per-(endpoint, model) unhealthy marker that feeds ``choose_endpoint``
|
||||
* llama-server status interpretation (``_is_llama_model_loaded`` etc.)
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
from fastapi import HTTPException
|
||||
|
||||
from security import _mask_secrets
|
||||
from state import _completion_error_cache, _completion_error_cache_lock
|
||||
|
||||
|
||||
def _is_fresh(cached_at: float, ttl: int) -> bool:
|
||||
return (time.time() - cached_at) < ttl
|
||||
|
||||
|
||||
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
||||
if resp.status >= 400:
|
||||
text = await resp.text()
|
||||
raise HTTPException(status_code=resp.status, detail=_mask_secrets(text))
|
||||
|
||||
|
||||
def _format_connection_issue(url: str, error: Exception) -> str:
|
||||
"""
|
||||
Provide a human-friendly error string for connection failures so operators
|
||||
know which endpoint and address failed from inside the container.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
host_hint = parsed.hostname or ""
|
||||
port_hint = parsed.port or ""
|
||||
|
||||
if isinstance(error, aiohttp.ClientConnectorError):
|
||||
resolved_host = getattr(error, "host", host_hint) or host_hint or "?"
|
||||
resolved_port = getattr(error, "port", port_hint) or port_hint or "?"
|
||||
parts = [
|
||||
f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).",
|
||||
"Ensure the endpoint address is reachable from within the container.",
|
||||
]
|
||||
if resolved_host in {"localhost", "127.0.0.1"}:
|
||||
parts.append(
|
||||
"Inside Docker, 'localhost' refers to the container itself; use "
|
||||
"'host.docker.internal' or a Docker network alias if the service "
|
||||
"runs on the host machine."
|
||||
)
|
||||
os_error = getattr(error, "os_error", None)
|
||||
if isinstance(os_error, OSError):
|
||||
errno = getattr(os_error, "errno", None)
|
||||
strerror = os_error.strerror or str(os_error)
|
||||
if errno is not None or strerror:
|
||||
parts.append(f"OS error [{errno}]: {strerror}.")
|
||||
elif os_error:
|
||||
parts.append(f"OS error: {os_error}.")
|
||||
parts.append(f"Original error: {error}.")
|
||||
return " ".join(parts)
|
||||
|
||||
if isinstance(error, asyncio.TimeoutError):
|
||||
return (
|
||||
f"Timed out waiting for {url}. "
|
||||
"The remote endpoint may be offline or slow to respond."
|
||||
)
|
||||
|
||||
return f"Error while contacting {url}: {error}"
|
||||
|
||||
|
||||
def _is_backend_connection_error(exc: Exception) -> bool:
|
||||
"""True for upstream connection-class failures observed via the OpenAI client.
|
||||
|
||||
Targets the case where a llama-server in router mode keeps answering
|
||||
/v1/models but its delegated worker for a specific model is dead, so
|
||||
chat/completions calls return 5xx with 'proxy error: Could not establish
|
||||
connection' (or the SDK raises APIConnectionError outright).
|
||||
|
||||
Excludes BadRequestError with exceed_context_size_error by design — those
|
||||
must stay on the reactive-trim path.
|
||||
"""
|
||||
if isinstance(exc, openai.APIConnectionError):
|
||||
return True
|
||||
if isinstance(exc, openai.InternalServerError):
|
||||
msg = str(exc).lower()
|
||||
return (
|
||||
"proxy error" in msg
|
||||
or "could not establish connection" in msg
|
||||
or "connection refused" in msg
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None:
|
||||
"""Record (endpoint, model) as broken so choose_endpoint avoids it.
|
||||
|
||||
Cleared only by TTL — the dead-worker failure mode is invisible to the
|
||||
/v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot
|
||||
rely on a successful probe as a recovery signal.
|
||||
"""
|
||||
async with _completion_error_cache_lock:
|
||||
_completion_error_cache[(endpoint, model)] = time.time()
|
||||
print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True)
|
||||
|
||||
|
||||
def _is_llama_model_loaded(item: dict) -> bool:
|
||||
"""Return True if a llama-server /v1/models item has status 'loaded'.
|
||||
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
|
||||
If no status field is present, the model is always-loaded (not dynamically managed)."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True # No status field: model is always loaded (e.g. single-model servers)
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") == "loaded"
|
||||
if isinstance(status, str):
|
||||
return status == "loaded"
|
||||
return False
|
||||
|
||||
|
||||
def _is_llama_model_loaded_or_sleeping(item: dict) -> bool:
|
||||
"""Return True if status is 'loaded' or 'sleeping'.
|
||||
Newer llama-server versions report 'sleeping' in /v1/models when a model is idle;
|
||||
ps_details needs to include these so _fetch_llama_props can detect and unload them."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") in ("loaded", "sleeping")
|
||||
if isinstance(status, str):
|
||||
return status in ("loaded", "sleeping")
|
||||
return False
|
||||
113
backends/normalize.py
Normal file
113
backends/normalize.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
"""Endpoint URL, model-name, and endpoint-classification helpers.
|
||||
|
||||
The endpoint classifiers read live config via ``get_config()`` so that the
|
||||
startup-time rebind of ``config`` in router.py is picked up at call time.
|
||||
"""
|
||||
from config import get_config
|
||||
|
||||
|
||||
def _normalize_llama_model_name(name: str) -> str:
|
||||
"""Extract the model name from a huggingface-style identifier.
|
||||
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
|
||||
"""
|
||||
if "/" in name:
|
||||
name = name.rsplit("/", 1)[1]
|
||||
if ":" in name:
|
||||
name = name.split(":")[0]
|
||||
return name
|
||||
|
||||
|
||||
def _extract_llama_quant(name: str) -> str:
|
||||
"""Extract the quantization level from a huggingface-style identifier.
|
||||
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
|
||||
Returns empty string if no quant suffix is present.
|
||||
"""
|
||||
if ":" in name:
|
||||
return name.rsplit(":", 1)[1]
|
||||
return ""
|
||||
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
base_url = ep
|
||||
else:
|
||||
base_url = ep + "/v1"
|
||||
return base_url
|
||||
|
||||
|
||||
def dedupe_on_keys(dicts, key_fields):
|
||||
"""
|
||||
Helper function to deduplicate endpoint details based on given dict keys.
|
||||
"""
|
||||
seen = set()
|
||||
out = []
|
||||
for d in dicts:
|
||||
# Build a tuple of the values for the chosen keys
|
||||
key = tuple(d.get(k) for k in key_fields)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
out.append(d)
|
||||
return out
|
||||
|
||||
|
||||
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
||||
"""
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
||||
|
||||
Returns True for:
|
||||
- External services like OpenAI.com, Groq, etc.
|
||||
|
||||
Returns False for:
|
||||
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
||||
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
||||
"""
|
||||
cfg = get_config()
|
||||
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
return False
|
||||
|
||||
if "/v1" not in endpoint:
|
||||
return False
|
||||
|
||||
base_endpoint = endpoint.replace('/v1', '')
|
||||
if base_endpoint in cfg.endpoints:
|
||||
return False # It's Ollama's /v1
|
||||
|
||||
# Check for default Ollama port
|
||||
if ':11434' in endpoint:
|
||||
return False # It's Ollama
|
||||
|
||||
return True # It's an external OpenAI endpoint
|
||||
|
||||
|
||||
def is_openai_compatible(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
||||
This includes external OpenAI endpoints AND llama-server endpoints.
|
||||
"""
|
||||
return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints
|
||||
|
||||
|
||||
def get_tracking_model(endpoint: str, model: str) -> str:
|
||||
"""
|
||||
Normalize model name for tracking purposes so it matches the PS table key.
|
||||
|
||||
- For llama-server endpoints: strips HF prefix and quantization suffix
|
||||
- For Ollama endpoints: appends ":latest" if no version suffix is present
|
||||
- For external OpenAI endpoints: returns as-is (not shown in PS)
|
||||
|
||||
This ensures consistent model naming across all routes for usage tracking.
|
||||
"""
|
||||
# External OpenAI endpoints are not shown in PS, keep as-is
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return model
|
||||
|
||||
# llama-server endpoints use normalized names in PS
|
||||
if endpoint in get_config().llama_server_endpoints:
|
||||
return _normalize_llama_model_name(model)
|
||||
|
||||
# Ollama endpoints: append ":latest" if no version suffix
|
||||
if ":" not in model:
|
||||
return model + ":latest"
|
||||
|
||||
return model
|
||||
449
backends/probe.py
Normal file
449
backends/probe.py
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
"""Backend probe / discovery primitives.
|
||||
|
||||
The ``fetch`` class wraps the three discovery paths the router uses:
|
||||
* ``available_models`` — what the endpoint advertises (Ollama ``/api/tags``
|
||||
or OpenAI-style ``/v1/models``)
|
||||
* ``loaded_models`` — what is currently resident (Ollama ``/api/ps`` or
|
||||
llama-server ``/v1/models`` filtered on ``status == "loaded"``)
|
||||
* ``endpoint_details`` — arbitrary detail fetch used by management routes
|
||||
|
||||
Each path goes through three layers of cache: success cache, error cache,
|
||||
and an in-flight request map. Stale-while-revalidate refreshes happen in
|
||||
background tasks tracked by the ``_bg_refresh_*`` maps in ``state``.
|
||||
|
||||
``_raw_probe`` and ``_endpoint_health`` are the lower-level dual probes
|
||||
used by ``/health`` and ``/api/config`` to distinguish a healthy daemon
|
||||
with a broken model-introspection path from a dead daemon.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import aiohttp
|
||||
|
||||
from config import get_config
|
||||
from state import (
|
||||
_models_cache,
|
||||
_models_cache_lock,
|
||||
_loaded_models_cache,
|
||||
_loaded_models_cache_lock,
|
||||
_available_error_cache,
|
||||
_available_error_cache_lock,
|
||||
_loaded_error_cache,
|
||||
_loaded_error_cache_lock,
|
||||
_inflight_available_models,
|
||||
_inflight_loaded_models,
|
||||
_inflight_lock,
|
||||
_bg_refresh_available,
|
||||
_bg_refresh_loaded,
|
||||
_bg_refresh_lock,
|
||||
default_headers,
|
||||
)
|
||||
from backends.sessions import get_session
|
||||
from backends.health import (
|
||||
_is_fresh,
|
||||
_ensure_success,
|
||||
_format_connection_issue,
|
||||
_is_llama_model_loaded,
|
||||
)
|
||||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible
|
||||
|
||||
|
||||
class fetch:
|
||||
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Internal function that performs the actual HTTP request to fetch available models.
|
||||
This is called by available_models() after checking caches and in-flight requests.
|
||||
"""
|
||||
cfg = get_config()
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
ep_base = endpoint.rstrip("/")
|
||||
if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint:
|
||||
endpoint_url = f"{ep_base}/v1/models"
|
||||
key = "data"
|
||||
elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints:
|
||||
endpoint_url = f"{ep_base}/models"
|
||||
key = "data"
|
||||
else:
|
||||
endpoint_url = f"{ep_base}/api/tags"
|
||||
key = "models"
|
||||
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
try:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
async with _models_cache_lock:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
message = _format_connection_issue(endpoint_url, e)
|
||||
print(f"[fetch.available_models] {message}")
|
||||
# Update error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
_available_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
|
||||
"""
|
||||
Background task to refresh available models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_bg_refresh_available[endpoint] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_available.get(endpoint) is task:
|
||||
_bg_refresh_available.pop(endpoint, None)
|
||||
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
|
||||
Uses request coalescing to prevent cache stampede: if multiple requests
|
||||
arrive when cache is expired, only one actual HTTP request is made.
|
||||
|
||||
Uses stale-while-revalidate: when the cache is between 300-600s old,
|
||||
the stale data is returned immediately while a background refresh runs.
|
||||
This prevents model blackouts caused by transient timeouts.
|
||||
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
# Check models cache with lock protection
|
||||
async with _models_cache_lock:
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
|
||||
# FRESH: <= 300s old - return immediately
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
|
||||
# STALE: 300-600s old - return stale data and refresh in background
|
||||
if _is_fresh(cached_at, 600):
|
||||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||
return models # Return stale data immediately
|
||||
|
||||
# EXPIRED: > 600s old - too stale, must refresh synchronously
|
||||
del _models_cache[endpoint]
|
||||
|
||||
# Check error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
err_age = time.time() - _available_error_cache[endpoint]
|
||||
if err_age < 30:
|
||||
# Very fresh error (<30s) – endpoint likely still down, bail fast
|
||||
return set()
|
||||
elif err_age < 300:
|
||||
# Stale error (30-300s) – endpoint may have recovered, probe in background
|
||||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||
return set()
|
||||
# Error expired (>300s) – remove and fall through to fresh fetch
|
||||
del _available_error_cache[endpoint]
|
||||
|
||||
# Request coalescing: check if another request is already fetching this endpoint
|
||||
async with _inflight_lock:
|
||||
if endpoint in _inflight_available_models:
|
||||
# Another request is already fetching - wait for it
|
||||
task = _inflight_available_models[endpoint]
|
||||
else:
|
||||
# Create new fetch task
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_inflight_available_models[endpoint] = task
|
||||
|
||||
try:
|
||||
# Wait for the fetch to complete (either ours or another request's)
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up in-flight tracking (only if we created it)
|
||||
async with _inflight_lock:
|
||||
if _inflight_available_models.get(endpoint) == task:
|
||||
_inflight_available_models.pop(endpoint, None)
|
||||
|
||||
|
||||
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Internal function that performs the actual HTTP request to fetch loaded models.
|
||||
This is called by loaded_models() after checking caches and in-flight requests.
|
||||
|
||||
For Ollama endpoints: queries /api/ps and returns model names
|
||||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||||
"""
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
|
||||
# Check if this is a llama-server endpoint
|
||||
if endpoint in get_config().llama_server_endpoints:
|
||||
# Query /v1/models for llama-server
|
||||
try:
|
||||
async with client.get(f"{endpoint}/models") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
# Filter for loaded models only
|
||||
items = data.get("data", [])
|
||||
models = {
|
||||
item.get("id")
|
||||
for item in items
|
||||
if item.get("id") and _is_llama_model_loaded(item)
|
||||
}
|
||||
|
||||
# Update cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
# Probe succeeded — clear any stale error so the endpoint
|
||||
# becomes routable again.
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
message = _format_connection_issue(f"{endpoint}/models", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
# Record the failure so `choose_endpoint` can avoid routing
|
||||
# to an unhealthy backend and repeated probes short-circuit.
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
else:
|
||||
# Original Ollama /api/ps logic
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
||||
# Update cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
message = _format_connection_issue(f"{endpoint}/api/ps", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
async def _refresh_loaded_models(endpoint: str) -> None:
|
||||
"""
|
||||
Background task to refresh loaded models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_bg_refresh_loaded[endpoint] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_loaded.get(endpoint) is task:
|
||||
_bg_refresh_loaded.pop(endpoint, None)
|
||||
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
|
||||
Uses request coalescing to prevent cache stampede and stale-while-revalidate
|
||||
to serve requests immediately even when cache is stale (refreshing in background).
|
||||
"""
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return set()
|
||||
|
||||
# Check loaded models cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
if endpoint in _loaded_models_cache:
|
||||
models, cached_at = _loaded_models_cache[endpoint]
|
||||
|
||||
# FRESH: < 10s old - return immediately
|
||||
if _is_fresh(cached_at, 10):
|
||||
return models
|
||||
|
||||
# STALE: 10-60s old - return stale data and refresh in background
|
||||
if _is_fresh(cached_at, 60):
|
||||
# Kick off background refresh (fire-and-forget)
|
||||
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
|
||||
return models # Return stale data immediately
|
||||
|
||||
# EXPIRED: > 60s old - too stale, must refresh synchronously
|
||||
del _loaded_models_cache[endpoint]
|
||||
|
||||
# Check error cache with lock protection
|
||||
async with _loaded_error_cache_lock:
|
||||
if endpoint in _loaded_error_cache:
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
||||
return set()
|
||||
# Error expired - remove it
|
||||
del _loaded_error_cache[endpoint]
|
||||
|
||||
# Request coalescing: check if another request is already fetching this endpoint
|
||||
async with _inflight_lock:
|
||||
if endpoint in _inflight_loaded_models:
|
||||
# Another request is already fetching - wait for it
|
||||
task = _inflight_loaded_models[endpoint]
|
||||
else:
|
||||
# Create new fetch task
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_inflight_loaded_models[endpoint] = task
|
||||
|
||||
try:
|
||||
# Wait for the fetch to complete (either ours or another request's)
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up in-flight tracking (only if we created it)
|
||||
async with _inflight_lock:
|
||||
if _inflight_loaded_models.get(endpoint) == task:
|
||||
_inflight_loaded_models.pop(endpoint, None)
|
||||
|
||||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
|
||||
When ``skip_error_cache`` is False (the default), the call is short-circuited
|
||||
if the endpoint recently failed (recorded in ``_available_error_cache``).
|
||||
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
|
||||
|
||||
``timeout`` overrides the session default for this single request (seconds, total).
|
||||
"""
|
||||
# Fast-fail if the endpoint is known to be down (unless caller opts out)
|
||||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
return []
|
||||
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}"
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
req_kwargs = {}
|
||||
if timeout is not None:
|
||||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||||
try:
|
||||
async with client.get(request_url, headers=headers, **req_kwargs) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
message = _format_connection_issue(request_url, e)
|
||||
print(f"[fetch.endpoint_details] {message}")
|
||||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
_available_error_cache[endpoint] = time.time()
|
||||
return []
|
||||
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Endpoint health probes (shared by /api/config and /health)
|
||||
# -------------------------------------------------------------
|
||||
async def _raw_probe(
|
||||
ep: str,
|
||||
route: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> tuple[bool, object]:
|
||||
"""Direct HTTP probe that distinguishes success from failure
|
||||
(unlike `fetch.endpoint_details`, which returns [] on either).
|
||||
Returns `(ok, payload_or_error_message)`.
|
||||
"""
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
|
||||
req_kwargs = {}
|
||||
if timeout is not None:
|
||||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||||
try:
|
||||
client: aiohttp.ClientSession = get_session(ep)
|
||||
async with client.get(url, headers=headers, **req_kwargs) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
return True, data
|
||||
except Exception as exc:
|
||||
return False, _format_connection_issue(url, exc)
|
||||
|
||||
|
||||
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
|
||||
"""Probe an endpoint and return `{status, version?, detail?}`.
|
||||
|
||||
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
|
||||
that a daemon which is reachable but has a broken model-introspection
|
||||
path (issue #83) is reported as `error` rather than `ok`.
|
||||
OpenAI-compatible endpoints use a single `/models` probe.
|
||||
"""
|
||||
if is_openai_compatible(ep):
|
||||
ok, payload = await _raw_probe(
|
||||
ep, "/models", get_config().api_keys.get(ep), timeout=timeout,
|
||||
)
|
||||
if ok:
|
||||
return {"status": "ok", "version": "latest"}
|
||||
return {"status": "error", "detail": str(payload)}
|
||||
|
||||
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
|
||||
_raw_probe(ep, "/api/version", timeout=timeout),
|
||||
_raw_probe(ep, "/api/ps", timeout=timeout),
|
||||
)
|
||||
|
||||
version_value = (
|
||||
version_payload.get("version")
|
||||
if version_ok and isinstance(version_payload, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
if version_ok and ps_ok:
|
||||
return {"status": "ok", "version": version_value}
|
||||
if not version_ok and not ps_ok:
|
||||
return {"status": "error", "detail": str(version_payload)}
|
||||
# Partial failure — daemon reachable but one probe failed. Report
|
||||
# as "error" so callers can surface the issue; include `version` so
|
||||
# the operator knows the daemon itself is alive.
|
||||
if not ps_ok:
|
||||
return {
|
||||
"status": "error",
|
||||
"version": version_value,
|
||||
"detail": f"/api/ps: {ps_payload}",
|
||||
}
|
||||
return {
|
||||
"status": "error",
|
||||
"detail": f"/api/version: {version_payload}",
|
||||
}
|
||||
72
backends/sessions.py
Normal file
72
backends/sessions.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""aiohttp / OpenAI client factories aware of Unix-socket endpoints.
|
||||
|
||||
Unix socket endpoints follow the ``.sock`` hostname convention (e.g.
|
||||
``http://192.168.0.52.sock/v1``) and resolve to ``/run/user/<uid>/<host>``.
|
||||
Their sessions/clients live in ``state.app_state`` so that startup can
|
||||
populate them once and routes can reuse them.
|
||||
"""
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
|
||||
from state import app_state
|
||||
from backends.normalize import ep2base
|
||||
|
||||
|
||||
def _is_unix_socket_endpoint(endpoint: str) -> bool:
|
||||
"""Return True if endpoint uses Unix socket (.sock hostname convention).
|
||||
|
||||
Detects URLs like http://192.168.0.52.sock/v1 where the host ends with
|
||||
.sock, indicating the connection should use a Unix domain socket at
|
||||
/tmp/<host> instead of TCP.
|
||||
"""
|
||||
try:
|
||||
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
|
||||
return host.endswith(".sock")
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_socket_path(endpoint: str) -> str:
|
||||
"""Derive Unix socket file path from a .sock endpoint URL.
|
||||
|
||||
http://192.168.0.52.sock/v1 -> /run/user/<uid>/192.168.0.52.sock
|
||||
"""
|
||||
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
|
||||
return f"/run/user/{os.getuid()}/{host}"
|
||||
|
||||
|
||||
def get_session(endpoint: str) -> aiohttp.ClientSession:
|
||||
"""Return the appropriate aiohttp session for the given endpoint.
|
||||
|
||||
Unix socket endpoints (.sock) get their own UnixConnector session.
|
||||
All other endpoints share the main TCP session.
|
||||
"""
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
sess = app_state["socket_sessions"].get(endpoint)
|
||||
if sess is not None:
|
||||
return sess
|
||||
return app_state["session"]
|
||||
|
||||
|
||||
def _make_openai_client(
|
||||
endpoint: str,
|
||||
default_headers: dict | None = None,
|
||||
api_key: str = "no-key",
|
||||
) -> openai.AsyncOpenAI:
|
||||
"""Return an AsyncOpenAI client configured for the given endpoint.
|
||||
|
||||
For Unix socket endpoints, injects a pre-created httpx UDS transport
|
||||
so the OpenAI SDK connects via the socket instead of TCP.
|
||||
"""
|
||||
base_url = ep2base(endpoint)
|
||||
kwargs: dict = {"api_key": api_key}
|
||||
if default_headers is not None:
|
||||
kwargs["default_headers"] = default_headers
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
http_client = app_state["httpx_clients"].get(endpoint)
|
||||
if http_client is not None:
|
||||
kwargs["http_client"] = http_client
|
||||
base_url = "http://localhost/v1"
|
||||
return openai.AsyncOpenAI(base_url=base_url, **kwargs)
|
||||
13
config.py
13
config.py
|
|
@ -124,3 +124,16 @@ def _config_path_from_env() -> Path:
|
|||
if candidate:
|
||||
return Path(candidate).expanduser()
|
||||
return Path("config.yaml")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared config accessor
|
||||
# ------------------------------------------------------------------
|
||||
# Submodules read config at call time via get_config() instead of importing
|
||||
# a bound name. The single source of truth is ``router.config`` — the lazy
|
||||
# import below resolves it after router.py has finished loading, and lets
|
||||
# tests that ``patch.object(router, "config", cfg)`` flow through.
|
||||
def get_config() -> "Config":
|
||||
"""Return the currently active Config from router.py."""
|
||||
import router # lazy to avoid module-load circular import
|
||||
return router.config
|
||||
|
|
|
|||
698
router.py
698
router.py
|
|
@ -75,7 +75,8 @@ from db import TokenDatabase
|
|||
from cache import init_llm_cache, get_llm_cache, openai_nonstream_to_sse
|
||||
|
||||
|
||||
# Create the global config object – it will be overwritten on startup
|
||||
# Create the global config object – it will be overwritten on startup.
|
||||
# Submodules read it lazily via config.get_config().
|
||||
config = Config.from_yaml(_config_path_from_env())
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -90,11 +91,7 @@ app.add_middleware(
|
|||
allow_methods=["GET", "POST", "DELETE"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
default_headers={
|
||||
"HTTP-Referer": "https://nomyo.ai",
|
||||
"Referer": "https://nomyo.ai",
|
||||
"X-Title": "NOMYO Router",
|
||||
}
|
||||
from state import default_headers
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Router-level authentication (optional)
|
||||
|
|
@ -205,254 +202,36 @@ from fingerprint import _conversation_fingerprint
|
|||
db: "TokenDatabase" = None
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 4. Helperfunctions
|
||||
# 4. Helperfunctions
|
||||
# -------------------------------------------------------------
|
||||
def _is_fresh(cached_at: float, ttl: int) -> bool:
|
||||
return (time.time() - cached_at) < ttl
|
||||
|
||||
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
||||
if resp.status >= 400:
|
||||
text = await resp.text()
|
||||
raise HTTPException(status_code=resp.status, detail=_mask_secrets(text))
|
||||
|
||||
def _format_connection_issue(url: str, error: Exception) -> str:
|
||||
"""
|
||||
Provide a human-friendly error string for connection failures so operators
|
||||
know which endpoint and address failed from inside the container.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
host_hint = parsed.hostname or ""
|
||||
port_hint = parsed.port or ""
|
||||
|
||||
if isinstance(error, aiohttp.ClientConnectorError):
|
||||
resolved_host = getattr(error, "host", host_hint) or host_hint or "?"
|
||||
resolved_port = getattr(error, "port", port_hint) or port_hint or "?"
|
||||
parts = [
|
||||
f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).",
|
||||
"Ensure the endpoint address is reachable from within the container.",
|
||||
]
|
||||
if resolved_host in {"localhost", "127.0.0.1"}:
|
||||
parts.append(
|
||||
"Inside Docker, 'localhost' refers to the container itself; use "
|
||||
"'host.docker.internal' or a Docker network alias if the service "
|
||||
"runs on the host machine."
|
||||
)
|
||||
os_error = getattr(error, "os_error", None)
|
||||
if isinstance(os_error, OSError):
|
||||
errno = getattr(os_error, "errno", None)
|
||||
strerror = os_error.strerror or str(os_error)
|
||||
if errno is not None or strerror:
|
||||
parts.append(f"OS error [{errno}]: {strerror}.")
|
||||
elif os_error:
|
||||
parts.append(f"OS error: {os_error}.")
|
||||
parts.append(f"Original error: {error}.")
|
||||
return " ".join(parts)
|
||||
|
||||
if isinstance(error, asyncio.TimeoutError):
|
||||
return (
|
||||
f"Timed out waiting for {url}. "
|
||||
"The remote endpoint may be offline or slow to respond."
|
||||
)
|
||||
|
||||
return f"Error while contacting {url}: {error}"
|
||||
|
||||
def _normalize_llama_model_name(name: str) -> str:
|
||||
"""Extract the model name from a huggingface-style identifier.
|
||||
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
|
||||
"""
|
||||
if "/" in name:
|
||||
name = name.rsplit("/", 1)[1]
|
||||
if ":" in name:
|
||||
name = name.split(":")[0]
|
||||
return name
|
||||
|
||||
def _extract_llama_quant(name: str) -> str:
|
||||
"""Extract the quantization level from a huggingface-style identifier.
|
||||
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
|
||||
Returns empty string if no quant suffix is present.
|
||||
"""
|
||||
if ":" in name:
|
||||
return name.rsplit(":", 1)[1]
|
||||
return ""
|
||||
from backends.normalize import (
|
||||
_normalize_llama_model_name,
|
||||
_extract_llama_quant,
|
||||
ep2base,
|
||||
dedupe_on_keys,
|
||||
)
|
||||
from backends.sessions import (
|
||||
_is_unix_socket_endpoint,
|
||||
_get_socket_path,
|
||||
get_session,
|
||||
_make_openai_client,
|
||||
)
|
||||
from backends.health import (
|
||||
_is_fresh,
|
||||
_ensure_success,
|
||||
_format_connection_issue,
|
||||
_is_backend_connection_error,
|
||||
_mark_backend_unhealthy,
|
||||
_is_llama_model_loaded,
|
||||
_is_llama_model_loaded_or_sleeping,
|
||||
)
|
||||
|
||||
|
||||
def _is_unix_socket_endpoint(endpoint: str) -> bool:
|
||||
"""Return True if endpoint uses Unix socket (.sock hostname convention).
|
||||
|
||||
Detects URLs like http://192.168.0.52.sock/v1 where the host ends with
|
||||
.sock, indicating the connection should use a Unix domain socket at
|
||||
/tmp/<host> instead of TCP.
|
||||
"""
|
||||
try:
|
||||
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
|
||||
return host.endswith(".sock")
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_socket_path(endpoint: str) -> str:
|
||||
"""Derive Unix socket file path from a .sock endpoint URL.
|
||||
|
||||
http://192.168.0.52.sock/v1 -> /run/user/<uid>/192.168.0.52.sock
|
||||
"""
|
||||
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
|
||||
return f"/run/user/{os.getuid()}/{host}"
|
||||
|
||||
|
||||
def get_session(endpoint: str) -> aiohttp.ClientSession:
|
||||
"""Return the appropriate aiohttp session for the given endpoint.
|
||||
|
||||
Unix socket endpoints (.sock) get their own UnixConnector session.
|
||||
All other endpoints share the main TCP session.
|
||||
"""
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
sess = app_state["socket_sessions"].get(endpoint)
|
||||
if sess is not None:
|
||||
return sess
|
||||
return app_state["session"]
|
||||
|
||||
|
||||
def _make_openai_client(
|
||||
endpoint: str,
|
||||
default_headers: dict | None = None,
|
||||
api_key: str = "no-key",
|
||||
) -> openai.AsyncOpenAI:
|
||||
"""Return an AsyncOpenAI client configured for the given endpoint.
|
||||
|
||||
For Unix socket endpoints, injects a pre-created httpx UDS transport
|
||||
so the OpenAI SDK connects via the socket instead of TCP.
|
||||
"""
|
||||
base_url = ep2base(endpoint)
|
||||
kwargs: dict = {"api_key": api_key}
|
||||
if default_headers is not None:
|
||||
kwargs["default_headers"] = default_headers
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
http_client = app_state["httpx_clients"].get(endpoint)
|
||||
if http_client is not None:
|
||||
kwargs["http_client"] = http_client
|
||||
base_url = "http://localhost/v1"
|
||||
return openai.AsyncOpenAI(base_url=base_url, **kwargs)
|
||||
|
||||
|
||||
def _is_backend_connection_error(exc: Exception) -> bool:
|
||||
"""True for upstream connection-class failures observed via the OpenAI client.
|
||||
|
||||
Targets the case where a llama-server in router mode keeps answering
|
||||
/v1/models but its delegated worker for a specific model is dead, so
|
||||
chat/completions calls return 5xx with 'proxy error: Could not establish
|
||||
connection' (or the SDK raises APIConnectionError outright).
|
||||
|
||||
Excludes BadRequestError with exceed_context_size_error by design — those
|
||||
must stay on the reactive-trim path.
|
||||
"""
|
||||
if isinstance(exc, openai.APIConnectionError):
|
||||
return True
|
||||
if isinstance(exc, openai.InternalServerError):
|
||||
msg = str(exc).lower()
|
||||
return (
|
||||
"proxy error" in msg
|
||||
or "could not establish connection" in msg
|
||||
or "connection refused" in msg
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None:
|
||||
"""Record (endpoint, model) as broken so choose_endpoint avoids it.
|
||||
|
||||
Cleared only by TTL — the dead-worker failure mode is invisible to the
|
||||
/v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot
|
||||
rely on a successful probe as a recovery signal.
|
||||
"""
|
||||
async with _completion_error_cache_lock:
|
||||
_completion_error_cache[(endpoint, model)] = time.time()
|
||||
print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True)
|
||||
|
||||
|
||||
def _is_llama_model_loaded(item: dict) -> bool:
|
||||
"""Return True if a llama-server /v1/models item has status 'loaded'.
|
||||
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
|
||||
If no status field is present, the model is always-loaded (not dynamically managed)."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True # No status field: model is always loaded (e.g. single-model servers)
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") == "loaded"
|
||||
if isinstance(status, str):
|
||||
return status == "loaded"
|
||||
return False
|
||||
|
||||
def _is_llama_model_loaded_or_sleeping(item: dict) -> bool:
|
||||
"""Return True if status is 'loaded' or 'sleeping'.
|
||||
Newer llama-server versions report 'sleeping' in /v1/models when a model is idle;
|
||||
ps_details needs to include these so _fetch_llama_props can detect and unload them."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") in ("loaded", "sleeping")
|
||||
if isinstance(status, str):
|
||||
return status in ("loaded", "sleeping")
|
||||
return False
|
||||
|
||||
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
||||
"""
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
||||
|
||||
Returns True for:
|
||||
- External services like OpenAI.com, Groq, etc.
|
||||
|
||||
Returns False for:
|
||||
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
||||
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
||||
"""
|
||||
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
return False
|
||||
|
||||
if "/v1" not in endpoint:
|
||||
return False
|
||||
|
||||
base_endpoint = endpoint.replace('/v1', '')
|
||||
if base_endpoint in config.endpoints:
|
||||
return False # It's Ollama's /v1
|
||||
|
||||
# Check for default Ollama port
|
||||
if ':11434' in endpoint:
|
||||
return False # It's Ollama
|
||||
|
||||
return True # It's an external OpenAI endpoint
|
||||
|
||||
def is_openai_compatible(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
||||
This includes external OpenAI endpoints AND llama-server endpoints.
|
||||
"""
|
||||
return "/v1" in endpoint or endpoint in config.llama_server_endpoints
|
||||
|
||||
def get_tracking_model(endpoint: str, model: str) -> str:
|
||||
"""
|
||||
Normalize model name for tracking purposes so it matches the PS table key.
|
||||
|
||||
- For llama-server endpoints: strips HF prefix and quantization suffix
|
||||
- For Ollama endpoints: appends ":latest" if no version suffix is present
|
||||
- For external OpenAI endpoints: returns as-is (not shown in PS)
|
||||
|
||||
This ensures consistent model naming across all routes for usage tracking.
|
||||
"""
|
||||
# External OpenAI endpoints are not shown in PS, keep as-is
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return model
|
||||
|
||||
# llama-server endpoints use normalized names in PS
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
return _normalize_llama_model_name(model)
|
||||
|
||||
# Ollama endpoints: append ":latest" if no version suffix
|
||||
if ":" not in model:
|
||||
return model + ":latest"
|
||||
|
||||
return model
|
||||
from backends.normalize import (
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
get_tracking_model,
|
||||
)
|
||||
|
||||
async def token_worker() -> None:
|
||||
try:
|
||||
|
|
@ -601,348 +380,8 @@ async def flush_remaining_buffers() -> None:
|
|||
# Do not raise during shutdown – log and continue teardown
|
||||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
||||
|
||||
class fetch:
|
||||
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Internal function that performs the actual HTTP request to fetch available models.
|
||||
This is called by available_models() after checking caches and in-flight requests.
|
||||
"""
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
from backends.probe import fetch
|
||||
|
||||
ep_base = endpoint.rstrip("/")
|
||||
if endpoint in config.llama_server_endpoints and "/v1" not in endpoint:
|
||||
endpoint_url = f"{ep_base}/v1/models"
|
||||
key = "data"
|
||||
elif "/v1" in endpoint or endpoint in config.llama_server_endpoints:
|
||||
endpoint_url = f"{ep_base}/models"
|
||||
key = "data"
|
||||
else:
|
||||
endpoint_url = f"{ep_base}/api/tags"
|
||||
key = "models"
|
||||
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
try:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
async with _models_cache_lock:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
message = _format_connection_issue(endpoint_url, e)
|
||||
print(f"[fetch.available_models] {message}")
|
||||
# Update error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
_available_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
|
||||
"""
|
||||
Background task to refresh available models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_bg_refresh_available[endpoint] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_available.get(endpoint) is task:
|
||||
_bg_refresh_available.pop(endpoint, None)
|
||||
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
|
||||
Uses request coalescing to prevent cache stampede: if multiple requests
|
||||
arrive when cache is expired, only one actual HTTP request is made.
|
||||
|
||||
Uses stale-while-revalidate: when the cache is between 300-600s old,
|
||||
the stale data is returned immediately while a background refresh runs.
|
||||
This prevents model blackouts caused by transient timeouts.
|
||||
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
# Check models cache with lock protection
|
||||
async with _models_cache_lock:
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
|
||||
# FRESH: <= 300s old - return immediately
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
|
||||
# STALE: 300-600s old - return stale data and refresh in background
|
||||
if _is_fresh(cached_at, 600):
|
||||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||
return models # Return stale data immediately
|
||||
|
||||
# EXPIRED: > 600s old - too stale, must refresh synchronously
|
||||
del _models_cache[endpoint]
|
||||
|
||||
# Check error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
err_age = time.time() - _available_error_cache[endpoint]
|
||||
if err_age < 30:
|
||||
# Very fresh error (<30s) – endpoint likely still down, bail fast
|
||||
return set()
|
||||
elif err_age < 300:
|
||||
# Stale error (30-300s) – endpoint may have recovered, probe in background
|
||||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||
return set()
|
||||
# Error expired (>300s) – remove and fall through to fresh fetch
|
||||
del _available_error_cache[endpoint]
|
||||
|
||||
# Request coalescing: check if another request is already fetching this endpoint
|
||||
async with _inflight_lock:
|
||||
if endpoint in _inflight_available_models:
|
||||
# Another request is already fetching - wait for it
|
||||
task = _inflight_available_models[endpoint]
|
||||
else:
|
||||
# Create new fetch task
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_inflight_available_models[endpoint] = task
|
||||
|
||||
try:
|
||||
# Wait for the fetch to complete (either ours or another request's)
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up in-flight tracking (only if we created it)
|
||||
async with _inflight_lock:
|
||||
if _inflight_available_models.get(endpoint) == task:
|
||||
_inflight_available_models.pop(endpoint, None)
|
||||
|
||||
|
||||
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Internal function that performs the actual HTTP request to fetch loaded models.
|
||||
This is called by loaded_models() after checking caches and in-flight requests.
|
||||
|
||||
For Ollama endpoints: queries /api/ps and returns model names
|
||||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||||
"""
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
|
||||
# Check if this is a llama-server endpoint
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# Query /v1/models for llama-server
|
||||
try:
|
||||
async with client.get(f"{endpoint}/models") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
# Filter for loaded models only
|
||||
items = data.get("data", [])
|
||||
models = {
|
||||
item.get("id")
|
||||
for item in items
|
||||
if item.get("id") and _is_llama_model_loaded(item)
|
||||
}
|
||||
|
||||
# Update cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
# Probe succeeded — clear any stale error so the endpoint
|
||||
# becomes routable again.
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
message = _format_connection_issue(f"{endpoint}/models", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
# Record the failure so `choose_endpoint` can avoid routing
|
||||
# to an unhealthy backend and repeated probes short-circuit.
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
else:
|
||||
# Original Ollama /api/ps logic
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
||||
# Update cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
message = _format_connection_issue(f"{endpoint}/api/ps", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
async def _refresh_loaded_models(endpoint: str) -> None:
|
||||
"""
|
||||
Background task to refresh loaded models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_bg_refresh_loaded[endpoint] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_loaded.get(endpoint) is task:
|
||||
_bg_refresh_loaded.pop(endpoint, None)
|
||||
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
|
||||
Uses request coalescing to prevent cache stampede and stale-while-revalidate
|
||||
to serve requests immediately even when cache is stale (refreshing in background).
|
||||
"""
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return set()
|
||||
|
||||
# Check loaded models cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
if endpoint in _loaded_models_cache:
|
||||
models, cached_at = _loaded_models_cache[endpoint]
|
||||
|
||||
# FRESH: < 10s old - return immediately
|
||||
if _is_fresh(cached_at, 10):
|
||||
return models
|
||||
|
||||
# STALE: 10-60s old - return stale data and refresh in background
|
||||
if _is_fresh(cached_at, 60):
|
||||
# Kick off background refresh (fire-and-forget)
|
||||
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
|
||||
return models # Return stale data immediately
|
||||
|
||||
# EXPIRED: > 60s old - too stale, must refresh synchronously
|
||||
del _loaded_models_cache[endpoint]
|
||||
|
||||
# Check error cache with lock protection
|
||||
async with _loaded_error_cache_lock:
|
||||
if endpoint in _loaded_error_cache:
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
||||
return set()
|
||||
# Error expired - remove it
|
||||
del _loaded_error_cache[endpoint]
|
||||
|
||||
# Request coalescing: check if another request is already fetching this endpoint
|
||||
async with _inflight_lock:
|
||||
if endpoint in _inflight_loaded_models:
|
||||
# Another request is already fetching - wait for it
|
||||
task = _inflight_loaded_models[endpoint]
|
||||
else:
|
||||
# Create new fetch task
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_inflight_loaded_models[endpoint] = task
|
||||
|
||||
try:
|
||||
# Wait for the fetch to complete (either ours or another request's)
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up in-flight tracking (only if we created it)
|
||||
async with _inflight_lock:
|
||||
if _inflight_loaded_models.get(endpoint) == task:
|
||||
_inflight_loaded_models.pop(endpoint, None)
|
||||
|
||||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
|
||||
When ``skip_error_cache`` is False (the default), the call is short-circuited
|
||||
if the endpoint recently failed (recorded in ``_available_error_cache``).
|
||||
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
|
||||
|
||||
``timeout`` overrides the session default for this single request (seconds, total).
|
||||
"""
|
||||
# Fast-fail if the endpoint is known to be down (unless caller opts out)
|
||||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
return []
|
||||
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}"
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
req_kwargs = {}
|
||||
if timeout is not None:
|
||||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||||
try:
|
||||
async with client.get(request_url, headers=headers, **req_kwargs) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
message = _format_connection_issue(request_url, e)
|
||||
print(f"[fetch.endpoint_details] {message}")
|
||||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
_available_error_cache[endpoint] = time.time()
|
||||
return []
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
base_url = ep
|
||||
else:
|
||||
base_url = ep+"/v1"
|
||||
return base_url
|
||||
|
||||
def dedupe_on_keys(dicts, key_fields):
|
||||
"""
|
||||
Helper function to deduplicate endpoint details based on given dict keys.
|
||||
"""
|
||||
seen = set()
|
||||
out = []
|
||||
for d in dicts:
|
||||
# Build a tuple of the values for the chosen keys
|
||||
key = tuple(d.get(k) for k in key_fields)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
out.append(d)
|
||||
return out
|
||||
|
||||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
|
|
@ -2910,80 +2349,7 @@ async def usage_proxy(request: Request):
|
|||
return {"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts}
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 20. Endpoint health probes (shared by /api/config and /health)
|
||||
# -------------------------------------------------------------
|
||||
async def _raw_probe(
|
||||
ep: str,
|
||||
route: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> tuple[bool, object]:
|
||||
"""Direct HTTP probe that distinguishes success from failure
|
||||
(unlike `fetch.endpoint_details`, which returns [] on either).
|
||||
Returns `(ok, payload_or_error_message)`.
|
||||
"""
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
|
||||
req_kwargs = {}
|
||||
if timeout is not None:
|
||||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||||
try:
|
||||
client: aiohttp.ClientSession = get_session(ep)
|
||||
async with client.get(url, headers=headers, **req_kwargs) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
return True, data
|
||||
except Exception as exc:
|
||||
return False, _format_connection_issue(url, exc)
|
||||
|
||||
|
||||
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
|
||||
"""Probe an endpoint and return `{status, version?, detail?}`.
|
||||
|
||||
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
|
||||
that a daemon which is reachable but has a broken model-introspection
|
||||
path (issue #83) is reported as `error` rather than `ok`.
|
||||
OpenAI-compatible endpoints use a single `/models` probe.
|
||||
"""
|
||||
if is_openai_compatible(ep):
|
||||
ok, payload = await _raw_probe(
|
||||
ep, "/models", config.api_keys.get(ep), timeout=timeout,
|
||||
)
|
||||
if ok:
|
||||
return {"status": "ok", "version": "latest"}
|
||||
return {"status": "error", "detail": str(payload)}
|
||||
|
||||
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
|
||||
_raw_probe(ep, "/api/version", timeout=timeout),
|
||||
_raw_probe(ep, "/api/ps", timeout=timeout),
|
||||
)
|
||||
|
||||
version_value = (
|
||||
version_payload.get("version")
|
||||
if version_ok and isinstance(version_payload, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
if version_ok and ps_ok:
|
||||
return {"status": "ok", "version": version_value}
|
||||
if not version_ok and not ps_ok:
|
||||
return {"status": "error", "detail": str(version_payload)}
|
||||
# Partial failure — daemon reachable but one probe failed. Report
|
||||
# as "error" so callers can surface the issue; include `version` so
|
||||
# the operator knows the daemon itself is alive.
|
||||
if not ps_ok:
|
||||
return {
|
||||
"status": "error",
|
||||
"version": version_value,
|
||||
"detail": f"/api/ps: {ps_payload}",
|
||||
}
|
||||
return {
|
||||
"status": "error",
|
||||
"detail": f"/api/version: {version_payload}",
|
||||
}
|
||||
from backends.probe import _raw_probe, _endpoint_health
|
||||
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
|
|||
7
state.py
7
state.py
|
|
@ -69,6 +69,13 @@ app_state = {
|
|||
"httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints
|
||||
}
|
||||
|
||||
# Default outbound HTTP headers attached to every backend request.
|
||||
default_headers = {
|
||||
"HTTP-Referer": "https://nomyo.ai",
|
||||
"Referer": "https://nomyo.ai",
|
||||
"X-Title": "NOMYO Router",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Count Buffer (for write-behind pattern)
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue