diff --git a/backends/__init__.py b/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backends/health.py b/backends/health.py new file mode 100644 index 0000000..ff985ca --- /dev/null +++ b/backends/health.py @@ -0,0 +1,136 @@ +"""Backend health probes and error classification helpers. + +Contains: + * cache-freshness check (``_is_fresh``) + * aiohttp response success assertion (``_ensure_success``) + * human-readable connection-issue formatter + * upstream-error detection that distinguishes connection failures from + legitimate 4xx responses (``_is_backend_connection_error``) + * per-(endpoint, model) unhealthy marker that feeds ``choose_endpoint`` + * llama-server status interpretation (``_is_llama_model_loaded`` etc.) +""" +import asyncio +import time +from urllib.parse import urlparse + +import aiohttp +import openai +from fastapi import HTTPException + +from security import _mask_secrets +from state import _completion_error_cache, _completion_error_cache_lock + + +def _is_fresh(cached_at: float, ttl: int) -> bool: + return (time.time() - cached_at) < ttl + + +async def _ensure_success(resp: aiohttp.ClientResponse) -> None: + if resp.status >= 400: + text = await resp.text() + raise HTTPException(status_code=resp.status, detail=_mask_secrets(text)) + + +def _format_connection_issue(url: str, error: Exception) -> str: + """ + Provide a human-friendly error string for connection failures so operators + know which endpoint and address failed from inside the container. + """ + parsed = urlparse(url) + host_hint = parsed.hostname or "" + port_hint = parsed.port or "" + + if isinstance(error, aiohttp.ClientConnectorError): + resolved_host = getattr(error, "host", host_hint) or host_hint or "?" + resolved_port = getattr(error, "port", port_hint) or port_hint or "?" + parts = [ + f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).", + "Ensure the endpoint address is reachable from within the container.", + ] + if resolved_host in {"localhost", "127.0.0.1"}: + parts.append( + "Inside Docker, 'localhost' refers to the container itself; use " + "'host.docker.internal' or a Docker network alias if the service " + "runs on the host machine." + ) + os_error = getattr(error, "os_error", None) + if isinstance(os_error, OSError): + errno = getattr(os_error, "errno", None) + strerror = os_error.strerror or str(os_error) + if errno is not None or strerror: + parts.append(f"OS error [{errno}]: {strerror}.") + elif os_error: + parts.append(f"OS error: {os_error}.") + parts.append(f"Original error: {error}.") + return " ".join(parts) + + if isinstance(error, asyncio.TimeoutError): + return ( + f"Timed out waiting for {url}. " + "The remote endpoint may be offline or slow to respond." + ) + + return f"Error while contacting {url}: {error}" + + +def _is_backend_connection_error(exc: Exception) -> bool: + """True for upstream connection-class failures observed via the OpenAI client. + + Targets the case where a llama-server in router mode keeps answering + /v1/models but its delegated worker for a specific model is dead, so + chat/completions calls return 5xx with 'proxy error: Could not establish + connection' (or the SDK raises APIConnectionError outright). + + Excludes BadRequestError with exceed_context_size_error by design — those + must stay on the reactive-trim path. + """ + if isinstance(exc, openai.APIConnectionError): + return True + if isinstance(exc, openai.InternalServerError): + msg = str(exc).lower() + return ( + "proxy error" in msg + or "could not establish connection" in msg + or "connection refused" in msg + ) + return False + + +async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None: + """Record (endpoint, model) as broken so choose_endpoint avoids it. + + Cleared only by TTL — the dead-worker failure mode is invisible to the + /v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot + rely on a successful probe as a recovery signal. + """ + async with _completion_error_cache_lock: + _completion_error_cache[(endpoint, model)] = time.time() + print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True) + + +def _is_llama_model_loaded(item: dict) -> bool: + """Return True if a llama-server /v1/models item has status 'loaded'. + Handles both dict format ({"value": "loaded"}) and plain string ("loaded"). + If no status field is present, the model is always-loaded (not dynamically managed).""" + status = item.get("status") + if status is None: + return True # No status field: model is always loaded (e.g. single-model servers) + if isinstance(status, dict): + return status.get("value") == "loaded" + if isinstance(status, str): + return status == "loaded" + return False + + +def _is_llama_model_loaded_or_sleeping(item: dict) -> bool: + """Return True if status is 'loaded' or 'sleeping'. + Newer llama-server versions report 'sleeping' in /v1/models when a model is idle; + ps_details needs to include these so _fetch_llama_props can detect and unload them.""" + status = item.get("status") + if status is None: + return True + if isinstance(status, dict): + return status.get("value") in ("loaded", "sleeping") + if isinstance(status, str): + return status in ("loaded", "sleeping") + return False diff --git a/backends/normalize.py b/backends/normalize.py new file mode 100644 index 0000000..6603f9d --- /dev/null +++ b/backends/normalize.py @@ -0,0 +1,113 @@ +"""Endpoint URL, model-name, and endpoint-classification helpers. + +The endpoint classifiers read live config via ``get_config()`` so that the +startup-time rebind of ``config`` in router.py is picked up at call time. +""" +from config import get_config + + +def _normalize_llama_model_name(name: str) -> str: + """Extract the model name from a huggingface-style identifier. + e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF' + """ + if "/" in name: + name = name.rsplit("/", 1)[1] + if ":" in name: + name = name.split(":")[0] + return name + + +def _extract_llama_quant(name: str) -> str: + """Extract the quantization level from a huggingface-style identifier. + e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0' + Returns empty string if no quant suffix is present. + """ + if ":" in name: + return name.rsplit(":", 1)[1] + return "" + + +def ep2base(ep): + if "/v1" in ep: + base_url = ep + else: + base_url = ep + "/v1" + return base_url + + +def dedupe_on_keys(dicts, key_fields): + """ + Helper function to deduplicate endpoint details based on given dict keys. + """ + seen = set() + out = [] + for d in dicts: + # Build a tuple of the values for the chosen keys + key = tuple(d.get(k) for k in key_fields) + if key not in seen: + seen.add(key) + out.append(d) + return out + + +def is_ext_openai_endpoint(endpoint: str) -> bool: + """ + Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server). + + Returns True for: + - External services like OpenAI.com, Groq, etc. + + Returns False for: + - Ollama endpoints (without /v1, or with /v1 but default port 11434) + - llama-server endpoints (explicitly configured in llama_server_endpoints) + """ + cfg = get_config() + # Check if it's a llama-server endpoint (has /v1 and is in the configured list) + if endpoint in cfg.llama_server_endpoints: + return False + + if "/v1" not in endpoint: + return False + + base_endpoint = endpoint.replace('/v1', '') + if base_endpoint in cfg.endpoints: + return False # It's Ollama's /v1 + + # Check for default Ollama port + if ':11434' in endpoint: + return False # It's Ollama + + return True # It's an external OpenAI endpoint + + +def is_openai_compatible(endpoint: str) -> bool: + """ + Return True if the endpoint speaks the OpenAI API (not native Ollama). + This includes external OpenAI endpoints AND llama-server endpoints. + """ + return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints + + +def get_tracking_model(endpoint: str, model: str) -> str: + """ + Normalize model name for tracking purposes so it matches the PS table key. + + - For llama-server endpoints: strips HF prefix and quantization suffix + - For Ollama endpoints: appends ":latest" if no version suffix is present + - For external OpenAI endpoints: returns as-is (not shown in PS) + + This ensures consistent model naming across all routes for usage tracking. + """ + # External OpenAI endpoints are not shown in PS, keep as-is + if is_ext_openai_endpoint(endpoint): + return model + + # llama-server endpoints use normalized names in PS + if endpoint in get_config().llama_server_endpoints: + return _normalize_llama_model_name(model) + + # Ollama endpoints: append ":latest" if no version suffix + if ":" not in model: + return model + ":latest" + + return model diff --git a/backends/probe.py b/backends/probe.py new file mode 100644 index 0000000..2fd3a60 --- /dev/null +++ b/backends/probe.py @@ -0,0 +1,449 @@ +"""Backend probe / discovery primitives. + +The ``fetch`` class wraps the three discovery paths the router uses: + * ``available_models`` — what the endpoint advertises (Ollama ``/api/tags`` + or OpenAI-style ``/v1/models``) + * ``loaded_models`` — what is currently resident (Ollama ``/api/ps`` or + llama-server ``/v1/models`` filtered on ``status == "loaded"``) + * ``endpoint_details`` — arbitrary detail fetch used by management routes + +Each path goes through three layers of cache: success cache, error cache, +and an in-flight request map. Stale-while-revalidate refreshes happen in +background tasks tracked by the ``_bg_refresh_*`` maps in ``state``. + +``_raw_probe`` and ``_endpoint_health`` are the lower-level dual probes +used by ``/health`` and ``/api/config`` to distinguish a healthy daemon +with a broken model-introspection path from a dead daemon. +""" +import asyncio +import time +from typing import List, Optional, Set + +import aiohttp + +from config import get_config +from state import ( + _models_cache, + _models_cache_lock, + _loaded_models_cache, + _loaded_models_cache_lock, + _available_error_cache, + _available_error_cache_lock, + _loaded_error_cache, + _loaded_error_cache_lock, + _inflight_available_models, + _inflight_loaded_models, + _inflight_lock, + _bg_refresh_available, + _bg_refresh_loaded, + _bg_refresh_lock, + default_headers, +) +from backends.sessions import get_session +from backends.health import ( + _is_fresh, + _ensure_success, + _format_connection_issue, + _is_llama_model_loaded, +) +from backends.normalize import is_ext_openai_endpoint, is_openai_compatible + + +class fetch: + async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]: + """ + Internal function that performs the actual HTTP request to fetch available models. + This is called by available_models() after checking caches and in-flight requests. + """ + cfg = get_config() + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + + ep_base = endpoint.rstrip("/") + if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint: + endpoint_url = f"{ep_base}/v1/models" + key = "data" + elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints: + endpoint_url = f"{ep_base}/models" + key = "data" + else: + endpoint_url = f"{ep_base}/api/tags" + key = "models" + + client: aiohttp.ClientSession = get_session(endpoint) + try: + async with client.get(endpoint_url, headers=headers) as resp: + await _ensure_success(resp) + data = await resp.json() + + items = data.get(key, []) + models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} + + async with _models_cache_lock: + _models_cache[endpoint] = (models, time.time()) + return models + except Exception as e: + # Treat any error as if the endpoint offers no models + message = _format_connection_issue(endpoint_url, e) + print(f"[fetch.available_models] {message}") + # Update error cache with lock protection + async with _available_error_cache_lock: + _available_error_cache[endpoint] = time.time() + return set() + + async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None: + """ + Background task to refresh available models cache without blocking the caller. + Used for stale-while-revalidate pattern. + Deduplicates: only one background refresh runs per endpoint at a time. + """ + async with _bg_refresh_lock: + if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done(): + return # A refresh is already running for this endpoint + task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) + _bg_refresh_available[endpoint] = task + + try: + await task + except Exception as e: + # Silently fail - cache will remain stale but functional + print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}") + finally: + async with _bg_refresh_lock: + if _bg_refresh_available.get(endpoint) is task: + _bg_refresh_available.pop(endpoint, None) + + async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: + """ + Query /api/tags and return a set of all model names that the + endpoint *advertises* (i.e. is capable of serving). This endpoint lists + every model that is installed on the Ollama instance, regardless of + whether the model is currently loaded into memory. + + Uses request coalescing to prevent cache stampede: if multiple requests + arrive when cache is expired, only one actual HTTP request is made. + + Uses stale-while-revalidate: when the cache is between 300-600s old, + the stale data is returned immediately while a background refresh runs. + This prevents model blackouts caused by transient timeouts. + + If the request fails (e.g. timeout, 5xx, or malformed response), an empty + set is returned. + """ + # Check models cache with lock protection + async with _models_cache_lock: + if endpoint in _models_cache: + models, cached_at = _models_cache[endpoint] + + # FRESH: <= 300s old - return immediately + if _is_fresh(cached_at, 300): + return models + + # STALE: 300-600s old - return stale data and refresh in background + if _is_fresh(cached_at, 600): + asyncio.create_task(fetch._refresh_available_models(endpoint, api_key)) + return models # Return stale data immediately + + # EXPIRED: > 600s old - too stale, must refresh synchronously + del _models_cache[endpoint] + + # Check error cache with lock protection + async with _available_error_cache_lock: + if endpoint in _available_error_cache: + err_age = time.time() - _available_error_cache[endpoint] + if err_age < 30: + # Very fresh error (<30s) – endpoint likely still down, bail fast + return set() + elif err_age < 300: + # Stale error (30-300s) – endpoint may have recovered, probe in background + asyncio.create_task(fetch._refresh_available_models(endpoint, api_key)) + return set() + # Error expired (>300s) – remove and fall through to fresh fetch + del _available_error_cache[endpoint] + + # Request coalescing: check if another request is already fetching this endpoint + async with _inflight_lock: + if endpoint in _inflight_available_models: + # Another request is already fetching - wait for it + task = _inflight_available_models[endpoint] + else: + # Create new fetch task + task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) + _inflight_available_models[endpoint] = task + + try: + # Wait for the fetch to complete (either ours or another request's) + result = await task + return result + finally: + # Clean up in-flight tracking (only if we created it) + async with _inflight_lock: + if _inflight_available_models.get(endpoint) == task: + _inflight_available_models.pop(endpoint, None) + + + async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]: + """ + Internal function that performs the actual HTTP request to fetch loaded models. + This is called by loaded_models() after checking caches and in-flight requests. + + For Ollama endpoints: queries /api/ps and returns model names + For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" + """ + client: aiohttp.ClientSession = get_session(endpoint) + + # Check if this is a llama-server endpoint + if endpoint in get_config().llama_server_endpoints: + # Query /v1/models for llama-server + try: + async with client.get(f"{endpoint}/models") as resp: + await _ensure_success(resp) + data = await resp.json() + + # Filter for loaded models only + items = data.get("data", []) + models = { + item.get("id") + for item in items + if item.get("id") and _is_llama_model_loaded(item) + } + + # Update cache with lock protection + async with _loaded_models_cache_lock: + _loaded_models_cache[endpoint] = (models, time.time()) + # Probe succeeded — clear any stale error so the endpoint + # becomes routable again. + async with _loaded_error_cache_lock: + _loaded_error_cache.pop(endpoint, None) + return models + except Exception as e: + # If anything goes wrong we simply assume the endpoint has no models + message = _format_connection_issue(f"{endpoint}/models", e) + print(f"[fetch.loaded_models] {message}") + # Record the failure so `choose_endpoint` can avoid routing + # to an unhealthy backend and repeated probes short-circuit. + async with _loaded_error_cache_lock: + _loaded_error_cache[endpoint] = time.time() + return set() + else: + # Original Ollama /api/ps logic + try: + async with client.get(f"{endpoint}/api/ps") as resp: + await _ensure_success(resp) + data = await resp.json() + # The response format is: + # {"models": [{"name": "model1"}, {"name": "model2"}]} + models = {m.get("name") for m in data.get("models", []) if m.get("name")} + + # Update cache with lock protection + async with _loaded_models_cache_lock: + _loaded_models_cache[endpoint] = (models, time.time()) + async with _loaded_error_cache_lock: + _loaded_error_cache.pop(endpoint, None) + return models + except Exception as e: + # If anything goes wrong we simply assume the endpoint has no models + message = _format_connection_issue(f"{endpoint}/api/ps", e) + print(f"[fetch.loaded_models] {message}") + async with _loaded_error_cache_lock: + _loaded_error_cache[endpoint] = time.time() + return set() + + async def _refresh_loaded_models(endpoint: str) -> None: + """ + Background task to refresh loaded models cache without blocking the caller. + Used for stale-while-revalidate pattern. + Deduplicates: only one background refresh runs per endpoint at a time. + """ + async with _bg_refresh_lock: + if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done(): + return # A refresh is already running for this endpoint + task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) + _bg_refresh_loaded[endpoint] = task + + try: + await task + except Exception as e: + # Silently fail - cache will remain stale but functional + print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}") + finally: + async with _bg_refresh_lock: + if _bg_refresh_loaded.get(endpoint) is task: + _bg_refresh_loaded.pop(endpoint, None) + + async def loaded_models(endpoint: str) -> Set[str]: + """ + Query /api/ps and return a set of model names that are currently + loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty + set is returned. + + Uses request coalescing to prevent cache stampede and stale-while-revalidate + to serve requests immediately even when cache is stale (refreshing in background). + """ + if is_ext_openai_endpoint(endpoint): + return set() + + # Check loaded models cache with lock protection + async with _loaded_models_cache_lock: + if endpoint in _loaded_models_cache: + models, cached_at = _loaded_models_cache[endpoint] + + # FRESH: < 10s old - return immediately + if _is_fresh(cached_at, 10): + return models + + # STALE: 10-60s old - return stale data and refresh in background + if _is_fresh(cached_at, 60): + # Kick off background refresh (fire-and-forget) + asyncio.create_task(fetch._refresh_loaded_models(endpoint)) + return models # Return stale data immediately + + # EXPIRED: > 60s old - too stale, must refresh synchronously + del _loaded_models_cache[endpoint] + + # Check error cache with lock protection + async with _loaded_error_cache_lock: + if endpoint in _loaded_error_cache: + if _is_fresh(_loaded_error_cache[endpoint], 300): + return set() + # Error expired - remove it + del _loaded_error_cache[endpoint] + + # Request coalescing: check if another request is already fetching this endpoint + async with _inflight_lock: + if endpoint in _inflight_loaded_models: + # Another request is already fetching - wait for it + task = _inflight_loaded_models[endpoint] + else: + # Create new fetch task + task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) + _inflight_loaded_models[endpoint] = task + + try: + # Wait for the fetch to complete (either ours or another request's) + result = await task + return result + finally: + # Clean up in-flight tracking (only if we created it) + async with _inflight_lock: + if _inflight_loaded_models.get(endpoint) == task: + _inflight_loaded_models.pop(endpoint, None) + + async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]: + """ + Query / to fetch and return a List of dicts with details + for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. + + When ``skip_error_cache`` is False (the default), the call is short-circuited + if the endpoint recently failed (recorded in ``_available_error_cache``). + Pass ``skip_error_cache=True`` from health-check routes that must always probe. + + ``timeout`` overrides the session default for this single request (seconds, total). + """ + # Fast-fail if the endpoint is known to be down (unless caller opts out) + if not skip_error_cache: + async with _available_error_cache_lock: + if endpoint in _available_error_cache: + if _is_fresh(_available_error_cache[endpoint], 300): + return [] + + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + + request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}" + client: aiohttp.ClientSession = get_session(endpoint) + req_kwargs = {} + if timeout is not None: + req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout) + try: + async with client.get(request_url, headers=headers, **req_kwargs) as resp: + await _ensure_success(resp) + data = await resp.json() + detail = data.get(detail, []) + return detail + except Exception as e: + # If anything goes wrong we cannot reply details + message = _format_connection_issue(request_url, e) + print(f"[fetch.endpoint_details] {message}") + if not skip_error_cache: + async with _available_error_cache_lock: + _available_error_cache[endpoint] = time.time() + return [] + + +# ------------------------------------------------------------- +# Endpoint health probes (shared by /api/config and /health) +# ------------------------------------------------------------- +async def _raw_probe( + ep: str, + route: str, + api_key: Optional[str] = None, + timeout: Optional[float] = None, +) -> tuple[bool, object]: + """Direct HTTP probe that distinguishes success from failure + (unlike `fetch.endpoint_details`, which returns [] on either). + Returns `(ok, payload_or_error_message)`. + """ + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + url = f"{ep.rstrip('/')}/{route.lstrip('/')}" + req_kwargs = {} + if timeout is not None: + req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout) + try: + client: aiohttp.ClientSession = get_session(ep) + async with client.get(url, headers=headers, **req_kwargs) as resp: + await _ensure_success(resp) + data = await resp.json() + return True, data + except Exception as exc: + return False, _format_connection_issue(url, exc) + + +async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict: + """Probe an endpoint and return `{status, version?, detail?}`. + + Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so + that a daemon which is reachable but has a broken model-introspection + path (issue #83) is reported as `error` rather than `ok`. + OpenAI-compatible endpoints use a single `/models` probe. + """ + if is_openai_compatible(ep): + ok, payload = await _raw_probe( + ep, "/models", get_config().api_keys.get(ep), timeout=timeout, + ) + if ok: + return {"status": "ok", "version": "latest"} + return {"status": "error", "detail": str(payload)} + + (version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather( + _raw_probe(ep, "/api/version", timeout=timeout), + _raw_probe(ep, "/api/ps", timeout=timeout), + ) + + version_value = ( + version_payload.get("version") + if version_ok and isinstance(version_payload, dict) + else None + ) + + if version_ok and ps_ok: + return {"status": "ok", "version": version_value} + if not version_ok and not ps_ok: + return {"status": "error", "detail": str(version_payload)} + # Partial failure — daemon reachable but one probe failed. Report + # as "error" so callers can surface the issue; include `version` so + # the operator knows the daemon itself is alive. + if not ps_ok: + return { + "status": "error", + "version": version_value, + "detail": f"/api/ps: {ps_payload}", + } + return { + "status": "error", + "detail": f"/api/version: {version_payload}", + } diff --git a/backends/sessions.py b/backends/sessions.py new file mode 100644 index 0000000..a7fa2b9 --- /dev/null +++ b/backends/sessions.py @@ -0,0 +1,72 @@ +"""aiohttp / OpenAI client factories aware of Unix-socket endpoints. + +Unix socket endpoints follow the ``.sock`` hostname convention (e.g. +``http://192.168.0.52.sock/v1``) and resolve to ``/run/user//``. +Their sessions/clients live in ``state.app_state`` so that startup can +populate them once and routes can reuse them. +""" +import os + +import aiohttp +import openai + +from state import app_state +from backends.normalize import ep2base + + +def _is_unix_socket_endpoint(endpoint: str) -> bool: + """Return True if endpoint uses Unix socket (.sock hostname convention). + + Detects URLs like http://192.168.0.52.sock/v1 where the host ends with + .sock, indicating the connection should use a Unix domain socket at + /tmp/ instead of TCP. + """ + try: + host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0] + return host.endswith(".sock") + except IndexError: + return False + + +def _get_socket_path(endpoint: str) -> str: + """Derive Unix socket file path from a .sock endpoint URL. + + http://192.168.0.52.sock/v1 -> /run/user//192.168.0.52.sock + """ + host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0] + return f"/run/user/{os.getuid()}/{host}" + + +def get_session(endpoint: str) -> aiohttp.ClientSession: + """Return the appropriate aiohttp session for the given endpoint. + + Unix socket endpoints (.sock) get their own UnixConnector session. + All other endpoints share the main TCP session. + """ + if _is_unix_socket_endpoint(endpoint): + sess = app_state["socket_sessions"].get(endpoint) + if sess is not None: + return sess + return app_state["session"] + + +def _make_openai_client( + endpoint: str, + default_headers: dict | None = None, + api_key: str = "no-key", +) -> openai.AsyncOpenAI: + """Return an AsyncOpenAI client configured for the given endpoint. + + For Unix socket endpoints, injects a pre-created httpx UDS transport + so the OpenAI SDK connects via the socket instead of TCP. + """ + base_url = ep2base(endpoint) + kwargs: dict = {"api_key": api_key} + if default_headers is not None: + kwargs["default_headers"] = default_headers + if _is_unix_socket_endpoint(endpoint): + http_client = app_state["httpx_clients"].get(endpoint) + if http_client is not None: + kwargs["http_client"] = http_client + base_url = "http://localhost/v1" + return openai.AsyncOpenAI(base_url=base_url, **kwargs) diff --git a/config.py b/config.py index e3b5ee6..143a2f9 100644 --- a/config.py +++ b/config.py @@ -124,3 +124,16 @@ def _config_path_from_env() -> Path: if candidate: return Path(candidate).expanduser() return Path("config.yaml") + + +# ------------------------------------------------------------------ +# Shared config accessor +# ------------------------------------------------------------------ +# Submodules read config at call time via get_config() instead of importing +# a bound name. The single source of truth is ``router.config`` — the lazy +# import below resolves it after router.py has finished loading, and lets +# tests that ``patch.object(router, "config", cfg)`` flow through. +def get_config() -> "Config": + """Return the currently active Config from router.py.""" + import router # lazy to avoid module-load circular import + return router.config diff --git a/router.py b/router.py index ecbad68..514590e 100644 --- a/router.py +++ b/router.py @@ -75,7 +75,8 @@ from db import TokenDatabase from cache import init_llm_cache, get_llm_cache, openai_nonstream_to_sse -# Create the global config object – it will be overwritten on startup +# Create the global config object – it will be overwritten on startup. +# Submodules read it lazily via config.get_config(). config = Config.from_yaml(_config_path_from_env()) # ------------------------------------------------------------- @@ -90,11 +91,7 @@ app.add_middleware( allow_methods=["GET", "POST", "DELETE"], allow_headers=["Authorization", "Content-Type"], ) -default_headers={ - "HTTP-Referer": "https://nomyo.ai", - "Referer": "https://nomyo.ai", - "X-Title": "NOMYO Router", - } +from state import default_headers # ------------------------------------------------------------- # Router-level authentication (optional) @@ -205,254 +202,36 @@ from fingerprint import _conversation_fingerprint db: "TokenDatabase" = None # ------------------------------------------------------------- -# 4. Helperfunctions +# 4. Helperfunctions # ------------------------------------------------------------- -def _is_fresh(cached_at: float, ttl: int) -> bool: - return (time.time() - cached_at) < ttl - -async def _ensure_success(resp: aiohttp.ClientResponse) -> None: - if resp.status >= 400: - text = await resp.text() - raise HTTPException(status_code=resp.status, detail=_mask_secrets(text)) - -def _format_connection_issue(url: str, error: Exception) -> str: - """ - Provide a human-friendly error string for connection failures so operators - know which endpoint and address failed from inside the container. - """ - parsed = urlparse(url) - host_hint = parsed.hostname or "" - port_hint = parsed.port or "" - - if isinstance(error, aiohttp.ClientConnectorError): - resolved_host = getattr(error, "host", host_hint) or host_hint or "?" - resolved_port = getattr(error, "port", port_hint) or port_hint or "?" - parts = [ - f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).", - "Ensure the endpoint address is reachable from within the container.", - ] - if resolved_host in {"localhost", "127.0.0.1"}: - parts.append( - "Inside Docker, 'localhost' refers to the container itself; use " - "'host.docker.internal' or a Docker network alias if the service " - "runs on the host machine." - ) - os_error = getattr(error, "os_error", None) - if isinstance(os_error, OSError): - errno = getattr(os_error, "errno", None) - strerror = os_error.strerror or str(os_error) - if errno is not None or strerror: - parts.append(f"OS error [{errno}]: {strerror}.") - elif os_error: - parts.append(f"OS error: {os_error}.") - parts.append(f"Original error: {error}.") - return " ".join(parts) - - if isinstance(error, asyncio.TimeoutError): - return ( - f"Timed out waiting for {url}. " - "The remote endpoint may be offline or slow to respond." - ) - - return f"Error while contacting {url}: {error}" - -def _normalize_llama_model_name(name: str) -> str: - """Extract the model name from a huggingface-style identifier. - e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF' - """ - if "/" in name: - name = name.rsplit("/", 1)[1] - if ":" in name: - name = name.split(":")[0] - return name - -def _extract_llama_quant(name: str) -> str: - """Extract the quantization level from a huggingface-style identifier. - e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0' - Returns empty string if no quant suffix is present. - """ - if ":" in name: - return name.rsplit(":", 1)[1] - return "" +from backends.normalize import ( + _normalize_llama_model_name, + _extract_llama_quant, + ep2base, + dedupe_on_keys, +) +from backends.sessions import ( + _is_unix_socket_endpoint, + _get_socket_path, + get_session, + _make_openai_client, +) +from backends.health import ( + _is_fresh, + _ensure_success, + _format_connection_issue, + _is_backend_connection_error, + _mark_backend_unhealthy, + _is_llama_model_loaded, + _is_llama_model_loaded_or_sleeping, +) -def _is_unix_socket_endpoint(endpoint: str) -> bool: - """Return True if endpoint uses Unix socket (.sock hostname convention). - - Detects URLs like http://192.168.0.52.sock/v1 where the host ends with - .sock, indicating the connection should use a Unix domain socket at - /tmp/ instead of TCP. - """ - try: - host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0] - return host.endswith(".sock") - except IndexError: - return False - - -def _get_socket_path(endpoint: str) -> str: - """Derive Unix socket file path from a .sock endpoint URL. - - http://192.168.0.52.sock/v1 -> /run/user//192.168.0.52.sock - """ - host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0] - return f"/run/user/{os.getuid()}/{host}" - - -def get_session(endpoint: str) -> aiohttp.ClientSession: - """Return the appropriate aiohttp session for the given endpoint. - - Unix socket endpoints (.sock) get their own UnixConnector session. - All other endpoints share the main TCP session. - """ - if _is_unix_socket_endpoint(endpoint): - sess = app_state["socket_sessions"].get(endpoint) - if sess is not None: - return sess - return app_state["session"] - - -def _make_openai_client( - endpoint: str, - default_headers: dict | None = None, - api_key: str = "no-key", -) -> openai.AsyncOpenAI: - """Return an AsyncOpenAI client configured for the given endpoint. - - For Unix socket endpoints, injects a pre-created httpx UDS transport - so the OpenAI SDK connects via the socket instead of TCP. - """ - base_url = ep2base(endpoint) - kwargs: dict = {"api_key": api_key} - if default_headers is not None: - kwargs["default_headers"] = default_headers - if _is_unix_socket_endpoint(endpoint): - http_client = app_state["httpx_clients"].get(endpoint) - if http_client is not None: - kwargs["http_client"] = http_client - base_url = "http://localhost/v1" - return openai.AsyncOpenAI(base_url=base_url, **kwargs) - - -def _is_backend_connection_error(exc: Exception) -> bool: - """True for upstream connection-class failures observed via the OpenAI client. - - Targets the case where a llama-server in router mode keeps answering - /v1/models but its delegated worker for a specific model is dead, so - chat/completions calls return 5xx with 'proxy error: Could not establish - connection' (or the SDK raises APIConnectionError outright). - - Excludes BadRequestError with exceed_context_size_error by design — those - must stay on the reactive-trim path. - """ - if isinstance(exc, openai.APIConnectionError): - return True - if isinstance(exc, openai.InternalServerError): - msg = str(exc).lower() - return ( - "proxy error" in msg - or "could not establish connection" in msg - or "connection refused" in msg - ) - return False - - -async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None: - """Record (endpoint, model) as broken so choose_endpoint avoids it. - - Cleared only by TTL — the dead-worker failure mode is invisible to the - /v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot - rely on a successful probe as a recovery signal. - """ - async with _completion_error_cache_lock: - _completion_error_cache[(endpoint, model)] = time.time() - print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True) - - -def _is_llama_model_loaded(item: dict) -> bool: - """Return True if a llama-server /v1/models item has status 'loaded'. - Handles both dict format ({"value": "loaded"}) and plain string ("loaded"). - If no status field is present, the model is always-loaded (not dynamically managed).""" - status = item.get("status") - if status is None: - return True # No status field: model is always loaded (e.g. single-model servers) - if isinstance(status, dict): - return status.get("value") == "loaded" - if isinstance(status, str): - return status == "loaded" - return False - -def _is_llama_model_loaded_or_sleeping(item: dict) -> bool: - """Return True if status is 'loaded' or 'sleeping'. - Newer llama-server versions report 'sleeping' in /v1/models when a model is idle; - ps_details needs to include these so _fetch_llama_props can detect and unload them.""" - status = item.get("status") - if status is None: - return True - if isinstance(status, dict): - return status.get("value") in ("loaded", "sleeping") - if isinstance(status, str): - return status in ("loaded", "sleeping") - return False - -def is_ext_openai_endpoint(endpoint: str) -> bool: - """ - Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server). - - Returns True for: - - External services like OpenAI.com, Groq, etc. - - Returns False for: - - Ollama endpoints (without /v1, or with /v1 but default port 11434) - - llama-server endpoints (explicitly configured in llama_server_endpoints) - """ - # Check if it's a llama-server endpoint (has /v1 and is in the configured list) - if endpoint in config.llama_server_endpoints: - return False - - if "/v1" not in endpoint: - return False - - base_endpoint = endpoint.replace('/v1', '') - if base_endpoint in config.endpoints: - return False # It's Ollama's /v1 - - # Check for default Ollama port - if ':11434' in endpoint: - return False # It's Ollama - - return True # It's an external OpenAI endpoint - -def is_openai_compatible(endpoint: str) -> bool: - """ - Return True if the endpoint speaks the OpenAI API (not native Ollama). - This includes external OpenAI endpoints AND llama-server endpoints. - """ - return "/v1" in endpoint or endpoint in config.llama_server_endpoints - -def get_tracking_model(endpoint: str, model: str) -> str: - """ - Normalize model name for tracking purposes so it matches the PS table key. - - - For llama-server endpoints: strips HF prefix and quantization suffix - - For Ollama endpoints: appends ":latest" if no version suffix is present - - For external OpenAI endpoints: returns as-is (not shown in PS) - - This ensures consistent model naming across all routes for usage tracking. - """ - # External OpenAI endpoints are not shown in PS, keep as-is - if is_ext_openai_endpoint(endpoint): - return model - - # llama-server endpoints use normalized names in PS - if endpoint in config.llama_server_endpoints: - return _normalize_llama_model_name(model) - - # Ollama endpoints: append ":latest" if no version suffix - if ":" not in model: - return model + ":latest" - - return model +from backends.normalize import ( + is_ext_openai_endpoint, + is_openai_compatible, + get_tracking_model, +) async def token_worker() -> None: try: @@ -601,348 +380,8 @@ async def flush_remaining_buffers() -> None: # Do not raise during shutdown – log and continue teardown print(f"[shutdown] Error flushing remaining buffers: {e}") -class fetch: - async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]: - """ - Internal function that performs the actual HTTP request to fetch available models. - This is called by available_models() after checking caches and in-flight requests. - """ - headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} - if api_key is not None: - headers["Authorization"] = "Bearer " + api_key +from backends.probe import fetch - ep_base = endpoint.rstrip("/") - if endpoint in config.llama_server_endpoints and "/v1" not in endpoint: - endpoint_url = f"{ep_base}/v1/models" - key = "data" - elif "/v1" in endpoint or endpoint in config.llama_server_endpoints: - endpoint_url = f"{ep_base}/models" - key = "data" - else: - endpoint_url = f"{ep_base}/api/tags" - key = "models" - - client: aiohttp.ClientSession = get_session(endpoint) - try: - async with client.get(endpoint_url, headers=headers) as resp: - await _ensure_success(resp) - data = await resp.json() - - items = data.get(key, []) - models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")} - - async with _models_cache_lock: - _models_cache[endpoint] = (models, time.time()) - return models - except Exception as e: - # Treat any error as if the endpoint offers no models - message = _format_connection_issue(endpoint_url, e) - print(f"[fetch.available_models] {message}") - # Update error cache with lock protection - async with _available_error_cache_lock: - _available_error_cache[endpoint] = time.time() - return set() - - async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None: - """ - Background task to refresh available models cache without blocking the caller. - Used for stale-while-revalidate pattern. - Deduplicates: only one background refresh runs per endpoint at a time. - """ - async with _bg_refresh_lock: - if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done(): - return # A refresh is already running for this endpoint - task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) - _bg_refresh_available[endpoint] = task - - try: - await task - except Exception as e: - # Silently fail - cache will remain stale but functional - print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}") - finally: - async with _bg_refresh_lock: - if _bg_refresh_available.get(endpoint) is task: - _bg_refresh_available.pop(endpoint, None) - - async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]: - """ - Query /api/tags and return a set of all model names that the - endpoint *advertises* (i.e. is capable of serving). This endpoint lists - every model that is installed on the Ollama instance, regardless of - whether the model is currently loaded into memory. - - Uses request coalescing to prevent cache stampede: if multiple requests - arrive when cache is expired, only one actual HTTP request is made. - - Uses stale-while-revalidate: when the cache is between 300-600s old, - the stale data is returned immediately while a background refresh runs. - This prevents model blackouts caused by transient timeouts. - - If the request fails (e.g. timeout, 5xx, or malformed response), an empty - set is returned. - """ - # Check models cache with lock protection - async with _models_cache_lock: - if endpoint in _models_cache: - models, cached_at = _models_cache[endpoint] - - # FRESH: <= 300s old - return immediately - if _is_fresh(cached_at, 300): - return models - - # STALE: 300-600s old - return stale data and refresh in background - if _is_fresh(cached_at, 600): - asyncio.create_task(fetch._refresh_available_models(endpoint, api_key)) - return models # Return stale data immediately - - # EXPIRED: > 600s old - too stale, must refresh synchronously - del _models_cache[endpoint] - - # Check error cache with lock protection - async with _available_error_cache_lock: - if endpoint in _available_error_cache: - err_age = time.time() - _available_error_cache[endpoint] - if err_age < 30: - # Very fresh error (<30s) – endpoint likely still down, bail fast - return set() - elif err_age < 300: - # Stale error (30-300s) – endpoint may have recovered, probe in background - asyncio.create_task(fetch._refresh_available_models(endpoint, api_key)) - return set() - # Error expired (>300s) – remove and fall through to fresh fetch - del _available_error_cache[endpoint] - - # Request coalescing: check if another request is already fetching this endpoint - async with _inflight_lock: - if endpoint in _inflight_available_models: - # Another request is already fetching - wait for it - task = _inflight_available_models[endpoint] - else: - # Create new fetch task - task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key)) - _inflight_available_models[endpoint] = task - - try: - # Wait for the fetch to complete (either ours or another request's) - result = await task - return result - finally: - # Clean up in-flight tracking (only if we created it) - async with _inflight_lock: - if _inflight_available_models.get(endpoint) == task: - _inflight_available_models.pop(endpoint, None) - - - async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]: - """ - Internal function that performs the actual HTTP request to fetch loaded models. - This is called by loaded_models() after checking caches and in-flight requests. - - For Ollama endpoints: queries /api/ps and returns model names - For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" - """ - client: aiohttp.ClientSession = get_session(endpoint) - - # Check if this is a llama-server endpoint - if endpoint in config.llama_server_endpoints: - # Query /v1/models for llama-server - try: - async with client.get(f"{endpoint}/models") as resp: - await _ensure_success(resp) - data = await resp.json() - - # Filter for loaded models only - items = data.get("data", []) - models = { - item.get("id") - for item in items - if item.get("id") and _is_llama_model_loaded(item) - } - - # Update cache with lock protection - async with _loaded_models_cache_lock: - _loaded_models_cache[endpoint] = (models, time.time()) - # Probe succeeded — clear any stale error so the endpoint - # becomes routable again. - async with _loaded_error_cache_lock: - _loaded_error_cache.pop(endpoint, None) - return models - except Exception as e: - # If anything goes wrong we simply assume the endpoint has no models - message = _format_connection_issue(f"{endpoint}/models", e) - print(f"[fetch.loaded_models] {message}") - # Record the failure so `choose_endpoint` can avoid routing - # to an unhealthy backend and repeated probes short-circuit. - async with _loaded_error_cache_lock: - _loaded_error_cache[endpoint] = time.time() - return set() - else: - # Original Ollama /api/ps logic - try: - async with client.get(f"{endpoint}/api/ps") as resp: - await _ensure_success(resp) - data = await resp.json() - # The response format is: - # {"models": [{"name": "model1"}, {"name": "model2"}]} - models = {m.get("name") for m in data.get("models", []) if m.get("name")} - - # Update cache with lock protection - async with _loaded_models_cache_lock: - _loaded_models_cache[endpoint] = (models, time.time()) - async with _loaded_error_cache_lock: - _loaded_error_cache.pop(endpoint, None) - return models - except Exception as e: - # If anything goes wrong we simply assume the endpoint has no models - message = _format_connection_issue(f"{endpoint}/api/ps", e) - print(f"[fetch.loaded_models] {message}") - async with _loaded_error_cache_lock: - _loaded_error_cache[endpoint] = time.time() - return set() - - async def _refresh_loaded_models(endpoint: str) -> None: - """ - Background task to refresh loaded models cache without blocking the caller. - Used for stale-while-revalidate pattern. - Deduplicates: only one background refresh runs per endpoint at a time. - """ - async with _bg_refresh_lock: - if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done(): - return # A refresh is already running for this endpoint - task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) - _bg_refresh_loaded[endpoint] = task - - try: - await task - except Exception as e: - # Silently fail - cache will remain stale but functional - print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}") - finally: - async with _bg_refresh_lock: - if _bg_refresh_loaded.get(endpoint) is task: - _bg_refresh_loaded.pop(endpoint, None) - - async def loaded_models(endpoint: str) -> Set[str]: - """ - Query /api/ps and return a set of model names that are currently - loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty - set is returned. - - Uses request coalescing to prevent cache stampede and stale-while-revalidate - to serve requests immediately even when cache is stale (refreshing in background). - """ - if is_ext_openai_endpoint(endpoint): - return set() - - # Check loaded models cache with lock protection - async with _loaded_models_cache_lock: - if endpoint in _loaded_models_cache: - models, cached_at = _loaded_models_cache[endpoint] - - # FRESH: < 10s old - return immediately - if _is_fresh(cached_at, 10): - return models - - # STALE: 10-60s old - return stale data and refresh in background - if _is_fresh(cached_at, 60): - # Kick off background refresh (fire-and-forget) - asyncio.create_task(fetch._refresh_loaded_models(endpoint)) - return models # Return stale data immediately - - # EXPIRED: > 60s old - too stale, must refresh synchronously - del _loaded_models_cache[endpoint] - - # Check error cache with lock protection - async with _loaded_error_cache_lock: - if endpoint in _loaded_error_cache: - if _is_fresh(_loaded_error_cache[endpoint], 300): - return set() - # Error expired - remove it - del _loaded_error_cache[endpoint] - - # Request coalescing: check if another request is already fetching this endpoint - async with _inflight_lock: - if endpoint in _inflight_loaded_models: - # Another request is already fetching - wait for it - task = _inflight_loaded_models[endpoint] - else: - # Create new fetch task - task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint)) - _inflight_loaded_models[endpoint] = task - - try: - # Wait for the fetch to complete (either ours or another request's) - result = await task - return result - finally: - # Clean up in-flight tracking (only if we created it) - async with _inflight_lock: - if _inflight_loaded_models.get(endpoint) == task: - _inflight_loaded_models.pop(endpoint, None) - - async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]: - """ - Query / to fetch and return a List of dicts with details - for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail. - - When ``skip_error_cache`` is False (the default), the call is short-circuited - if the endpoint recently failed (recorded in ``_available_error_cache``). - Pass ``skip_error_cache=True`` from health-check routes that must always probe. - - ``timeout`` overrides the session default for this single request (seconds, total). - """ - # Fast-fail if the endpoint is known to be down (unless caller opts out) - if not skip_error_cache: - async with _available_error_cache_lock: - if endpoint in _available_error_cache: - if _is_fresh(_available_error_cache[endpoint], 300): - return [] - - headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} - if api_key is not None: - headers["Authorization"] = "Bearer " + api_key - - request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}" - client: aiohttp.ClientSession = get_session(endpoint) - req_kwargs = {} - if timeout is not None: - req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout) - try: - async with client.get(request_url, headers=headers, **req_kwargs) as resp: - await _ensure_success(resp) - data = await resp.json() - detail = data.get(detail, []) - return detail - except Exception as e: - # If anything goes wrong we cannot reply details - message = _format_connection_issue(request_url, e) - print(f"[fetch.endpoint_details] {message}") - if not skip_error_cache: - async with _available_error_cache_lock: - _available_error_cache[endpoint] = time.time() - return [] - -def ep2base(ep): - if "/v1" in ep: - base_url = ep - else: - base_url = ep+"/v1" - return base_url - -def dedupe_on_keys(dicts, key_fields): - """ - Helper function to deduplicate endpoint details based on given dict keys. - """ - seen = set() - out = [] - for d in dicts: - # Build a tuple of the values for the chosen keys - key = tuple(d.get(k) for k in key_fields) - if key not in seen: - seen.add(key) - out.append(d) - return out async def increment_usage(endpoint: str, model: str) -> None: async with usage_lock: @@ -2910,80 +2349,7 @@ async def usage_proxy(request: Request): return {"usage_counts": usage_counts, "token_usage_counts": token_usage_counts} -# ------------------------------------------------------------- -# 20. Endpoint health probes (shared by /api/config and /health) -# ------------------------------------------------------------- -async def _raw_probe( - ep: str, - route: str, - api_key: Optional[str] = None, - timeout: Optional[float] = None, -) -> tuple[bool, object]: - """Direct HTTP probe that distinguishes success from failure - (unlike `fetch.endpoint_details`, which returns [] on either). - Returns `(ok, payload_or_error_message)`. - """ - headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} - if api_key is not None: - headers["Authorization"] = "Bearer " + api_key - url = f"{ep.rstrip('/')}/{route.lstrip('/')}" - req_kwargs = {} - if timeout is not None: - req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout) - try: - client: aiohttp.ClientSession = get_session(ep) - async with client.get(url, headers=headers, **req_kwargs) as resp: - await _ensure_success(resp) - data = await resp.json() - return True, data - except Exception as exc: - return False, _format_connection_issue(url, exc) - - -async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict: - """Probe an endpoint and return `{status, version?, detail?}`. - - Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so - that a daemon which is reachable but has a broken model-introspection - path (issue #83) is reported as `error` rather than `ok`. - OpenAI-compatible endpoints use a single `/models` probe. - """ - if is_openai_compatible(ep): - ok, payload = await _raw_probe( - ep, "/models", config.api_keys.get(ep), timeout=timeout, - ) - if ok: - return {"status": "ok", "version": "latest"} - return {"status": "error", "detail": str(payload)} - - (version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather( - _raw_probe(ep, "/api/version", timeout=timeout), - _raw_probe(ep, "/api/ps", timeout=timeout), - ) - - version_value = ( - version_payload.get("version") - if version_ok and isinstance(version_payload, dict) - else None - ) - - if version_ok and ps_ok: - return {"status": "ok", "version": version_value} - if not version_ok and not ps_ok: - return {"status": "error", "detail": str(version_payload)} - # Partial failure — daemon reachable but one probe failed. Report - # as "error" so callers can surface the issue; include `version` so - # the operator knows the daemon itself is alive. - if not ps_ok: - return { - "status": "error", - "version": version_value, - "detail": f"/api/ps: {ps_payload}", - } - return { - "status": "error", - "detail": f"/api/version: {version_payload}", - } +from backends.probe import _raw_probe, _endpoint_health # ------------------------------------------------------------- diff --git a/state.py b/state.py index 9f2b3cd..301cc26 100644 --- a/state.py +++ b/state.py @@ -69,6 +69,13 @@ app_state = { "httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints } +# Default outbound HTTP headers attached to every backend request. +default_headers = { + "HTTP-Referer": "https://nomyo.ai", + "Referer": "https://nomyo.ai", + "X-Title": "NOMYO Router", +} + # ------------------------------------------------------------------ # Token Count Buffer (for write-behind pattern) # ------------------------------------------------------------------