From 27dfc07889970f2446e158d70f4e1b0c33023be8 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Tue, 12 May 2026 18:33:47 +0200 Subject: [PATCH 01/22] feat: add conversation-endpoint affinity to benefit from hot kv-caches if possible --- config.yaml | 10 +++ router.py | 199 ++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 157 insertions(+), 52 deletions(-) diff --git a/config.yaml b/config.yaml index 76fbbe1..6fc57e6 100644 --- a/config.yaml +++ b/config.yaml @@ -26,6 +26,16 @@ max_concurrent_connections: 2 # When false (default), equally-idle endpoints are chosen at random. # priority_routing: true +# Conversation affinity (optional, default: false). +# Routes follow-up requests back to the endpoint that previously served the +# same conversation so the llama.cpp / Ollama prompt cache (KV cache) stays +# warm — first turn does a cold prefill, follow-ups skip it. Soft preference: +# falls back to the standard algorithm when the affine endpoint no longer has +# the model loaded or has no free slot. Conversations are fingerprinted by +# (model, first system + first user turn). +# conversation_affinity: true +# conversation_affinity_ttl: 300 # seconds; matches Ollama's default keep_alive + # Optional router-level API key that gates router/API/web UI access (leave empty to disable) nomyo-router-api-key: "" diff --git a/router.py b/router.py index 603387e..34e02da 100644 --- a/router.py +++ b/router.py @@ -6,7 +6,7 @@ version: 0.7 license: AGPL """ # ------------------------------------------------------------- -import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx +import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx, hashlib try: import truststore; truststore.inject_into_ssl() except ImportError: @@ -223,6 +223,15 @@ class Config(BaseSettings): # When True, config order = priority; routes by utilization ratio + config index (WRR) priority_routing: bool = Field(default=False) + # Conversation affinity: route the same conversation back to the endpoint that + # previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm. + # Soft preference — falls back to the standard algorithm when the affine endpoint + # is saturated or no longer has the model loaded. + conversation_affinity: bool = Field(default=False) + # TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min): + # if the backend has already evicted the model, the KV cache is cold anyway. + conversation_affinity_ttl: int = Field(default=300) + api_keys: Dict[str, str] = Field(default_factory=dict) # Optional router-level API key used to gate access to this service and dashboard router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY") @@ -436,6 +445,45 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict( usage_lock = asyncio.Lock() # protects access to usage_counts token_usage_lock = asyncio.Lock() +# Conversation affinity map: fingerprint -> (endpoint, expires_at_monotonic). +# Keeps the same conversation pinned to the endpoint that already has its +# KV-cache prefix warm. Never held together with usage_lock. +_affinity_map: Dict[str, tuple[str, float]] = {} +_affinity_lock = asyncio.Lock() +_AFFINITY_MAX_ENTRIES = 10000 + + +def _conversation_fingerprint(model: str, messages: Optional[list], + prompt: Optional[str]) -> Optional[str]: + """ + Stable hash over (model, first system + first user turn). That prefix + determines whether the backend's prompt cache is reusable; later turns + don't influence the routing decision because they extend the same prefix. + Returns None when there is no usable prefix. + """ + parts: list[str] = [model or "_"] + if messages: + for m in messages: + role = m.get("role") if isinstance(m, dict) else None + if role not in ("system", "user"): + continue + content = m.get("content") + if isinstance(content, list): # OpenAI multimodal parts + content = "".join( + p.get("text", "") for p in content + if isinstance(p, dict) and p.get("type") == "text" + ) + if not isinstance(content, str): + continue + parts.append(f"{role}:{content}") + if role == "user": + break + elif prompt: + parts.append(f"user:{prompt}") + else: + return None + return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest() + # Database instance db: "TokenDatabase" = None @@ -1738,7 +1786,8 @@ def get_max_connections(ep: str) -> int: "max_concurrent_connections", config.max_concurrent_connections ) -async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: +async def choose_endpoint(model: str, reserve: bool = True, + affinity_key: Optional[str] = None) -> tuple[str, str]: """ Determine which endpoint to use for the given model while respecting the `max_concurrent_connections` per endpoint‑model pair **and** @@ -1748,10 +1797,14 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: 1️⃣ Query every endpoint for its advertised models (`/api/tags`). 2️⃣ Build a list of endpoints that contain the requested model. + 2️⃣.5 If conversation affinity is enabled and the caller passes + ``affinity_key``, prefer the endpoint that previously served the + same conversation — but only when it still has the model loaded + and a free slot. Otherwise fall through to the standard logic. 3️⃣ For those endpoints, find those that have the model loaded (`/api/ps`) *and* still have a free slot. 4️⃣ If none are both loaded and free, fall back to any endpoint - from the filtered list that simply has a free slot and randomly + from the filtered list that simply has a free slot and randomly select one. 5️⃣ If all are saturated, pick any endpoint from the filtered list (the request will queue on that endpoint). @@ -1799,6 +1852,19 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) + # Look up a possible affinity hint *before* taking usage_lock. The two + # locks are never held together to avoid lock-ordering issues. + affine_ep: Optional[str] = None + if config.conversation_affinity and affinity_key: + async with _affinity_lock: + entry = _affinity_map.get(affinity_key) + if entry is not None: + ep, expires_at = entry + if expires_at < time.monotonic(): + _affinity_map.pop(affinity_key, None) + else: + affine_ep = ep + # Protect all reads/writes of usage_counts with the lock so that selection # and reservation are atomic — concurrent callers see each other's pending load. async with usage_lock: @@ -1814,59 +1880,75 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: # Priority map: position in all_endpoints list (lower = higher priority) ep_priority = {ep: i for i, ep in enumerate(all_endpoints)} - # 3️⃣ Endpoints that have the model loaded *and* a free slot - loaded_and_free = [ - ep for ep, models in zip(candidate_endpoints, loaded_sets) - if model in models and tracking_usage(ep) < get_max_connections(ep) - ] + selected: Optional[str] = None - if loaded_and_free: - if config.priority_routing: - # WRR: sort by config order first (stable), then by utilization ratio. - # Stable sort preserves priority for equal-ratio endpoints. - loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) - loaded_and_free.sort(key=utilization_ratio) - selected = loaded_and_free[0] - else: - # Sort ascending for load balancing — all endpoints here already have the - # model loaded, so there is no model-switching cost to optimise for. - loaded_and_free.sort(key=tracking_usage) - # When all candidates are equally idle, randomise to avoid always picking - # the first entry in a stable sort. - if all(tracking_usage(ep) == 0 for ep in loaded_and_free): - selected = random.choice(loaded_and_free) - else: - selected = loaded_and_free[0] - else: - # 4️⃣ Endpoints among the candidates that simply have a free slot - endpoints_with_free_slot = [ - ep for ep in candidate_endpoints - if tracking_usage(ep) < get_max_connections(ep) + # 2️⃣.5 Conversation affinity preference — only honour the hint when + # the affine endpoint still advertises the model loaded *and* has a + # free slot. Otherwise fall back to the standard algorithm. + if affine_ep: + ep_loaded = { + ep: set(models) + for ep, models in zip(candidate_endpoints, loaded_sets) + } + if (affine_ep in candidate_endpoints + and model in ep_loaded.get(affine_ep, set()) + and tracking_usage(affine_ep) < get_max_connections(affine_ep)): + selected = affine_ep + + if selected is None: + # 3️⃣ Endpoints that have the model loaded *and* a free slot + loaded_and_free = [ + ep for ep, models in zip(candidate_endpoints, loaded_sets) + if model in models and tracking_usage(ep) < get_max_connections(ep) ] - if endpoints_with_free_slot: + if loaded_and_free: if config.priority_routing: - endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) - endpoints_with_free_slot.sort(key=utilization_ratio) - selected = endpoints_with_free_slot[0] + # WRR: sort by config order first (stable), then by utilization ratio. + # Stable sort preserves priority for equal-ratio endpoints. + loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999)) + loaded_and_free.sort(key=utilization_ratio) + selected = loaded_and_free[0] else: - # Sort by total endpoint load (ascending) to prefer idle endpoints. - endpoints_with_free_slot.sort( - key=lambda ep: sum(usage_counts.get(ep, {}).values()) - ) - if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): - selected = random.choice(endpoints_with_free_slot) + # Sort ascending for load balancing — all endpoints here already have the + # model loaded, so there is no model-switching cost to optimise for. + loaded_and_free.sort(key=tracking_usage) + # When all candidates are equally idle, randomise to avoid always picking + # the first entry in a stable sort. + if all(tracking_usage(ep) == 0 for ep in loaded_and_free): + selected = random.choice(loaded_and_free) else: - selected = endpoints_with_free_slot[0] + selected = loaded_and_free[0] else: - # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) - if config.priority_routing: - selected = min( - candidate_endpoints, - key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), - ) + # 4️⃣ Endpoints among the candidates that simply have a free slot + endpoints_with_free_slot = [ + ep for ep in candidate_endpoints + if tracking_usage(ep) < get_max_connections(ep) + ] + + if endpoints_with_free_slot: + if config.priority_routing: + endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999)) + endpoints_with_free_slot.sort(key=utilization_ratio) + selected = endpoints_with_free_slot[0] + else: + # Sort by total endpoint load (ascending) to prefer idle endpoints. + endpoints_with_free_slot.sort( + key=lambda ep: sum(usage_counts.get(ep, {}).values()) + ) + if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot): + selected = random.choice(endpoints_with_free_slot) + else: + selected = endpoints_with_free_slot[0] else: - selected = min(candidate_endpoints, key=tracking_usage) + # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue) + if config.priority_routing: + selected = min( + candidate_endpoints, + key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)), + ) + else: + selected = min(candidate_endpoints, key=tracking_usage) tracking_model = get_tracking_model(selected, model) snapshot = None @@ -1875,6 +1957,15 @@ async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]: snapshot = _capture_snapshot() if snapshot is not None: await _distribute_snapshot(snapshot) + # Record / refresh affinity *after* releasing usage_lock. + if reserve and config.conversation_affinity and affinity_key: + expires_at = time.monotonic() + config.conversation_affinity_ttl + async with _affinity_lock: + _affinity_map[affinity_key] = (selected, expires_at) + if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: + now = time.monotonic() + for k in [k for k, v in _affinity_map.items() if v[1] < now]: + _affinity_map.pop(k, None) return selected, tracking_model # ------------------------------------------------------------- @@ -1925,7 +2016,8 @@ async def proxy(request: Request): yield _cached return StreamingResponse(_serve_cached_generate(), media_type="application/json") - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, None, prompt) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: @@ -2095,7 +2187,8 @@ async def chat_proxy(request: Request): opt = True else: opt = False - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: @@ -3228,7 +3321,8 @@ async def openai_chat_completions_proxy(request: Request): return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json") # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Helpers and API call — done in handler scope so try/except works reliably async def _normalize_images_in_messages(msgs: list) -> list: @@ -3538,7 +3632,8 @@ async def openai_completions_proxy(request: Request): return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json") # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) + _affinity_key = _conversation_fingerprint(model, None, prompt) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) # 3. Async generator that streams completions data and decrements the counter From 4acbaeb29c53a2c516a6b6e5f02b9f09170bb378 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 11:05:34 +0200 Subject: [PATCH 02/22] fix: stopping background task properly on shutdown --- router.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/router.py b/router.py index 34e02da..e007f83 100644 --- a/router.py +++ b/router.py @@ -4123,6 +4123,16 @@ async def startup_event() -> None: @app.on_event("shutdown") async def shutdown_event() -> None: await close_all_sse_queues() + + # Stop background tasks first so they stop touching the DB before we close it. + for t in (token_worker_task, flush_task): + if t is not None: + t.cancel() + try: + await t + except (asyncio.CancelledError, Exception): + pass + await flush_remaining_buffers() await app_state["session"].close() @@ -4142,7 +4152,11 @@ async def shutdown_event() -> None: except Exception as e: print(f"[shutdown] Error closing httpx client {ep}: {e}") - if token_worker_task is not None: - token_worker_task.cancel() - if flush_task is not None: - flush_task.cancel() + # Close the aiosqlite connection last — its worker thread is non-daemon + # and would otherwise keep the interpreter alive after lifespan completes. + if db is not None: + try: + await db.close() + print("[shutdown] Closed token DB connection.") + except Exception as e: + print(f"[shutdown] Error closing DB: {e}") From aa7ec6354a72a5fd18a0fb8800c9d3b684109d39 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 13:38:37 +0200 Subject: [PATCH 03/22] feat: visualization of conversation affinity in dashboard --- config.yaml | 26 ++++++--- doc/configuration.md | 85 +++++++++++++++++++++++++++ doc/monitoring.md | 33 +++++++++++ router.py | 47 +++++++++++++-- static/index.html | 134 +++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 306 insertions(+), 19 deletions(-) diff --git a/config.yaml b/config.yaml index 6fc57e6..2107a3c 100644 --- a/config.yaml +++ b/config.yaml @@ -27,14 +27,24 @@ max_concurrent_connections: 2 # priority_routing: true # Conversation affinity (optional, default: false). -# Routes follow-up requests back to the endpoint that previously served the -# same conversation so the llama.cpp / Ollama prompt cache (KV cache) stays -# warm — first turn does a cold prefill, follow-ups skip it. Soft preference: -# falls back to the standard algorithm when the affine endpoint no longer has -# the model loaded or has no free slot. Conversations are fingerprinted by -# (model, first system + first user turn). -# conversation_affinity: true -# conversation_affinity_ttl: 300 # seconds; matches Ollama's default keep_alive +# Pins a conversation to the endpoint that served its first turn so the +# llama.cpp / Ollama prompt cache (KV cache) stays warm — first turn pays +# the cold prefill, every follow-up turn reuses the same prefix. +# +# Fingerprint = sha1(model + leading system messages + first user turn). +# Same chat → same fingerprint on every follow-up turn → same pin, TTL +# refreshed on each reuse. Soft preference: if the pinned endpoint no +# longer has the model loaded or has no free slot, the standard algorithm +# takes over (no failure, just a cache miss). +# +# Heads-up: most chat UIs (Open WebUI, LibreChat, …) fire side requests for +# title / tag / follow-up generation. Those have their own first turn and +# therefore their own pin, so a single visible "chat" may show several dots +# in the dashboard's Affinity column. That is correct — each pin matches a +# real warm KV prefix on the backend. See doc/configuration.md for details. +conversation_affinity: true +conversation_affinity_ttl: 300 # seconds of inactivity before a pin expires; + # bumped on every reuse. Matches Ollama's default keep_alive. # Optional router-level API key that gates router/API/web UI access (leave empty to disable) nomyo-router-api-key: "" diff --git a/doc/configuration.md b/doc/configuration.md index 7d9986a..1addd66 100644 --- a/doc/configuration.md +++ b/doc/configuration.md @@ -166,6 +166,91 @@ With this config the primary handles up to 4 concurrent requests before the seco --- +### `conversation_affinity` + +**Type**: `bool` (optional) + +**Default**: `false` + +**Companion setting**: [`conversation_affinity_ttl`](#conversation_affinity_ttl) + +**Description**: When enabled, the router prefers to send follow-up requests of the same conversation back to the endpoint that already served the first turn. This keeps the backend's prompt cache (the llama.cpp / Ollama **KV cache**) warm: the first user turn pays the cold prefill cost, every later turn reuses the same prefix and only generates new tokens. It is a **soft preference** — when the previously-chosen endpoint is no longer eligible (model unloaded, no free slot), the router falls back to the standard selection algorithm (`priority_routing` or random). + +#### How a conversation is identified + +The router does **not** track session IDs or auth tokens. It computes a stable fingerprint per request from: + +``` +SHA1( model + + every leading message with role="system" + + the first message with role="user" ) +``` + +Anything after the first user turn is ignored — those later messages extend the same KV prefix, so they don't change the cache identity. + +**What this means in practice** + +| You send… | Fingerprint behaves like… | +|---|---| +| Turn 2 of the same chat (history grows but first system+user are unchanged) | **Same** as turn 1 → pin is reused and TTL refreshed | +| Turn 1 of a fresh chat | **New** fingerprint → new pin | +| Same first user prompt but a different model | **New** fingerprint (model is part of the hash) | +| Same chat but the client mutates the system prompt between turns (e.g. injects a fresh timestamp) | **New** fingerprint — the affinity will not stick | + +#### TTL and refresh + +Every time `choose_endpoint` returns a pinned endpoint, the entry's expiry is bumped to `now + conversation_affinity_ttl`. An idle conversation drops out of the map once that window elapses without traffic. Default 300 s matches Ollama's default `keep_alive` — once the backend has unloaded the model, the KV cache is gone too, so a stale pin would be pointless anyway. + +#### Why the dashboard may show more than one dot per visible conversation + +The fingerprint is computed per **HTTP request**, not per chat-window. Most chat UIs (Open WebUI in particular) fire several **auxiliary** requests alongside the real conversation: + +- *Title generation* — synthetic system prompt + the user message as content +- *Follow-up question suggestion* — synthetic system prompt + the conversation as content +- *Tag generation*, *memory extraction*, *retrieval query rewriting*, etc. + +Each of those has its own `(system + first user turn)` and therefore its own fingerprint and its own pin in [the affinity dot matrix](monitoring.md#affinity-stats-conversation-affinity). They all *correctly* refer to a real warm KV-cache prefix on the backend, so the routing they drive is right — they just don't visually map 1:1 to a user-perceived "conversation." + +#### Example + +```yaml +endpoints: + - http://gpu-primary:11434 + - http://gpu-secondary:11434 + +conversation_affinity: true +conversation_affinity_ttl: 300 +``` + +With this configuration, a chat that starts on `gpu-primary` will keep returning to `gpu-primary` for follow-up turns as long as the model is still loaded there and a slot is free, even if `gpu-secondary` happens to be more idle at that moment. Cold-prefill cost is paid once instead of once per turn. + +#### When to enable + +- ✅ Interactive chat workloads with long histories — the prefill savings on every follow-up turn are substantial. +- ✅ Multi-endpoint deployments where models are loaded on more than one node. +- ❌ Pure one-shot / single-turn workloads (no KV-cache to keep warm). +- ❌ When you specifically want strict load-balancing parity — affinity intentionally biases against perfect balance. + +--- + +### `conversation_affinity_ttl` + +**Type**: `int` (seconds, optional) + +**Default**: `300` + +**Description**: How long a conversation stays pinned to its endpoint after the last request that touched it. Refreshed on every reuse — so an actively-used conversation keeps its pin indefinitely; an abandoned one expires after `conversation_affinity_ttl` seconds of silence. + +**Recommendation**: leave this aligned with the backend's `keep_alive` window. If the model is unloaded by the backend, the KV cache is gone and there is no benefit to keeping the pin. + +**Example**: +```yaml +conversation_affinity: true +conversation_affinity_ttl: 600 # half an hour of inactivity before un-pinning +``` + +--- + ### `router_api_key` **Type**: `str` (optional) diff --git a/doc/monitoring.md b/doc/monitoring.md index b5bcbff..ab75d25 100644 --- a/doc/monitoring.md +++ b/doc/monitoring.md @@ -166,6 +166,39 @@ curl -X POST http://localhost:12434/api/cache/invalidate Clears all cached entries and resets hit/miss counters. +### Affinity Stats (Conversation Affinity) + +```bash +curl http://localhost:12434/api/affinity_stats +``` + +Response when [`conversation_affinity`](configuration.md#conversation_affinity) is enabled: + +```json +{ + "enabled": true, + "ttl": 300, + "entries": [ + { "endpoint": "http://gpu-primary:11434", "model": "llama3.2:latest", "remaining": 287.4 }, + { "endpoint": "http://gpu-primary:11434", "model": "llama3.2:latest", "remaining": 113.0 }, + { "endpoint": "http://gpu-secondary:11434", "model": "qwen2.5-coder:7b", "remaining": 44.8 } + ] +} +``` + +Response when the feature is disabled: +```json +{ "enabled": false, "ttl": 300, "entries": [] } +``` + +- One element per **live pinned conversation** (no fingerprints or content — just the endpoint/model the pin points to and how many seconds it has left before expiry). +- Aggregation by `(endpoint, model)` is left to the consumer: the dashboard does this client-side. +- The endpoint is gated by the same `nomyo-router-api-key` middleware as the rest of `/api/*`. + +The dashboard's **Running Models (PS) → Affinity** column is rendered from this data. The column auto-hides when `enabled: false`. Each row shows one dot per live pin against that `(endpoint, model)` pair; dot opacity = `remaining / ttl` (floor 0.15), so freshly-routed pins are solid and pins close to expiry fade out. A `+N` overflow badge appears once a single (endpoint, model) holds more than 12 active pins; an em-dash (`—`) marks an `(endpoint, model)` with no live pins. + +> Multiple dots for what looks like "one chat window" is normal — most chat UIs (Open WebUI, LibreChat, …) fire auxiliary requests (title generation, follow-up suggestions, tag extraction) that have their own first-turn fingerprint and therefore their own pin. See [Conversation Affinity → Why the dashboard may show more than one dot per visible conversation](configuration.md#conversation_affinity) for the details. + ### Real-time Usage Stream ```bash diff --git a/router.py b/router.py index e007f83..c075e20 100644 --- a/router.py +++ b/router.py @@ -445,10 +445,12 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict( usage_lock = asyncio.Lock() # protects access to usage_counts token_usage_lock = asyncio.Lock() -# Conversation affinity map: fingerprint -> (endpoint, expires_at_monotonic). +# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic). # Keeps the same conversation pinned to the endpoint that already has its -# KV-cache prefix warm. Never held together with usage_lock. -_affinity_map: Dict[str, tuple[str, float]] = {} +# KV-cache prefix warm. Model is stored so the dashboard can aggregate live +# entries per (endpoint, model) without recomputing fingerprints. +# Never held together with usage_lock. +_affinity_map: Dict[str, tuple[str, str, float]] = {} _affinity_lock = asyncio.Lock() _AFFINITY_MAX_ENTRIES = 10000 @@ -1859,7 +1861,7 @@ async def choose_endpoint(model: str, reserve: bool = True, async with _affinity_lock: entry = _affinity_map.get(affinity_key) if entry is not None: - ep, expires_at = entry + ep, _stored_model, expires_at = entry if expires_at < time.monotonic(): _affinity_map.pop(affinity_key, None) else: @@ -1961,10 +1963,10 @@ async def choose_endpoint(model: str, reserve: bool = True, if reserve and config.conversation_affinity and affinity_key: expires_at = time.monotonic() + config.conversation_affinity_ttl async with _affinity_lock: - _affinity_map[affinity_key] = (selected, expires_at) + _affinity_map[affinity_key] = (selected, model, expires_at) if len(_affinity_map) > _AFFINITY_MAX_ENTRIES: now = time.monotonic() - for k in [k for k, v in _affinity_map.items() if v[1] < now]: + for k in [k for k, v in _affinity_map.items() if v[2] < now]: _affinity_map.pop(k, None) return selected, tracking_model @@ -3103,6 +3105,39 @@ async def ps_details_proxy(request: Request): return JSONResponse(content={"models": models}, status_code=200) +# ------------------------------------------------------------- +# 18b. Conversation-affinity stats – feeds the PS-table dot matrix +# ------------------------------------------------------------- +@app.get("/api/affinity_stats") +async def affinity_stats(request: Request): + """ + Aggregate live conversation-affinity pins, one entry per pinned conversation. + Each entry exposes only the endpoint, model, and remaining TTL in seconds — + no fingerprints or content. When conversation_affinity is disabled the + `entries` list is always empty. + """ + if not config.conversation_affinity: + return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []} + + now = time.monotonic() + entries: list[dict] = [] + async with _affinity_lock: + for fp, (ep, mdl, expires_at) in list(_affinity_map.items()): + remaining = expires_at - now + if remaining <= 0: + _affinity_map.pop(fp, None) + continue + entries.append({ + "endpoint": ep, + "model": mdl, + "remaining": round(remaining, 2), + }) + return { + "enabled": True, + "ttl": config.conversation_affinity_ttl, + "entries": entries, + } + # ------------------------------------------------------------- # 19. Proxy usage route – for monitoring # ------------------------------------------------------------- diff --git a/static/index.html b/static/index.html index 419d7bb..b29f22b 100644 --- a/static/index.html +++ b/static/index.html @@ -121,6 +121,45 @@ .ps-subrow + .ps-subrow { margin-top: 2px; } + #ps-table .affinity-col, + #ps-table .affinity-cell { + display: none; + } + #ps-table.affinity-on .affinity-col, + #ps-table.affinity-on .affinity-cell { + display: table-cell; + width: 90px; + text-align: center; + padding-left: 6px; + padding-right: 6px; + } + #ps-table.affinity-on .affinity-dots { + max-width: 78px; + } + .affinity-dots { + display: inline-flex; + flex-wrap: wrap; + gap: 3px; + align-items: center; + line-height: 1; + } + .affinity-dot { + width: 8px; + height: 8px; + border-radius: 50%; + background: #2e7d32; + display: inline-block; + transition: opacity 1s linear; + } + .affinity-overflow { + font-size: 10px; + color: #555; + margin-left: 2px; + } + .affinity-empty { + color: #bbb; + font-size: 11px; + } #ps-table { width: max-content; min-width: 100%; @@ -131,13 +170,13 @@ max-width: 300px; white-space: nowrap; } - /* Optimize narrow columns */ - #ps-table th:nth-child(3), - #ps-table td:nth-child(3), + /* Optimize narrow columns (Params / Quant / Ctx) */ #ps-table th:nth-child(4), #ps-table td:nth-child(4), #ps-table th:nth-child(5), - #ps-table td:nth-child(5) { + #ps-table td:nth-child(5), + #ps-table th:nth-child(6), + #ps-table td:nth-child(6) { width: 80px; text-align: center; } @@ -395,6 +434,7 @@ Model Endpoint + Affinity Params Quant Ctx @@ -406,7 +446,7 @@ - Loading… + Loading… @@ -932,6 +972,14 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { return items.map((item) => `
${item || ""}
`).join(""); }; + const escapeAttr = (s) => String(s).replace(/&/g, "&").replace(/"/g, """).replace(//g, ">"); + const renderAffinitySlots = (endpoints, modelName) => { + if (!endpoints.length) return ""; + return endpoints + .map((ep) => `
`) + .join(""); + }; + body.innerHTML = Array.from(grouped.entries()) .map(([modelName, modelInstances]) => { const existingRow = psRows.get(modelName); @@ -955,6 +1003,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { return ` ${modelName} stats ${renderInstanceList(endpoints)} + ${renderAffinitySlots(endpoints, modelName)} ${params} ${quant} ${ctx} @@ -972,11 +1021,83 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { const model = row.dataset.model; if (model) psRows.set(model, row); }); + renderAffinityDots(); } catch (e) { console.error(e); } } + /* ---------- Conversation-affinity dots ---------- */ + const AFFINITY_MAX_DOTS = 12; + let affinityIndex = new Map(); // `${endpoint}|${model}` -> array of {expiresAt} + let affinityTtl = 300; + let affinityEnabled = false; + + async function loadAffinity() { + try { + const data = await fetchJSON("/api/affinity_stats"); + affinityEnabled = !!data.enabled; + affinityTtl = Number(data.ttl) || 300; + const now = Date.now() / 1000; + const idx = new Map(); + for (const e of data.entries || []) { + const key = `${e.endpoint}|${e.model}`; + if (!idx.has(key)) idx.set(key, []); + idx.get(key).push({ expiresAt: now + Number(e.remaining) }); + } + affinityIndex = idx; + applyAffinityColumnVisibility(); + renderAffinityDots(); + } catch (err) { + // Endpoint may 404 on older deployments — silently degrade. + affinityEnabled = false; + affinityIndex = new Map(); + applyAffinityColumnVisibility(); + renderAffinityDots(); + } + } + + function applyAffinityColumnVisibility() { + const table = document.getElementById("ps-table"); + if (!table) return; + table.classList.toggle("affinity-on", affinityEnabled); + } + + function renderAffinityDots() { + const spans = document.querySelectorAll(".affinity-dots"); + if (!spans.length) return; + const now = Date.now() / 1000; + spans.forEach((span) => { + const ep = span.dataset.endpoint; + const mdl = span.dataset.model; + const key = `${ep}|${mdl}`; + const pins = (affinityIndex.get(key) || []).filter((p) => p.expiresAt > now); + if (pins.length !== (affinityIndex.get(key) || []).length) { + if (pins.length) affinityIndex.set(key, pins); + else affinityIndex.delete(key); + } + if (!pins.length) { + span.innerHTML = affinityEnabled + ? `` + : ""; + return; + } + // Sort freshest first so visible dots are the most "recent". + pins.sort((a, b) => b.expiresAt - a.expiresAt); + const visible = pins.slice(0, AFFINITY_MAX_DOTS); + const overflow = pins.length - visible.length; + const dotsHtml = visible + .map((p) => { + const remaining = Math.max(0, p.expiresAt - now); + const opacity = Math.max(0.15, Math.min(1, remaining / affinityTtl)); + const secs = Math.round(remaining); + return ``; + }) + .join(""); + span.innerHTML = dotsHtml + (overflow > 0 ? `+${overflow}` : ""); + }); + } + /* ---------- Usage Chart (stacked‑percentage) ---------- */ function getColor(seed) { const h = Math.abs(hashString(seed) % 360); @@ -1173,10 +1294,13 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { loadEndpoints(); loadTags(); loadPS(); + loadAffinity(); loadUsage(); initHeaderChart(); setInterval(tickTpsChart, 1000); setInterval(loadPS, 60_000); + setInterval(loadAffinity, 15_000); + setInterval(renderAffinityDots, 2_000); setInterval(loadEndpoints, 300_000); /* show logic */ From ad0be90a70cdcdc512b8ec17c4c83db04eaf15f8 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 14:35:45 +0200 Subject: [PATCH 04/22] fix: model naming for affinity status for llama endpoints --- router.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/router.py b/router.py index c075e20..6656403 100644 --- a/router.py +++ b/router.py @@ -3121,15 +3121,19 @@ async def affinity_stats(request: Request): now = time.monotonic() entries: list[dict] = [] + llama_eps = set(config.llama_server_endpoints) async with _affinity_lock: for fp, (ep, mdl, expires_at) in list(_affinity_map.items()): remaining = expires_at - now if remaining <= 0: _affinity_map.pop(fp, None) continue + # Mirror the normalisation used by /api/ps_details so the dashboard + # can join affinity entries to PS rows by (endpoint, model). + display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl entries.append({ "endpoint": ep, - "model": mdl, + "model": display_model, "remaining": round(remaining, 2), }) return { From 84e3b30f2f948a63e7d363d62bddcb4396bd6491 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 14:59:05 +0200 Subject: [PATCH 05/22] fix: removed dead config key --- router.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/router.py b/router.py index 6656403..111d638 100644 --- a/router.py +++ b/router.py @@ -2,7 +2,7 @@ title: NOMYO Router - an (O)llama and OpenAI API v1 Proxy with Endpoint:Model aware routing author: alpha-nerd-nomyo author_url: https://github.com/nomyo-ai -version: 0.7 +version: 0.9 license: AGPL """ # ------------------------------------------------------------- @@ -256,9 +256,8 @@ class Config(BaseSettings): cache_history_weight: float = Field(default=0.3) class Config: - # Load from `config.yaml` first, then from env variables + # YAML loading is handled manually via Config.from_yaml(); env vars use this prefix. env_prefix = "NOMYO_ROUTER_" - yaml_file = Path("config.yaml") # relative to cwd @classmethod def _expand_env_refs(cls, obj): From 6c869aa3052b3b98ca486c42ad4db2dacd240619 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 16:22:48 +0200 Subject: [PATCH 06/22] =?UTF-8?q?fix:=20mitigating=20WARNING=20router.py:1?= =?UTF-8?q?421:9=20[cfg-resource-leak]=20=E2=80=94=20Image.open=20acquires?= =?UTF-8?q?=20file=20handle=20but=20not=20all=20exit=20paths=20release=20i?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- router.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/router.py b/router.py index 111d638..08225fb 100644 --- a/router.py +++ b/router.py @@ -1418,30 +1418,30 @@ def resize_image_if_needed(image_data): pass # Decode the base64 image data image_bytes = base64.b64decode(image_data) - image = Image.open(io.BytesIO(image_bytes)) - if image.mode not in ("RGB", "L"): - image = image.convert("RGB") + with Image.open(io.BytesIO(image_bytes)) as image: + if image.mode not in ("RGB", "L"): + image = image.convert("RGB") - # Get current size - width, height = image.size + # Get current size + width, height = image.size - # Calculate the new dimensions while maintaining aspect ratio - if width > 512 or height > 512: - aspect_ratio = width / height - if aspect_ratio > 1: # Width is larger - new_width = 512 - new_height = int(512 / aspect_ratio) - else: # Height is larger - new_height = 512 - new_width = int(512 * aspect_ratio) + # Calculate the new dimensions while maintaining aspect ratio + if width > 512 or height > 512: + aspect_ratio = width / height + if aspect_ratio > 1: # Width is larger + new_width = 512 + new_height = int(512 / aspect_ratio) + else: # Height is larger + new_height = 512 + new_width = int(512 * aspect_ratio) - image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) - # Encode the resized image back to base64 - buffered = io.BytesIO() - image.save(buffered, format="PNG") - resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8") - return resized_image_data + # Encode the resized image back to base64 + buffered = io.BytesIO() + image.save(buffered, format="PNG") + resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8") + return resized_image_data except Exception as e: print(f"Error processing image: {e}") From 85f6f780efa4362b1351ff552be3989dbef5a350 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 17:05:33 +0200 Subject: [PATCH 07/22] feat: nyx triage --- .nyx/triage.json | 84 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 .nyx/triage.json diff --git a/.nyx/triage.json b/.nyx/triage.json new file mode 100644 index 0000000..e549fde --- /dev/null +++ b/.nyx/triage.json @@ -0,0 +1,84 @@ +{ + "version": 1, + "decisions": [ + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", + "state": "false_positive", + "rule_id": "py.auth.token_override_without_validation", + "path": "router.py" + }, + { + "fingerprint": "bf20a694d2f903f3e8cc287ad42f4291558cc769a0dccabea8ecf8c0eaa12eee", + "state": "false_positive", + "rule_id": "state-resource-leak", + "path": "db.py" + }, + { + "fingerprint": "52591ad15f1d27aa6394afbb7d150a8246c3d8032ca705fc39f541e2e71bbf7f", + "state": "false_positive", + "rule_id": "state-resource-leak", + "path": "router.py" + }, + { + "fingerprint": "8f3331f28c2b1839039946ba5eb7fe45b0fa19e2b8d4ecb60c3907ec19e8330b", + "state": "accepted_risk", + "rule_id": "py.crypto.sha1", + "path": "router.py" + } + ], + "suppression_rules": [] +} \ No newline at end of file From 5ce4eed0ad6d139c6bbfc64dfcd12688fbfba1cf Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 17:32:15 +0200 Subject: [PATCH 08/22] chore: deduplicate nyx triage entries --- .nyx/triage.json | 54 ------------------------------------------------ 1 file changed, 54 deletions(-) diff --git a/.nyx/triage.json b/.nyx/triage.json index e549fde..86dc21c 100644 --- a/.nyx/triage.json +++ b/.nyx/triage.json @@ -1,60 +1,6 @@ { "version": 1, "decisions": [ - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, - { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" - }, { "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", "state": "false_positive", From e484f122281ed59258b5faf1e923b1917178d465 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Wed, 13 May 2026 19:07:08 +0200 Subject: [PATCH 09/22] fix: triage by suppression_rules for CI --- .nyx/triage.json | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/.nyx/triage.json b/.nyx/triage.json index 86dc21c..2c4dc31 100644 --- a/.nyx/triage.json +++ b/.nyx/triage.json @@ -1,30 +1,24 @@ { "version": 1, - "decisions": [ + "decisions": [], + "suppression_rules": [ { - "fingerprint": "2ef3f80c7bf8282ea4dcf8fa88a2df1aa7fcb5d9009f7a49b4a9eb0379bd592e", - "state": "false_positive", - "rule_id": "py.auth.token_override_without_validation", - "path": "router.py" + "by": "rule", + "value": "py.auth.token_override_without_validation", + "state": "suppressed", + "note": "false_positive: token validation handled upstream by middleware" }, { - "fingerprint": "bf20a694d2f903f3e8cc287ad42f4291558cc769a0dccabea8ecf8c0eaa12eee", - "state": "false_positive", - "rule_id": "state-resource-leak", - "path": "db.py" + "by": "rule", + "value": "state-resource-leak", + "state": "suppressed", + "note": "false_positive: resource lifecycle managed externally" }, { - "fingerprint": "52591ad15f1d27aa6394afbb7d150a8246c3d8032ca705fc39f541e2e71bbf7f", - "state": "false_positive", - "rule_id": "state-resource-leak", - "path": "router.py" - }, - { - "fingerprint": "8f3331f28c2b1839039946ba5eb7fe45b0fa19e2b8d4ecb60c3907ec19e8330b", - "state": "accepted_risk", - "rule_id": "py.crypto.sha1", - "path": "router.py" + "by": "rule", + "value": "py.crypto.sha1", + "state": "suppressed", + "note": "accepted_risk: used for non-security checksum only" } - ], - "suppression_rules": [] + ] } \ No newline at end of file From e7b2c7e33339898f30712b3b12ebbbda1455fbea Mon Sep 17 00:00:00 2001 From: Renovate Bot Date: Fri, 15 May 2026 07:17:06 +0000 Subject: [PATCH 10/22] chore(deps): update dependency idna to v3.15 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 159c062..f9ef22e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ frozenlist==1.8.0 h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 -idna==3.14 +idna==3.15 jiter==0.14.0 multidict==6.7.1 ollama==0.6.2 From 5bd7d4d04de0d1ffcb69e54d7744af77bb9664e3 Mon Sep 17 00:00:00 2001 From: alpha-nerd Date: Fri, 15 May 2026 09:23:59 +0200 Subject: [PATCH 11/22] .forgejo/workflows/docker-publish-semantic.yml aktualisiert --- .forgejo/workflows/docker-publish-semantic.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.forgejo/workflows/docker-publish-semantic.yml b/.forgejo/workflows/docker-publish-semantic.yml index 2fa59d5..ebbfd36 100644 --- a/.forgejo/workflows/docker-publish-semantic.yml +++ b/.forgejo/workflows/docker-publish-semantic.yml @@ -87,8 +87,6 @@ jobs: build-args: | SEMANTIC_CACHE=true tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }} - cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }} - cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }},mode=max merge: runs-on: docker-amd64 From e0d78012207ebb86d0e6456101efd1959021cac0 Mon Sep 17 00:00:00 2001 From: Renovate Bot Date: Fri, 15 May 2026 07:52:17 +0000 Subject: [PATCH 12/22] chore(deps): update dependency tiktoken to v0.13.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f9ef22e..df0f020 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ PyYAML==6.0.3 sniffio==1.3.1 starlette==0.52.1 truststore==0.10.4 -tiktoken==0.12.0 +tiktoken==0.13.0 tqdm==4.67.3 typing-inspection==0.4.2 typing_extensions==4.15.0 From 5b938302d6eab98c7bde909a6fea1c4e137b46b3 Mon Sep 17 00:00:00 2001 From: Renovate Bot Date: Fri, 15 May 2026 08:17:23 +0000 Subject: [PATCH 13/22] chore(deps): update dependency uvicorn to v0.47.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index df0f020..c187485 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,7 +36,7 @@ tiktoken==0.13.0 tqdm==4.67.3 typing-inspection==0.4.2 typing_extensions==4.15.0 -uvicorn==0.46.0 +uvicorn==0.47.0 uvloop yarl==1.23.0 aiosqlite From 29ee3600822e88a10474a49c7089233ab480ef40 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Fri, 15 May 2026 16:43:12 +0200 Subject: [PATCH 14/22] feat: adding automated tests --- .forgejo/workflows/pr-tests.yml | 39 ++++ .gitignore | 5 +- test/config_test.yaml | 15 ++ test/conftest.py | 226 +++++++++++++++++++++ test/pytest.ini | 7 + test/requirements_test.txt | 4 + test/test.md | 60 ++++++ test/test_api_integration.py | 304 ++++++++++++++++++++++++++++ test/test_api_validation.py | 230 +++++++++++++++++++++ test/test_cache.py | 329 ++++++++++++++++++++++++++++++ test/test_choose_endpoint.py | 345 ++++++++++++++++++++++++++++++++ test/test_db.py | 197 ++++++++++++++++++ test/test_fetch.py | 180 +++++++++++++++++ test/test_openai_proxies.py | 181 +++++++++++++++++ test/test_unit_context.py | 116 +++++++++++ test/test_unit_helpers.py | 279 ++++++++++++++++++++++++++ test/test_unit_rechunk.py | 173 ++++++++++++++++ test/test_unit_transforms.py | 200 ++++++++++++++++++ 18 files changed, 2886 insertions(+), 4 deletions(-) create mode 100644 .forgejo/workflows/pr-tests.yml create mode 100644 test/config_test.yaml create mode 100644 test/conftest.py create mode 100644 test/pytest.ini create mode 100644 test/requirements_test.txt create mode 100644 test/test.md create mode 100644 test/test_api_integration.py create mode 100644 test/test_api_validation.py create mode 100644 test/test_cache.py create mode 100644 test/test_choose_endpoint.py create mode 100644 test/test_db.py create mode 100644 test/test_fetch.py create mode 100644 test/test_openai_proxies.py create mode 100644 test/test_unit_context.py create mode 100644 test/test_unit_helpers.py create mode 100644 test/test_unit_rechunk.py create mode 100644 test/test_unit_transforms.py diff --git a/.forgejo/workflows/pr-tests.yml b/.forgejo/workflows/pr-tests.yml new file mode 100644 index 0000000..aa96b84 --- /dev/null +++ b/.forgejo/workflows/pr-tests.yml @@ -0,0 +1,39 @@ +name: PR Tests +on: [pull_request] +jobs: + test: + runs-on: docker-arm64 + container: + image: python:3.12-slim + env: + CMAKE_BUILD_PARALLEL_LEVEL: "4" + steps: + - name: Install system deps + run: | + apt-get update + apt-get install -y --no-install-recommends \ + git ca-certificates \ + build-essential pkg-config + rm -rf /var/lib/apt/lists/* + - name: Checkout + run: | + git config --global --add safe.directory "$PWD" + git clone --depth=1 \ + "https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" . + git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr" + git checkout pr + - name: Fetch action source + run: | + git clone --depth=1 --branch master \ + "https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \ + ./.run-tests + - uses: ./.run-tests/run-tests + with: + setup: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r test/requirements_test.txt + command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml + artifacts-path: | + report.xml + coverage.xml diff --git a/.gitignore b/.gitignore index cfce37c..7cd8431 100644 --- a/.gitignore +++ b/.gitignore @@ -66,7 +66,4 @@ config.yaml # SQLite *.db* -*settings.json - -# Test suite (local only, not committed yet) -test/ \ No newline at end of file +*settings.json \ No newline at end of file diff --git a/test/config_test.yaml b/test/config_test.yaml new file mode 100644 index 0000000..f05cce7 --- /dev/null +++ b/test/config_test.yaml @@ -0,0 +1,15 @@ +endpoints: + - http://192.168.0.51:12434 + +llama_server_endpoints: + - http://192.168.0.51:12434/v1 + +max_concurrent_connections: 2 + +api_keys: + "http://192.168.0.51:12434": "ollama" + "http://192.168.0.51:12434/v1": "llama" + +db_path: "/tmp/nomyo_test_tokens.db" + +cache_enabled: false diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..c95fa2d --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,226 @@ +""" +Test configuration for nomyo-router. + +Run from project root: + pytest test/ -v + pytest test/ -m "not integration" # skip real-server tests + pytest test/ -m integration -v # only real-server tests + +Environment variables: + NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434) + NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1) + NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset) + NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset) +""" +import asyncio +import os +import ssl +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import aiohttp +import httpx +import pytest + +_TEST_DIR = Path(__file__).parent +# Must be set before importing router so the module-level Config.from_yaml uses test config +os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml")) + +sys.path.insert(0, str(_TEST_DIR.parent)) + +import router # noqa: E402 + +TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434") +TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1") + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + "integration: tests that require a real backend at 192.168.0.50:12434", + ) + + +# ── Config mocks ───────────────────────────────────────────────────────────── + +@pytest.fixture +def mock_config(): + """Minimal config pointing at TEST_OLLAMA / TEST_LLAMA.""" + cfg = MagicMock() + cfg.endpoints = [TEST_OLLAMA] + cfg.llama_server_endpoints = [TEST_LLAMA] + cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"} + cfg.max_concurrent_connections = 2 + cfg.router_api_key = None + cfg.cache_enabled = False + return cfg + + +@pytest.fixture +def mock_config_no_llama(): + """Config with Ollama only, no llama-server.""" + cfg = MagicMock() + cfg.endpoints = [TEST_OLLAMA] + cfg.llama_server_endpoints = [] + cfg.api_keys = {TEST_OLLAMA: "ollama"} + cfg.max_concurrent_connections = 2 + cfg.router_api_key = None + cfg.cache_enabled = False + return cfg + + +@pytest.fixture +def mock_config_with_key(): + """Config with router_api_key set (enables auth middleware).""" + cfg = MagicMock() + cfg.endpoints = [TEST_OLLAMA] + cfg.llama_server_endpoints = [] + cfg.api_keys = {} + cfg.max_concurrent_connections = 2 + cfg.router_api_key = "test-secret-key" + cfg.cache_enabled = False + return cfg + + +# ── aiohttp session (used by fetch tests + choose_endpoint tests) ───────────── + +@pytest.fixture +async def aio_session(): + """Real aiohttp session stored in app_state; intercepted by aioresponses.""" + ssl_ctx = ssl.create_default_context() + conn = aiohttp.TCPConnector(ssl=ssl_ctx) + session = aiohttp.ClientSession(connector=conn) + router.app_state["session"] = session + + # Clear caches to prevent test bleed + router._models_cache.clear() + router._loaded_models_cache.clear() + router._available_error_cache.clear() + router._loaded_error_cache.clear() + router._inflight_available_models.clear() + router._inflight_loaded_models.clear() + router._bg_refresh_available.clear() + router._bg_refresh_loaded.clear() + + yield session + + await session.close() + router.app_state["session"] = None + + +# ── Validation-only HTTP client (no real backend needed) ────────────────────── + +@pytest.fixture +async def client(mock_config, tmp_path): + """httpx client for validation/auth tests — no real backend calls made.""" + from db import TokenDatabase + + ssl_ctx = ssl.create_default_context() + conn = aiohttp.TCPConnector(ssl=ssl_ctx) + session = aiohttp.ClientSession(connector=conn) + + db_inst = TokenDatabase(str(tmp_path / "test.db")) + await db_inst.init_db() + + old_session = router.app_state.get("session") + old_db = router.db + + router.app_state["session"] = session + router.db = db_inst + + with patch.object(router, "config", mock_config): + transport = httpx.ASGITransport(app=router.app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test", timeout=10.0 + ) as c: + yield c + + await session.close() + router.app_state["session"] = old_session + router.db = old_db + + +@pytest.fixture +async def client_auth(mock_config_with_key, tmp_path): + """httpx client with router_api_key configured (for auth middleware tests).""" + from db import TokenDatabase + + ssl_ctx = ssl.create_default_context() + conn = aiohttp.TCPConnector(ssl=ssl_ctx) + session = aiohttp.ClientSession(connector=conn) + + db_inst = TokenDatabase(str(tmp_path / "test_auth.db")) + await db_inst.init_db() + + old_session = router.app_state.get("session") + old_db = router.db + + router.app_state["session"] = session + router.db = db_inst + + with patch.object(router, "config", mock_config_with_key): + transport = httpx.ASGITransport(app=router.app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test", timeout=10.0 + ) as c: + yield c + + await session.close() + router.app_state["session"] = old_session + router.db = old_db + + +# ── Integration client (full startup with real backend) ────────────────────── + +@pytest.fixture(scope="module") +async def integration_client(): + """Full app startup pointing at the real test server.""" + await router.startup_event() + transport = httpx.ASGITransport(app=router.app) + async with httpx.AsyncClient( + transport=transport, + base_url="http://test", + timeout=httpx.Timeout(60.0), + ) as c: + yield c + await router.shutdown_event() + + +# ── Model discovery fixtures ────────────────────────────────────────────────── + +@pytest.fixture(scope="module") +async def chat_model(integration_client): + """Return a chat/generation model name available on the test server.""" + env_model = os.getenv("NOMYO_TEST_MODEL_CHAT") + if env_model: + return env_model + resp = await integration_client.get("/api/tags") + if resp.status_code != 200: + pytest.skip("Cannot reach test server") + models = resp.json().get("models", []) + # Prefer small models for faster tests + for m in models: + name = m.get("name", "") + if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]): + return name + if models: + return models[0]["name"] + pytest.skip("No chat models available on test server") + + +@pytest.fixture(scope="module") +async def embed_model(integration_client): + """Return an embedding model name available on the test server.""" + env_model = os.getenv("NOMYO_TEST_EMBED_MODEL") + if env_model: + return env_model + resp = await integration_client.get("/api/tags") + if resp.status_code != 200: + pytest.skip("Cannot reach test server") + models = resp.json().get("models", []) + for m in models: + name = m.get("name", "") + if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]): + return name + pytest.skip("No embedding model available on test server") diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 0000000..1d05e6d --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,7 @@ +[pytest] +asyncio_mode = auto +markers = + integration: tests that require a real backend at 192.168.0.51:12434 +testpaths = . +filterwarnings = + ignore::pytest.PytestUnhandledThreadExceptionWarning diff --git a/test/requirements_test.txt b/test/requirements_test.txt new file mode 100644 index 0000000..8c7c53f --- /dev/null +++ b/test/requirements_test.txt @@ -0,0 +1,4 @@ +pytest>=8.0 +pytest-asyncio>=0.24 +pytest-cov>=5.0 +aioresponses>=0.7 diff --git a/test/test.md b/test/test.md new file mode 100644 index 0000000..533a3ee --- /dev/null +++ b/test/test.md @@ -0,0 +1,60 @@ +# Testing nomyo-router + +## Setup + +Install test dependencies (from the project root): + +```bash +pip install -r test/requirements_test.txt +``` + +## Running tests + +All commands run from the `test/` directory: + +```bash +cd test +``` + +**All non-integration tests** (no backend required): +```bash +pytest -m "not integration" -v +``` + +**Integration tests only** (requires backend at `192.168.0.51:12434`): +```bash +pytest -m integration -v +``` + +**Everything:** +```bash +pytest -v +``` + +## Test structure + +| File | What it covers | Backend needed | +|---|---|---| +| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No | +| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No | +| `test_unit_context.py` | Context window trimming logic | No | +| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No | +| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No | +| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No | +| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** | + +## Integration test backend + +Integration tests start the router in-process via `startup_event()` and route traffic +through `httpx.ASGITransport` — no separately running router instance is needed. + +They do require a reachable Ollama or llama-server backend. Override the defaults via +environment variables: + +```bash +export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434 +export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise +export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise +``` + +If the backend is unreachable, integration tests are automatically skipped. diff --git a/test/test_api_integration.py b/test/test_api_integration.py new file mode 100644 index 0000000..6c40fdc --- /dev/null +++ b/test/test_api_integration.py @@ -0,0 +1,304 @@ +""" +Integration tests against the real backend at 192.168.0.50:12434. + +Run with: + pytest test/test_api_integration.py -v -m integration + +All tests in this file are marked @pytest.mark.integration. +They require the test server to be reachable and to have at least one +chat model and one embedding model available. + +Env vars to pin specific models: + NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b + NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest +""" +import json + +import pytest + + +pytestmark = pytest.mark.integration + + +# ── Health / discovery routes ───────────────────────────────────────────────── + +class TestDiscoveryRoutes: + async def test_version(self, integration_client): + resp = await integration_client.get("/api/version") + assert resp.status_code == 200 + data = resp.json() + assert "version" in data + assert isinstance(data["version"], str) + + async def test_tags_returns_models(self, integration_client): + resp = await integration_client.get("/api/tags") + assert resp.status_code == 200 + data = resp.json() + assert "models" in data + assert isinstance(data["models"], list) + assert len(data["models"]) > 0 + + async def test_ps_returns_list(self, integration_client): + resp = await integration_client.get("/api/ps") + assert resp.status_code == 200 + data = resp.json() + assert "models" in data + assert isinstance(data["models"], list) + + async def test_v1_models_returns_data(self, integration_client): + resp = await integration_client.get("/v1/models") + assert resp.status_code == 200 + data = resp.json() + assert "data" in data + assert isinstance(data["data"], list) + + async def test_usage_returns_counts(self, integration_client): + resp = await integration_client.get("/api/usage") + assert resp.status_code == 200 + data = resp.json() + assert "usage_counts" in data + assert "token_usage_counts" in data + + async def test_config_returns_endpoints(self, integration_client): + resp = await integration_client.get("/api/config") + assert resp.status_code == 200 + data = resp.json() + assert "endpoints" in data + + async def test_hostname(self, integration_client): + resp = await integration_client.get("/api/hostname") + assert resp.status_code == 200 + assert "hostname" in resp.json() + + async def test_health(self, integration_client): + resp = await integration_client.get("/health") + assert resp.status_code in (200, 503) + data = resp.json() + assert data["status"] in ("ok", "error") + assert "endpoints" in data + + async def test_cache_stats(self, integration_client): + resp = await integration_client.get("/api/cache/stats") + assert resp.status_code == 200 + data = resp.json() + assert "enabled" in data + + +# ── /api/chat ───────────────────────────────────────────────────────────────── + +class TestApiChat: + async def test_non_streaming(self, integration_client, chat_model): + resp = await integration_client.post( + "/api/chat", + json={ + "model": chat_model, + "stream": False, + "messages": [{"role": "user", "content": "Reply with exactly: OK"}], + "options": {"num_predict": 10}, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "message" in data + assert "content" in data["message"] + + async def test_streaming_ndjson(self, integration_client, chat_model): + resp = await integration_client.post( + "/api/chat", + json={ + "model": chat_model, + "stream": True, + "messages": [{"role": "user", "content": "Say hi"}], + "options": {"num_predict": 5}, + }, + ) + assert resp.status_code == 200 + lines = [l for l in resp.text.strip().split("\n") if l.strip()] + assert len(lines) >= 1 + for line in lines: + obj = json.loads(line) + assert "model" in obj + + async def test_non_streaming_has_token_counts(self, integration_client, chat_model): + resp = await integration_client.post( + "/api/chat", + json={ + "model": chat_model, + "stream": False, + "messages": [{"role": "user", "content": "Count to 3"}], + "options": {"num_predict": 20}, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data.get("done") is True + # Token counts should be present in the final chunk + assert data.get("prompt_eval_count", 0) >= 0 + + async def test_system_message_honoured(self, integration_client, chat_model): + resp = await integration_client.post( + "/api/chat", + json={ + "model": chat_model, + "stream": False, + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"}, + {"role": "user", "content": "PING"}, + ], + "options": {"num_predict": 10}, + }, + ) + assert resp.status_code == 200 + content = resp.json()["message"]["content"] + assert isinstance(content, str) + assert len(content) > 0 + + +# ── /api/generate ───────────────────────────────────────────────────────────── + +class TestApiGenerate: + async def test_non_streaming(self, integration_client, chat_model): + resp = await integration_client.post( + "/api/generate", + json={ + "model": chat_model, + "prompt": "Complete: The sky is", + "stream": False, + "options": {"num_predict": 5}, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "response" in data + + async def test_streaming(self, integration_client, chat_model): + resp = await integration_client.post( + "/api/generate", + json={ + "model": chat_model, + "prompt": "One plus one equals", + "stream": True, + "options": {"num_predict": 5}, + }, + ) + assert resp.status_code == 200 + lines = [l for l in resp.text.strip().split("\n") if l.strip()] + assert len(lines) >= 1 + + +# ── /api/embed ──────────────────────────────────────────────────────────────── + +class TestApiEmbed: + async def test_embed_single_string(self, integration_client, embed_model): + resp = await integration_client.post( + "/api/embed", + json={"model": embed_model, "input": "The quick brown fox"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "embeddings" in data + assert isinstance(data["embeddings"], list) + assert len(data["embeddings"]) == 1 + assert len(data["embeddings"][0]) > 0 + + async def test_embed_multiple_inputs(self, integration_client, embed_model): + resp = await integration_client.post( + "/api/embed", + json={"model": embed_model, "input": ["sentence one", "sentence two"]}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "embeddings" in data + assert len(data["embeddings"]) == 2 + + +# ── /v1/chat/completions ────────────────────────────────────────────────────── + +class TestOpenAIChatCompletions: + async def test_non_streaming(self, integration_client, chat_model): + model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model + resp = await integration_client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": [{"role": "user", "content": "Reply OK"}], + "max_tokens": 10, + "stream": False, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert "choices" in data + assert len(data["choices"]) > 0 + assert "message" in data["choices"][0] + + async def test_streaming_sse(self, integration_client, chat_model): + model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model + resp = await integration_client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 5, + "stream": True, + }, + ) + assert resp.status_code == 200 + # Response should be SSE format + assert "data:" in resp.text or "[DONE]" in resp.text + + async def test_non_streaming_has_usage(self, integration_client, chat_model): + model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model + resp = await integration_client.post( + "/v1/chat/completions", + json={ + "model": model, + "messages": [{"role": "user", "content": "Say yes"}], + "max_tokens": 5, + "stream": False, + }, + ) + assert resp.status_code == 200 + data = resp.json() + if "usage" in data and data["usage"]: + assert data["usage"].get("prompt_tokens", 0) >= 0 + + +# ── /v1/embeddings ──────────────────────────────────────────────────────────── + +class TestOpenAIEmbeddings: + async def test_single_input(self, integration_client, embed_model): + model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model + resp = await integration_client.post( + "/v1/embeddings", + json={"model": model, "input": "Test sentence"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "data" in data + assert len(data["data"]) > 0 + embedding = data["data"][0].get("embedding") + assert isinstance(embedding, list) + assert len(embedding) > 0 + + +# ── Token counts (database-backed) ─────────────────────────────────────────── + +class TestTokenCounts: + async def test_token_counts_endpoint(self, integration_client): + resp = await integration_client.get("/api/token_counts") + assert resp.status_code == 200 + data = resp.json() + assert "total_tokens" in data + assert "breakdown" in data + + +# ── ps_details (extended ps) ───────────────────────────────────────────────── + +class TestPsDetails: + async def test_ps_details_returns_models(self, integration_client): + resp = await integration_client.get("/api/ps_details") + assert resp.status_code == 200 + data = resp.json() + assert "models" in data + assert isinstance(data["models"], list) diff --git a/test/test_api_validation.py b/test/test_api_validation.py new file mode 100644 index 0000000..5d2b52d --- /dev/null +++ b/test/test_api_validation.py @@ -0,0 +1,230 @@ +""" +HTTP-level validation and auth middleware tests. + +These tests use an in-process httpx client and never reach a real backend: +all requests are rejected at the validation or auth layer before any +endpoint-selection or upstream HTTP calls occur. +""" +import pytest + + +class TestChatValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post( + "/api/chat", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status_code == 400 + assert "model" in resp.json()["detail"].lower() + + async def test_missing_messages_returns_400(self, client): + resp = await client.post("/api/chat", json={"model": "llama3.2"}) + assert resp.status_code == 400 + + async def test_invalid_json_returns_400(self, client): + resp = await client.post( + "/api/chat", + content=b"not-json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + async def test_messages_not_list_returns_400(self, client): + resp = await client.post( + "/api/chat", + json={"model": "m", "messages": "not-a-list"}, + ) + assert resp.status_code == 400 + + async def test_options_not_dict_returns_400(self, client): + resp = await client.post( + "/api/chat", + json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"}, + ) + assert resp.status_code == 400 + + +class TestGenerateValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post("/api/generate", json={"prompt": "hello"}) + assert resp.status_code == 400 + assert "model" in resp.json()["detail"].lower() + + async def test_missing_prompt_returns_400(self, client): + resp = await client.post("/api/generate", json={"model": "m"}) + assert resp.status_code == 400 + assert "prompt" in resp.json()["detail"].lower() + + async def test_invalid_json_returns_400(self, client): + resp = await client.post( + "/api/generate", + content=b"{bad-json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +class TestEmbedValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post("/api/embed", json={"input": "hello"}) + assert resp.status_code == 400 + + async def test_missing_input_returns_400(self, client): + resp = await client.post("/api/embed", json={"model": "nomic-embed-text"}) + assert resp.status_code == 400 + + +class TestEmbeddingsValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post("/api/embeddings", json={"prompt": "hello"}) + assert resp.status_code == 400 + + async def test_missing_prompt_returns_400(self, client): + resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"}) + assert resp.status_code == 400 + + +class TestOpenAIChatValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status_code == 400 + + async def test_missing_messages_returns_400(self, client): + resp = await client.post( + "/v1/chat/completions", + json={"model": "gpt-4o"}, + ) + assert resp.status_code == 400 + + async def test_invalid_json_returns_400(self, client): + resp = await client.post( + "/v1/chat/completions", + content=b"}{", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + async def test_svg_image_rejected(self, client): + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "vision-model", + "messages": [{ + "role": "user", + "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}}, + ], + }], + }, + ) + assert resp.status_code == 400 + assert "svg" in resp.json()["detail"].lower() + + +class TestOpenAICompletionsValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post("/v1/completions", json={"prompt": "hello"}) + assert resp.status_code == 400 + + async def test_missing_prompt_returns_400(self, client): + resp = await client.post("/v1/completions", json={"model": "m"}) + assert resp.status_code == 400 + + +class TestRerankValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post( + "/v1/rerank", + json={"query": "search query", "documents": ["doc1"]}, + ) + assert resp.status_code == 400 + + async def test_missing_query_returns_400(self, client): + resp = await client.post( + "/v1/rerank", + json={"model": "reranker", "documents": ["doc1"]}, + ) + assert resp.status_code == 400 + + async def test_empty_documents_returns_400(self, client): + resp = await client.post( + "/v1/rerank", + json={"model": "reranker", "query": "search", "documents": []}, + ) + assert resp.status_code == 400 + + +class TestShowValidation: + async def test_missing_model_returns_400(self, client): + resp = await client.post("/api/show", json={}) + assert resp.status_code == 400 + + +class TestCopyValidation: + async def test_missing_source_returns_400(self, client): + resp = await client.post("/api/copy", json={"destination": "dst"}) + assert resp.status_code == 400 + + async def test_missing_destination_returns_400(self, client): + resp = await client.post("/api/copy", json={"source": "src"}) + assert resp.status_code == 400 + + +class TestDeleteValidation: + async def test_missing_model_returns_400(self, client): + import json as _json + resp = await client.request( + "DELETE", + "/api/delete", + content=_json.dumps({}).encode(), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 400 + + +class TestAuthMiddleware: + async def test_no_key_returns_401(self, client_auth): + resp = await client_auth.post( + "/api/chat", + json={"model": "m", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 401 + assert "Missing" in resp.json()["detail"] + + async def test_invalid_key_returns_403(self, client_auth): + resp = await client_auth.post( + "/api/chat", + headers={"Authorization": "Bearer wrong-key"}, + json={"model": "m", "messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 403 + assert "Invalid" in resp.json()["detail"] + + async def test_valid_key_passes_middleware(self, client_auth): + # /api/usage reads in-memory counters only — no backend call needed + resp = await client_auth.get( + "/api/usage", + headers={"Authorization": "Bearer test-secret-key"}, + ) + assert resp.status_code == 200 + + async def test_key_via_query_param(self, client_auth): + resp = await client_auth.get("/api/usage?api_key=test-secret-key") + assert resp.status_code == 200 + + async def test_options_bypasses_auth(self, client_auth): + resp = await client_auth.options("/api/chat") + assert resp.status_code not in (401, 403) + + async def test_root_path_bypasses_auth(self, client_auth): + resp = await client_auth.get("/") + assert resp.status_code not in (401, 403) + + async def test_favicon_bypasses_auth(self, client_auth): + resp = await client_auth.get("/favicon.ico") + # Should not be blocked by auth (may 404 in test but not 401/403) + assert resp.status_code not in (401, 403) diff --git a/test/test_cache.py b/test/test_cache.py new file mode 100644 index 0000000..cd37688 --- /dev/null +++ b/test/test_cache.py @@ -0,0 +1,329 @@ +"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed).""" +from types import SimpleNamespace + +import orjson +import pytest + +import cache as cache_mod +from cache import ( + LLMCache, + _bm25_weighted_text, + get_llm_cache, + init_llm_cache, + openai_nonstream_to_sse, +) + + +def _exact_cfg(backend: str = "memory") -> SimpleNamespace: + """Config for exact-match mode — similarity=1.0 avoids embedding deps.""" + return SimpleNamespace( + cache_enabled=True, + cache_backend=backend, + cache_similarity=1.0, + cache_history_weight=0.3, + cache_ttl=300, + cache_db_path="/tmp/test_cache.db", + cache_redis_url="redis://localhost:6379", + ) + + +# ────────────────────────────────────────────────────────────────────────────── +# Pure helpers +# ────────────────────────────────────────────────────────────────────────────── + +class TestBM25WeightedText: + def test_empty_history(self): + assert _bm25_weighted_text([]) == "" + + def test_history_without_content(self): + assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == "" + + def test_repeats_high_idf_terms(self): + history = [ + {"role": "user", "content": "Tell me about quantum entanglement"}, + {"role": "assistant", "content": "Quantum entanglement is a phenomenon"}, + {"role": "user", "content": "How does entanglement work?"}, + ] + out = _bm25_weighted_text(history) + # Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped + assert "entanglement" in out + assert "is" not in out.split() + + +# ────────────────────────────────────────────────────────────────────────────── +# openai_nonstream_to_sse +# ────────────────────────────────────────────────────────────────────────────── + +class TestOpenAINonstreamToSSE: + def test_valid_chat_completion(self): + chat = { + "id": "x1", + "created": 123, + "model": "gpt-4o", + "choices": [{"message": {"role": "assistant", "content": "hello"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + } + out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o") + text = out.decode() + assert text.startswith("data: ") + assert text.endswith("data: [DONE]\n\n") + # First chunk contains the original content + first = text.split("\n\n")[0][len("data: "):] + parsed = orjson.loads(first) + assert parsed["choices"][0]["delta"]["content"] == "hello" + assert parsed["usage"]["total_tokens"] == 3 + + def test_corrupt_bytes_return_done_only(self): + out = openai_nonstream_to_sse(b"not-json", "m") + assert out == b"data: [DONE]\n\n" + + +# ────────────────────────────────────────────────────────────────────────────── +# LLMCache internal helpers +# ────────────────────────────────────────────────────────────────────────────── + +class TestLLMCacheParsing: + def test_namespace_is_stable_and_isolated(self): + c = LLMCache(_exact_cfg()) + a = c._namespace("chat", "m1", "system A") + b = c._namespace("chat", "m1", "system A") + assert a == b + assert c._namespace("chat", "m1", "system B") != a + assert c._namespace("generate", "m1", "system A") != a + assert len(a) == 16 + + def test_parse_messages_flat_strings(self): + c = LLMCache(_exact_cfg()) + sys, hist, last = c._parse_messages([ + {"role": "system", "content": "be helpful"}, + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "what is 2+2?"}, + ]) + assert sys == "be helpful" + assert last == "what is 2+2?" + assert hist == [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + + def test_parse_messages_multimodal_content(self): + c = LLMCache(_exact_cfg()) + sys, _hist, last = c._parse_messages([ + {"role": "system", "content": "sys"}, + {"role": "user", "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + ]}, + ]) + assert sys == "sys" + assert last == "describe" + + def test_parse_messages_no_user_message(self): + c = LLMCache(_exact_cfg()) + sys, hist, last = c._parse_messages([ + {"role": "system", "content": "sys only"}, + ]) + assert sys == "sys only" + assert last == "" + assert hist == [] + + +class TestPersonalTokenExtraction: + def test_email_extracted(self): + c = LLMCache(_exact_cfg()) + toks = c._extract_personal_tokens("Reach me at alice@example.com please") + assert "alice@example.com" in toks + + def test_numeric_id_after_keyword(self): + c = LLMCache(_exact_cfg()) + toks = c._extract_personal_tokens("User id: 123456") + assert "123456" in toks + + def test_identity_tag_names_extracted(self): + c = LLMCache(_exact_cfg()) + toks = c._extract_personal_tokens( + "[Tags: identity] User's name is Andreas Schwibbe" + ) + # Both name tokens should be extracted lowercased; stopwords dropped + assert "andreas" in toks + assert "schwibbe" in toks + assert "name" not in toks # in _IDENTITY_STOPWORDS + assert "user" not in toks + + def test_empty_system_returns_empty_set(self): + c = LLMCache(_exact_cfg()) + assert c._extract_personal_tokens("") == frozenset() + + +class TestResponseIsPersonalized: + def _resp(self, content: str) -> bytes: + return orjson.dumps({"choices": [{"message": {"content": content}}]}) + + def test_email_in_response_is_personalized(self): + c = LLMCache(_exact_cfg()) + assert c._response_is_personalized(self._resp("contact bob@x.com"), "") + + def test_uuid_in_response_is_personalized(self): + c = LLMCache(_exact_cfg()) + uuid = "550e8400-e29b-41d4-a716-446655440000" + assert c._response_is_personalized(self._resp(f"id={uuid}"), "") + + def test_long_numeric_id_in_response_is_personalized(self): + c = LLMCache(_exact_cfg()) + assert c._response_is_personalized(self._resp("account 12345678"), "") + + def test_identity_token_from_system_echoed_in_response(self): + c = LLMCache(_exact_cfg()) + system = "[Tags: identity] Andreas works here" + assert c._response_is_personalized( + self._resp("Yes, Andreas is logged in"), system + ) + + def test_generic_response_not_personalized(self): + c = LLMCache(_exact_cfg()) + assert not c._response_is_personalized( + self._resp("The capital of France is Paris."), "be helpful" + ) + + def test_ollama_message_format_parsed(self): + c = LLMCache(_exact_cfg()) + body = orjson.dumps({"message": {"content": "alice@example.com"}}) + assert c._response_is_personalized(body, "") + + def test_unparseable_body_with_bytes_is_conservative(self): + c = LLMCache(_exact_cfg()) + # Can't parse → returns True (err on the side of privacy) + assert c._response_is_personalized(b"binary-junk", "") + + def test_empty_response_not_personalized(self): + c = LLMCache(_exact_cfg()) + assert not c._response_is_personalized(b"", "anything") + + +# ────────────────────────────────────────────────────────────────────────────── +# End-to-end exact-match cache with the memory backend +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture +async def memcache(): + """LLMCache wired up with the in-memory backend (no external deps).""" + c = LLMCache(_exact_cfg("memory")) + await c.init() + return c + + +class TestExactMatchCache: + async def test_miss_then_set_then_hit(self, memcache): + msgs = [ + {"role": "system", "content": "be helpful"}, + {"role": "user", "content": "what is 2+2?"}, + ] + resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]}) + + assert await memcache.get_chat("chat", "m1", msgs) is None + await memcache.set_chat("chat", "m1", msgs, resp) + hit = await memcache.get_chat("chat", "m1", msgs) + assert hit == resp + + async def test_namespace_isolation_by_system(self, memcache): + resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]}) + msgs_a = [ + {"role": "system", "content": "system A"}, + {"role": "user", "content": "same question"}, + ] + msgs_b = [ + {"role": "system", "content": "system B"}, + {"role": "user", "content": "same question"}, + ] + await memcache.set_chat("chat", "m", msgs_a, resp) + # Same question + different system prompt = different namespace = miss + assert await memcache.get_chat("chat", "m", msgs_b) is None + + async def test_namespace_isolation_by_route(self, memcache): + resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]}) + msgs = [{"role": "user", "content": "ping"}] + await memcache.set_chat("chat", "m", msgs, resp) + assert await memcache.get_chat("openai_chat", "m", msgs) is None + + async def test_no_user_message_is_noop(self, memcache): + msgs = [{"role": "system", "content": "sys only"}] + resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]}) + # Both get and set should silently no-op + assert await memcache.get_chat("chat", "m", msgs) is None + await memcache.set_chat("chat", "m", msgs, resp) + assert await memcache.get_chat("chat", "m", msgs) is None + + async def test_personalized_response_generic_system_not_stored(self, memcache): + msgs = [ + {"role": "system", "content": "be helpful"}, # generic + {"role": "user", "content": "give me an email"}, + ] + # Response contains an email → would leak across users sharing the + # generic namespace → must NOT be stored at all + resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]}) + await memcache.set_chat("chat", "m", msgs, resp) + assert await memcache.get_chat("chat", "m", msgs) is None + + async def test_personalized_response_user_specific_system_stored(self, memcache): + msgs = [ + {"role": "system", "content": "User id: 998877 prefers concise answers"}, + {"role": "user", "content": "what is my id?"}, + ] + resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]}) + await memcache.set_chat("chat", "m", msgs, resp) + # User-specific namespace → exact-match within this user is OK + assert await memcache.get_chat("chat", "m", msgs) == resp + + async def test_generate_convenience_wrappers(self, memcache): + resp = orjson.dumps({"response": "blue"}) + await memcache.set_generate("m", "what color is the sky?", "", resp) + assert await memcache.get_generate("m", "what color is the sky?") == resp + + +class TestStatsAndClear: + async def test_stats_tracks_hits_and_misses(self, memcache): + msgs = [{"role": "user", "content": "hello"}] + await memcache.get_chat("chat", "m", msgs) # miss + resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]}) + await memcache.set_chat("chat", "m", msgs, resp) + await memcache.get_chat("chat", "m", msgs) # hit + s = memcache.stats() + assert s["hits"] == 1 + assert s["misses"] == 1 + assert s["hit_rate"] == 0.5 + assert s["semantic"] is False + assert s["backend"] == "memory" + + async def test_clear_resets_counters_and_storage(self, memcache): + msgs = [{"role": "user", "content": "hi"}] + resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]}) + await memcache.set_chat("chat", "m", msgs, resp) + await memcache.get_chat("chat", "m", msgs) + await memcache.clear() + s = memcache.stats() + assert s["hits"] == 0 + assert s["misses"] == 0 + assert await memcache.get_chat("chat", "m", msgs) is None + + +# ────────────────────────────────────────────────────────────────────────────── +# Module-level helpers +# ────────────────────────────────────────────────────────────────────────────── + +class TestInitLLMCache: + async def test_disabled_returns_none(self): + cfg = _exact_cfg() + cfg.cache_enabled = False + result = await init_llm_cache(cfg) + assert result is None + + async def test_enabled_returns_initialized_cache(self): + cfg = _exact_cfg() + try: + result = await init_llm_cache(cfg) + assert result is not None + assert get_llm_cache() is result + finally: + # Reset singleton between tests + cache_mod._cache = None diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py new file mode 100644 index 0000000..b94fdf1 --- /dev/null +++ b/test/test_choose_endpoint.py @@ -0,0 +1,345 @@ +"""Tests for choose_endpoint routing logic with mocked fetch calls.""" +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import router + +EP1 = "http://ep1:11434" +EP2 = "http://ep2:11434" +EP3 = "http://ep3:11434" +LLAMA_EP = "http://llama:8080/v1" + + +def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False): + cfg = MagicMock() + cfg.endpoints = endpoints + cfg.llama_server_endpoints = llama_eps or [] + cfg.api_keys = {} + cfg.max_concurrent_connections = max_conn + cfg.endpoint_config = endpoint_config or {} + cfg.priority_routing = priority_routing + cfg.router_api_key = None + return cfg + + +@pytest.fixture(autouse=True) +def reset_usage(): + """Clear usage_counts between tests to prevent bleed.""" + router.usage_counts.clear() + yield + router.usage_counts.clear() + + +class TestChooseEndpointBasic: + async def test_selects_single_candidate(self): + cfg = _make_cfg([EP1]) + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})), + ): + ep, tracking = await router.choose_endpoint("llama3.2:latest") + assert ep == EP1 + assert tracking == "llama3.2:latest" + + async def test_raises_when_no_endpoint_has_model(self): + cfg = _make_cfg([EP1, EP2]) + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(return_value=set())), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), + ): + with pytest.raises(RuntimeError, match="advertise the model"): + await router.choose_endpoint("unknown-model:latest") + + async def test_prefers_loaded_endpoint(self): + cfg = _make_cfg([EP1, EP2]) + async def available(ep, *_): + return {"llama3.2:latest"} + + async def loaded(ep): + return {"llama3.2:latest"} if ep == EP2 else set() + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", side_effect=loaded), + ): + ep, _ = await router.choose_endpoint("llama3.2:latest") + assert ep == EP2 + + async def test_falls_back_to_free_slot(self): + cfg = _make_cfg([EP1, EP2]) + async def available(ep, *_): + return {"llama3.2:latest"} + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), + ): + ep, _ = await router.choose_endpoint("llama3.2:latest") + assert ep in (EP1, EP2) + + async def test_saturated_picks_least_busy(self): + cfg = _make_cfg([EP1, EP2]) + cfg.max_concurrent_connections = 1 + + async def available(ep, *_): + return {"llama3.2:latest"} + + # Saturate EP1 with 2 active connections, EP2 with 1 + router.usage_counts[EP1]["llama3.2:latest"] = 2 + router.usage_counts[EP2]["llama3.2:latest"] = 1 + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), + ): + ep, _ = await router.choose_endpoint("llama3.2:latest") + # Least-busy is EP2 + assert ep == EP2 + + async def test_reserve_increments_usage(self): + cfg = _make_cfg([EP1]) + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})), + ): + ep, tracking = await router.choose_endpoint("model:latest", reserve=True) + assert router.usage_counts[ep][tracking] == 1 + + +class TestChooseEndpointModelNaming: + async def test_strips_latest_for_openai_endpoints(self): + cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP]) + cfg.endpoints = [] + + async def available(ep, *_): + # llama-server advertises without :latest + return {"gpt-4o"} + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})), + ): + ep, _ = await router.choose_endpoint("gpt-4o:latest") + assert ep == LLAMA_EP + + async def test_adds_latest_for_ollama_when_bare_name(self): + cfg = _make_cfg([EP1]) + + async def available(ep, *_): + return {"llama3.2:latest"} + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})), + ): + ep, _ = await router.choose_endpoint("llama3.2") + assert ep == EP1 + + +class TestChooseEndpointLoadBalancing: + async def test_random_selection_among_idle(self): + cfg = _make_cfg([EP1, EP2, EP3]) + selected = set() + + async def available(ep, *_): + return {"model:latest"} + + async def loaded(ep): + return {"model:latest"} + + for _ in range(20): + router.usage_counts.clear() + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", side_effect=loaded), + ): + ep, _ = await router.choose_endpoint("model:latest", reserve=False) + selected.add(ep) + + # With 20 draws from 3 idle endpoints, all three should appear + assert len(selected) > 1 + + async def test_sort_by_load_ascending(self): + cfg = _make_cfg([EP1, EP2]) + router.usage_counts[EP1]["model:latest"] = 1 + router.usage_counts[EP2]["model:latest"] = 0 + + async def available(ep, *_): + return {"model:latest"} + + async def loaded(ep): + return {"model:latest"} + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", side_effect=loaded), + ): + ep, _ = await router.choose_endpoint("model:latest", reserve=False) + # EP2 has fewer active connections → should be selected + assert ep == EP2 + + +# --------------------------------------------------------------------------- +# get_max_connections unit tests +# --------------------------------------------------------------------------- + +class TestGetMaxConnections: + def test_returns_global_default_when_no_override(self): + cfg = _make_cfg([EP1, EP2], max_conn=3) + with patch.object(router, "config", cfg): + assert router.get_max_connections(EP1) == 3 + assert router.get_max_connections(EP2) == 3 + + def test_returns_per_endpoint_override(self): + cfg = _make_cfg( + [EP1, EP2], + max_conn=2, + endpoint_config={EP1: {"max_concurrent_connections": 5}}, + ) + with patch.object(router, "config", cfg): + assert router.get_max_connections(EP1) == 5 + assert router.get_max_connections(EP2) == 2 # falls back to global + + def test_unrecognised_endpoint_falls_back_to_global(self): + cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}}) + with patch.object(router, "config", cfg): + assert router.get_max_connections(EP3) == 4 + + +# --------------------------------------------------------------------------- +# Priority / WRR routing tests +# --------------------------------------------------------------------------- + +MODEL = "model:latest" + + +def _all_loaded(ep): + """Side-effect: every endpoint advertises and has MODEL loaded.""" + return {MODEL} + + +class TestPriorityRouting: + """Tests for priority_routing=True (WRR + config-order tiebreaking).""" + + async def test_idle_picks_first_in_config_order(self): + """When all endpoints are idle, priority picks the first listed endpoint.""" + cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True) + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), + patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), + ): + ep, _ = await router.choose_endpoint(MODEL, reserve=False) + assert ep == EP1 + + async def test_lower_utilization_preferred_over_priority(self): + """An endpoint with lower ratio is preferred even if it has lower priority.""" + cfg = _make_cfg([EP1, EP2], priority_routing=True) + # EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0 + router.usage_counts[EP1][MODEL] = 1 + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), + patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), + ): + ep, _ = await router.choose_endpoint(MODEL, reserve=False) + assert ep == EP2 + + async def test_wrr_distribution_matches_expected_sequence(self): + """ + Full WRR sequence with heterogeneous capacities, mirroring the issue example: + EP1 max=2, EP2 max=2, EP3 max=1 + + Expected routing order for 5 sequential requests: + EP1, EP2, EP3, EP1, EP2 + """ + cfg = _make_cfg( + [EP1, EP2, EP3], + max_conn=2, + endpoint_config={EP3: {"max_concurrent_connections": 1}}, + priority_routing=True, + ) + + expected = [EP1, EP2, EP3, EP1, EP2] + actual = [] + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), + patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), + ): + for _ in expected: + ep, _ = await router.choose_endpoint(MODEL, reserve=True) + actual.append(ep) + + assert actual == expected + + async def test_saturated_picks_lowest_ratio_then_priority(self): + """When all endpoints are saturated, pick lowest utilization ratio; break ties by priority.""" + cfg = _make_cfg( + [EP1, EP2, EP3], + max_conn=1, + endpoint_config={EP3: {"max_concurrent_connections": 2}}, + priority_routing=True, + ) + # EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins + router.usage_counts[EP1][MODEL] = 1 + router.usage_counts[EP2][MODEL] = 1 + router.usage_counts[EP3][MODEL] = 1 + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), + patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), + ): + ep, _ = await router.choose_endpoint(MODEL, reserve=False) + assert ep == EP3 + + async def test_saturated_ties_broken_by_priority(self): + """When all are saturated with equal ratio, config order wins.""" + cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True) + router.usage_counts[EP1][MODEL] = 1 + router.usage_counts[EP2][MODEL] = 1 + router.usage_counts[EP3][MODEL] = 1 + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), + patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), + ): + ep, _ = await router.choose_endpoint(MODEL, reserve=False) + assert ep == EP1 + + +class TestPriorityRoutingDisabled: + """Verify that priority_routing=False keeps the original random behaviour.""" + + async def test_idle_endpoints_are_randomised(self): + """Without priority routing, all-idle selection must eventually pick each endpoint.""" + cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False) + selected = set() + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)), + patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)), + ): + for _ in range(30): + router.usage_counts.clear() + ep, _ = await router.choose_endpoint(MODEL, reserve=False) + selected.add(ep) + + # With 30 draws from 3 equally-idle endpoints, all three must appear + assert selected == {EP1, EP2, EP3} diff --git a/test/test_db.py b/test/test_db.py new file mode 100644 index 0000000..833b375 --- /dev/null +++ b/test/test_db.py @@ -0,0 +1,197 @@ +"""Direct unit tests for db.TokenDatabase — no router/app dependency.""" +from datetime import datetime, timezone + +import pytest + +from db import TokenDatabase + + +@pytest.fixture +async def db(tmp_path): + inst = TokenDatabase(str(tmp_path / "tokens.db")) + await inst.init_db() + yield inst + await inst.close() + + +class TestInit: + async def test_init_creates_tables(self, db): + # Re-init must be idempotent + await db.init_db() + # Insert + read confirms tables exist + await db.update_token_counts("http://ep", "m", 1, 2) + rows = [r async for r in db.load_token_counts()] + assert len(rows) == 1 + + async def test_creates_parent_directory(self, tmp_path): + nested = tmp_path / "nested" / "subdir" / "x.db" + inst = TokenDatabase(str(nested)) + await inst.init_db() + try: + assert nested.parent.exists() + finally: + await inst.close() + + +class TestUpdateTokenCounts: + async def test_insert_then_update_aggregates(self, db): + await db.update_token_counts("http://ep", "m1", 10, 20) + await db.update_token_counts("http://ep", "m1", 5, 7) + rows = [r async for r in db.load_token_counts()] + assert len(rows) == 1 + r = rows[0] + assert r["endpoint"] == "http://ep" + assert r["model"] == "m1" + assert r["input_tokens"] == 15 + assert r["output_tokens"] == 27 + assert r["total_tokens"] == 42 + + async def test_independent_endpoint_model_pairs(self, db): + await db.update_token_counts("http://ep1", "m1", 1, 1) + await db.update_token_counts("http://ep1", "m2", 2, 2) + await db.update_token_counts("http://ep2", "m1", 3, 3) + rows = [r async for r in db.load_token_counts()] + assert len(rows) == 3 + totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows} + assert totals == { + ("http://ep1", "m1"): 2, + ("http://ep1", "m2"): 4, + ("http://ep2", "m1"): 6, + } + + +class TestBatchedCounts: + async def test_update_batched_counts(self, db): + counts = { + "http://a": {"m": (4, 6)}, + "http://b": {"m": (1, 1), "n": (10, 0)}, + } + await db.update_batched_counts(counts) + rows = [r async for r in db.load_token_counts()] + totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows} + assert totals == { + ("http://a", "m"): 10, + ("http://b", "m"): 2, + ("http://b", "n"): 10, + } + + async def test_empty_batch_is_noop(self, db): + await db.update_batched_counts({}) + rows = [r async for r in db.load_token_counts()] + assert rows == [] + + +class TestTimeSeries: + async def test_add_time_series_entry(self, db): + # The aggregate FK requires the (endpoint,model) row to exist first + await db.update_token_counts("http://ep", "m", 0, 0) + await db.add_time_series_entry("http://ep", "m", 3, 4) + await db.add_time_series_entry("http://ep", "m", 1, 1) + rows = [r async for r in db.get_latest_time_series(limit=10)] + assert len(rows) == 2 + # Newest-first ordering; both timestamps are within the same minute, + # so just check totals are present and well-formed + for r in rows: + assert r["endpoint"] == "http://ep" + assert r["model"] == "m" + assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"] + + async def test_add_batched_time_series(self, db): + await db.update_token_counts("http://ep", "m", 0, 0) + now = int(datetime.now(tz=timezone.utc).timestamp()) + entries = [ + {"endpoint": "http://ep", "model": "m", "input_tokens": 1, + "output_tokens": 2, "total_tokens": 3, "timestamp": now - 60}, + {"endpoint": "http://ep", "model": "m", "input_tokens": 4, + "output_tokens": 5, "total_tokens": 9, "timestamp": now}, + ] + await db.add_batched_time_series(entries) + rows = [r async for r in db.get_latest_time_series(limit=10)] + assert len(rows) == 2 + assert rows[0]["timestamp"] >= rows[1]["timestamp"] + + async def test_get_time_series_for_model_filters(self, db): + await db.update_token_counts("http://ep", "m1", 0, 0) + await db.update_token_counts("http://ep", "m2", 0, 0) + now = int(datetime.now(tz=timezone.utc).timestamp()) + await db.add_batched_time_series([ + {"endpoint": "http://ep", "model": "m1", "input_tokens": 1, + "output_tokens": 1, "total_tokens": 2, "timestamp": now}, + {"endpoint": "http://ep", "model": "m2", "input_tokens": 9, + "output_tokens": 9, "total_tokens": 18, "timestamp": now}, + ]) + rows = [r async for r in db.get_time_series_for_model("m1")] + assert len(rows) == 1 + assert rows[0]["total_tokens"] == 2 + + async def test_endpoint_distribution_for_model(self, db): + await db.update_token_counts("http://a", "m", 0, 0) + await db.update_token_counts("http://b", "m", 0, 0) + now = int(datetime.now(tz=timezone.utc).timestamp()) + await db.add_batched_time_series([ + {"endpoint": "http://a", "model": "m", "input_tokens": 1, + "output_tokens": 1, "total_tokens": 2, "timestamp": now}, + {"endpoint": "http://a", "model": "m", "input_tokens": 1, + "output_tokens": 1, "total_tokens": 2, "timestamp": now}, + {"endpoint": "http://b", "model": "m", "input_tokens": 5, + "output_tokens": 5, "total_tokens": 10, "timestamp": now}, + ]) + dist = await db.get_endpoint_distribution_for_model("m") + assert dist == {"http://a": 4, "http://b": 10} + + +class TestGetTokenCountsForModel: + async def test_aggregates_across_endpoints(self, db): + await db.update_token_counts("http://a", "m", 1, 2) + await db.update_token_counts("http://b", "m", 3, 4) + result = await db.get_token_counts_for_model("m") + assert result is not None + assert result["endpoint"] == "aggregated" + assert result["model"] == "m" + assert result["input_tokens"] == 4 + assert result["output_tokens"] == 6 + assert result["total_tokens"] == 10 + + async def test_unknown_model_returns_zero_aggregate(self, db): + # SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros + result = await db.get_token_counts_for_model("nope") + assert result is not None + assert result["input_tokens"] in (0, None) + + +class TestAggregateTimeSeriesOlderThan: + async def test_aggregates_old_entries_by_day(self, db): + await db.update_token_counts("http://ep", "m", 0, 0) + now = int(datetime.now(tz=timezone.utc).timestamp()) + old = now - (40 * 86400) # 40 days ago + await db.add_batched_time_series([ + {"endpoint": "http://ep", "model": "m", "input_tokens": 1, + "output_tokens": 1, "total_tokens": 2, "timestamp": old}, + {"endpoint": "http://ep", "model": "m", "input_tokens": 3, + "output_tokens": 3, "total_tokens": 6, "timestamp": old + 60}, + {"endpoint": "http://ep", "model": "m", "input_tokens": 99, + "output_tokens": 99, "total_tokens": 198, "timestamp": now}, + ]) + n = await db.aggregate_time_series_older_than(30, trim_old=False) + assert n == 1 # one (endpoint, model, day) group rolled up + + async def test_invalid_days_falls_back_to_30(self, db): + # Just ensure it doesn't blow up with a bogus value + n = await db.aggregate_time_series_older_than(0) + assert n == 0 + + async def test_trim_old_removes_aggregated_rows(self, db): + await db.update_token_counts("http://ep", "m", 0, 0) + now = int(datetime.now(tz=timezone.utc).timestamp()) + old = now - (40 * 86400) + await db.add_batched_time_series([ + {"endpoint": "http://ep", "model": "m", "input_tokens": 1, + "output_tokens": 1, "total_tokens": 2, "timestamp": old}, + {"endpoint": "http://ep", "model": "m", "input_tokens": 99, + "output_tokens": 99, "total_tokens": 198, "timestamp": now}, + ]) + await db.aggregate_time_series_older_than(30, trim_old=True) + remaining = [r async for r in db.get_latest_time_series(limit=10)] + # Only the recent (within-cutoff) row should remain + assert len(remaining) == 1 + assert remaining[0]["total_tokens"] == 198 diff --git a/test/test_fetch.py b/test/test_fetch.py new file mode 100644 index 0000000..9b542a1 --- /dev/null +++ b/test/test_fetch.py @@ -0,0 +1,180 @@ +"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking.""" +import time +from unittest.mock import patch, MagicMock + +import pytest +from aioresponses import aioresponses + +import router +from conftest import TEST_OLLAMA, TEST_LLAMA + +MOCK_OLLAMA_EP = "http://mock-ollama:11434" +MOCK_LLAMA_EP = "http://mock-llama:8080/v1" + + +def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None): + cfg = MagicMock() + cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP] + cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP] + cfg.api_keys = api_keys or {} + cfg.max_concurrent_connections = 2 + cfg.router_api_key = None + return cfg + + +@pytest.fixture(autouse=True) +def clear_caches(aio_session): + """aio_session fixture already clears caches and sets up app_state.""" + yield + + +class TestFetchAvailableModels: + async def test_ollama_tags(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/tags", + payload={"models": [ + {"name": "llama3.2:latest"}, + {"name": "qwen2.5:7b"}, + ]}, + ) + models = await router.fetch.available_models(MOCK_OLLAMA_EP) + assert models == {"llama3.2:latest", "qwen2.5:7b"} + + async def test_openai_compatible_models_endpoint(self): + cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_LLAMA_EP}/models", + payload={"data": [{"id": "unsloth/model:Q8_0"}]}, + ) + models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok") + assert "unsloth/model:Q8_0" in models + + async def test_caches_successful_result(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/tags", + payload={"models": [{"name": "llama3.2:latest"}]}, + ) + first = await router.fetch.available_models(MOCK_OLLAMA_EP) + second = await router.fetch.available_models(MOCK_OLLAMA_EP) + # second call must be served from cache without a second HTTP request + assert first == second == {"llama3.2:latest"} + + async def test_returns_empty_on_http_500(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"}) + models = await router.fetch.available_models(MOCK_OLLAMA_EP) + assert models == set() + + async def test_returns_empty_on_connection_error(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + import aiohttp + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/tags", + exception=aiohttp.ClientConnectorError( + connection_key=MagicMock(host="mock-ollama", port=11434), + os_error=OSError(111, "refused"), + ), + ) + models = await router.fetch.available_models(MOCK_OLLAMA_EP) + assert models == set() + + async def test_stale_cache_returned_while_refresh_runs(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/tags", + payload={"models": [{"name": "llama3.2:latest"}]}, + ) + await router.fetch.available_models(MOCK_OLLAMA_EP) + + # Manually age cache into stale-but-valid window (300-600s) + async with router._models_cache_lock: + models, _ = router._models_cache[MOCK_OLLAMA_EP] + router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400) + + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/tags", + payload={"models": [{"name": "llama3.2:latest"}]}, + ) + # Should return stale data immediately + stale = await router.fetch.available_models(MOCK_OLLAMA_EP) + assert "llama3.2:latest" in stale + + async def test_error_cache_short_circuits(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + # Seed error cache with a very recent error + async with router._available_error_cache_lock: + router._available_error_cache[MOCK_OLLAMA_EP] = time.time() + + with patch.object(router, "config", cfg), aioresponses(): + # No HTTP mock registered — if a call happens it will raise + models = await router.fetch.available_models(MOCK_OLLAMA_EP) + assert models == set() + + +class TestFetchLoadedModels: + async def test_ollama_ps(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/ps", + payload={"models": [{"name": "llama3.2:latest"}]}, + ) + models = await router.fetch.loaded_models(MOCK_OLLAMA_EP) + assert models == {"llama3.2:latest"} + + async def test_llama_server_filters_loaded(self): + cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_LLAMA_EP}/models", + payload={"data": [ + {"id": "model-a", "status": {"value": "loaded"}}, + {"id": "model-b", "status": {"value": "unloaded"}}, + ]}, + ) + models = await router.fetch.loaded_models(MOCK_LLAMA_EP) + assert models == {"model-a"} + + async def test_llama_server_no_status_field_always_loaded(self): + cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_LLAMA_EP}/models", + payload={"data": [{"id": "always-on-model"}]}, + ) + models = await router.fetch.loaded_models(MOCK_LLAMA_EP) + assert "always-on-model" in models + + async def test_returns_empty_on_error(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={}) + models = await router.fetch.loaded_models(MOCK_OLLAMA_EP) + assert models == set() + + async def test_ext_openai_always_empty(self): + ext_ep = "https://api.openai.com/v1" + cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[]) + with patch.object(router, "config", cfg): + models = await router.fetch.loaded_models(ext_ep) + assert models == set() + + async def test_caches_result(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/ps", + payload={"models": [{"name": "qwen:7b"}]}, + ) + first = await router.fetch.loaded_models(MOCK_OLLAMA_EP) + second = await router.fetch.loaded_models(MOCK_OLLAMA_EP) + assert first == second diff --git a/test/test_openai_proxies.py b/test/test_openai_proxies.py new file mode 100644 index 0000000..8a56c91 --- /dev/null +++ b/test/test_openai_proxies.py @@ -0,0 +1,181 @@ +"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes. + +These tests verify that when the LLM cache reports a hit, the route returns +the cached payload *without* selecting an endpoint or contacting any backend. +""" +from unittest.mock import AsyncMock, patch + +import orjson +import pytest +from fastapi import HTTPException + +import router + + +_BYPASS = HTTPException(status_code=599, detail="bypassed") + + +class _FakeCache: + """Minimal stand-in for cache.LLMCache.get_chat.""" + def __init__(self, response_bytes: bytes | None): + self._resp = response_bytes + self.calls: list[tuple] = [] + + async def get_chat(self, route, model, messages): + self.calls.append((route, model, messages)) + return self._resp + + +@pytest.fixture +def cache_hit_payload(): + return orjson.dumps({ + "id": "cmpl-xyz", + "created": 1, + "model": "test-model", + "choices": [{"message": {"role": "assistant", "content": "from-cache"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + }) + + +# ────────────────────────────────────────────────────────────────────────────── +# /v1/chat/completions +# ────────────────────────────────────────────────────────────────────────────── + +class TestOpenAIChatCompletionsCacheHit: + async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload): + fake = _FakeCache(cache_hit_payload) + # Patch the route's references to both helpers — they're imported by name + # into router's namespace at module load time. + with ( + patch.object(router, "get_llm_cache", return_value=fake), + patch.object(router, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached"))), + ): + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "ping"}], + "stream": False, + "nomyo": {"cache": True}, + }, + ) + assert resp.status_code == 200 + # Body is streamed; collect it + body = resp.content + parsed = orjson.loads(body) + assert parsed["choices"][0]["message"]["content"] == "from-cache" + assert fake.calls and fake.calls[0][0] == "openai_chat" + + async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload): + fake = _FakeCache(cache_hit_payload) + with ( + patch.object(router, "get_llm_cache", return_value=fake), + patch.object(router, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached"))), + ): + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "ping"}], + "stream": True, + "nomyo": {"cache": True}, + }, + ) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + text = resp.content.decode() + # First SSE frame contains the cached content as a delta + first_frame = text.split("\n\n")[0] + assert first_frame.startswith("data: ") + chunk = orjson.loads(first_frame[len("data: "):]) + assert chunk["choices"][0]["delta"]["content"] == "from-cache" + # Stream is terminated with [DONE] + assert "data: [DONE]" in text + + async def test_cache_disabled_in_payload_bypasses_cache_check(self, client): + """When nomyo.cache=False, get_chat is never called even if a cache exists.""" + fake = _FakeCache(b"") # has a response, but should never be consulted + with ( + patch.object(router, "get_llm_cache", return_value=fake), + patch.object(router, "choose_endpoint", + AsyncMock(side_effect=_BYPASS)), + ): + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "nomyo": {"cache": False}, + }, + ) + # Got past the cache short-circuit → endpoint selection invoked + assert resp.status_code == 599 + assert fake.calls == [] + + async def test_no_cache_configured_bypasses_cache_check(self, client): + """get_llm_cache() returning None should not break the route.""" + with ( + patch.object(router, "get_llm_cache", return_value=None), + patch.object(router, "choose_endpoint", + AsyncMock(side_effect=_BYPASS)), + ): + resp = await client.post( + "/v1/chat/completions", + json={ + "model": "m", + "messages": [{"role": "user", "content": "hi"}], + "nomyo": {"cache": True}, + }, + ) + assert resp.status_code == 599 + + +# ────────────────────────────────────────────────────────────────────────────── +# /v1/completions +# ────────────────────────────────────────────────────────────────────────────── + +class TestOpenAICompletionsCacheHit: + async def test_nonstream_cache_hit(self, client, cache_hit_payload): + fake = _FakeCache(cache_hit_payload) + with ( + patch.object(router, "get_llm_cache", return_value=fake), + patch.object(router, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached"))), + ): + resp = await client.post( + "/v1/completions", + json={ + "model": "test-model", + "prompt": "Tell me a joke", + "stream": False, + "nomyo": {"cache": True}, + }, + ) + assert resp.status_code == 200 + # Prompt-style cache lookup is namespaced under "openai_completions" + assert fake.calls[0][0] == "openai_completions" + # Cache lookup receives the prompt as a single user message + cached_msgs = fake.calls[0][2] + assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}] + + async def test_stream_cache_hit(self, client, cache_hit_payload): + fake = _FakeCache(cache_hit_payload) + with ( + patch.object(router, "get_llm_cache", return_value=fake), + patch.object(router, "choose_endpoint", + AsyncMock(side_effect=AssertionError("backend must not be reached"))), + ): + resp = await client.post( + "/v1/completions", + json={ + "model": "test-model", + "prompt": "What is 2+2?", + "stream": True, + "nomyo": {"cache": True}, + }, + ) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + assert "data: [DONE]" in resp.content.decode() diff --git a/test/test_unit_context.py b/test/test_unit_context.py new file mode 100644 index 0000000..de2b98a --- /dev/null +++ b/test/test_unit_context.py @@ -0,0 +1,116 @@ +"""Unit tests for context-window trimming logic.""" +import pytest +import router + + +def _msgs(roles_contents): + return [{"role": r, "content": c} for r, c in roles_contents] + + +class TestCountMessageTokens: + def test_returns_int(self): + msgs = _msgs([("user", "hello")]) + assert isinstance(router._count_message_tokens(msgs), int) + + def test_empty_list(self): + assert router._count_message_tokens([]) >= 0 + + def test_longer_content_more_tokens(self): + short = _msgs([("user", "hi")]) + long_ = _msgs([("user", "a " * 500)]) + assert router._count_message_tokens(long_) > router._count_message_tokens(short) + + def test_list_content(self): + msgs = [{"role": "user", "content": [ + {"type": "text", "text": "what do you see?"}, + ]}] + tokens = router._count_message_tokens(msgs) + assert tokens > 0 + + def test_multiple_messages(self): + msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")]) + assert router._count_message_tokens(msgs) > 10 + + +class TestTrimMessagesForContext: + def test_short_history_unchanged(self): + msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")]) + result = router._trim_messages_for_context(msgs, n_ctx=4096) + assert result == msgs + + def test_system_messages_always_kept(self): + msgs = ( + _msgs([("system", "you are helpful")]) + + _msgs([("user", f"msg {i}") for i in range(50)]) + + _msgs([("user", "final question")]) + ) + result = router._trim_messages_for_context(msgs, n_ctx=512) + system_msgs = [m for m in result if m["role"] == "system"] + assert len(system_msgs) == 1 + assert system_msgs[0]["content"] == "you are helpful" + + def test_last_user_message_always_kept(self): + msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")]) + result = router._trim_messages_for_context(msgs, n_ctx=256) + assert result[-1]["content"] == "very important last question" + + def test_oldest_dropped_first(self): + msgs = _msgs([ + ("user", "oldest msg"), + ("assistant", "oldest reply"), + ("user", "newer msg"), + ("assistant", "newer reply"), + ("user", "newest"), + ]) + # Use very small target to force trimming + result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10) + contents = [m["content"] for m in result] + # "oldest msg" should be dropped before "newest" + if "oldest msg" in contents: + assert "newest" in contents + else: + assert "newest" in contents + + def test_result_starts_with_user(self): + msgs = _msgs([ + ("assistant", "leftover assistant"), + ("user", "question"), + ]) + result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20) + if result: + assert result[0]["role"] == "user" + + def test_target_tokens_overrides_safety_margin(self): + msgs = _msgs([("user", "a " * 200)]) + result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10) + result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000) + # Both should return at least the last message + assert len(result_small) >= 1 + assert len(result_large) >= 1 + + +class TestCalibratedTrimTarget: + def test_returns_positive_int(self): + msgs = [{"role": "user", "content": "hello " * 100}] + result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000) + assert isinstance(result, int) + assert result >= 1 + + def test_over_limit_reduces_target(self): + msgs = [{"role": "user", "content": "a " * 500}] + # actual_tokens > n_ctx means we need to shed more + target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500) + assert target < router._count_message_tokens(msgs) + + def test_well_within_limit_returns_current(self): + msgs = [{"role": "user", "content": "hi"}] + # actual_tokens << n_ctx means nothing to shed + target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50) + # Should return cur_tiktoken since to_shed == 0 + assert target == max(1, router._count_message_tokens(msgs)) + + def test_minimum_is_one(self): + # Even if we need to shed everything, result is at least 1 + msgs = [{"role": "user", "content": "hello"}] + target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999) + assert target >= 1 diff --git a/test/test_unit_helpers.py b/test/test_unit_helpers.py new file mode 100644 index 0000000..d38eb37 --- /dev/null +++ b/test/test_unit_helpers.py @@ -0,0 +1,279 @@ +"""Unit tests for pure helper functions in router.py (no network, no app).""" +import time +import asyncio +from unittest.mock import MagicMock, patch + +import aiohttp +import pytest + +import router + + +class TestMaskSecrets: + def test_masks_openai_key(self): + text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ" + result = router._mask_secrets(text) + assert "sk-***redacted***" in result + assert "sk-abcd1234" not in result + + def test_masks_api_key_assignment(self): + result = router._mask_secrets("api_key: supersecretvalue123") + assert "supersecretvalue123" not in result + assert "***redacted***" in result + + def test_masks_api_key_with_colon(self): + result = router._mask_secrets("api-key: mykey") + assert "mykey" not in result + + def test_empty_string_returns_empty(self): + assert router._mask_secrets("") == "" + + def test_none_returns_none(self): + assert router._mask_secrets(None) is None + + def test_no_secrets_unchanged(self): + text = "this is a normal log line" + assert router._mask_secrets(text) == text + + +class TestIsFresh: + def test_fresh_within_ttl(self): + cached_at = time.time() - 10 + assert router._is_fresh(cached_at, 300) is True + + def test_expired_beyond_ttl(self): + cached_at = time.time() - 400 + assert router._is_fresh(cached_at, 300) is False + + def test_exactly_at_boundary(self): + cached_at = time.time() - 300 + # May be True or False depending on timing, just verify it runs + result = router._is_fresh(cached_at, 300) + assert isinstance(result, bool) + + def test_just_cached(self): + assert router._is_fresh(time.time(), 1) is True + + +class TestNormalizeLlamaModelName: + def test_strips_hf_prefix(self): + assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF" + + def test_strips_quant_suffix(self): + assert router._normalize_llama_model_name("model:Q8_0") == "model" + + def test_strips_both(self): + result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16") + assert result == "gpt-oss-20b-GGUF" + + def test_no_prefix_no_suffix(self): + assert router._normalize_llama_model_name("plain-model") == "plain-model" + + def test_multiple_slashes(self): + result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M") + assert result == "model-name" + + +class TestExtractLlamaQuant: + def test_extracts_quant(self): + assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0" + + def test_no_quant_returns_empty(self): + assert router._extract_llama_quant("plain-model") == "" + + def test_f16(self): + assert router._extract_llama_quant("model:F16") == "F16" + + def test_q4_k_m(self): + assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M" + + +class TestIsUnixSocketEndpoint: + def test_sock_endpoint_detected(self): + assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True + + def test_regular_http_not_sock(self): + assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False + + def test_ollama_not_sock(self): + assert router._is_unix_socket_endpoint("http://localhost:11434") is False + + def test_dot_sock_in_host_detected(self): + assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True + + +class TestGetSocketPath: + def test_returns_run_user_path(self): + import os + path = router._get_socket_path("http://192.168.0.52.sock/v1") + uid = os.getuid() + assert path == f"/run/user/{uid}/192.168.0.52.sock" + + +class TestIsBase64: + def test_valid_base64(self): + import base64 + data = base64.b64encode(b"hello world").decode() + assert router.is_base64(data) is True + + def test_invalid_base64(self): + assert router.is_base64("not-base64!@#$") is False + + def test_empty_string(self): + # Empty string is valid base64 (decodes to empty bytes) + assert router.is_base64("") is True + + def test_non_string(self): + # Non-strings fall through without returning True (returns None) + assert not router.is_base64(12345) + + +class TestIsLlamaModelLoaded: + def test_status_dict_loaded(self): + assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True + + def test_status_dict_unloaded(self): + assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False + + def test_status_string_loaded(self): + assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True + + def test_status_string_unloaded(self): + assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False + + def test_no_status_field_always_loaded(self): + # No status field → always available (single-model server) + assert router._is_llama_model_loaded({"id": "m"}) is True + + def test_status_none_always_loaded(self): + assert router._is_llama_model_loaded({"id": "m", "status": None}) is True + + +class TestEp2Base: + def test_adds_v1_to_ollama(self): + assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1" + + def test_keeps_v1_if_present(self): + assert router.ep2base("http://host/v1") == "http://host/v1" + + def test_llama_server_endpoint_unchanged(self): + ep = "http://192.168.0.50:8889/v1" + assert router.ep2base(ep) == ep + + +class TestDedupeOnKeys: + def test_removes_duplicate_by_single_key(self): + items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}] + result = router.dedupe_on_keys(items, ["name"]) + assert len(result) == 2 + assert result[0]["name"] == "a" + assert result[1]["name"] == "b" + + def test_removes_duplicate_by_two_keys(self): + items = [ + {"digest": "abc", "name": "m1"}, + {"digest": "abc", "name": "m1"}, + {"digest": "def", "name": "m2"}, + ] + result = router.dedupe_on_keys(items, ["digest", "name"]) + assert len(result) == 2 + + def test_empty_list(self): + assert router.dedupe_on_keys([], ["name"]) == [] + + def test_no_duplicates(self): + items = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + assert len(router.dedupe_on_keys(items, ["name"])) == 3 + + +class TestFormatConnectionIssue: + def test_connector_error_message(self): + err = aiohttp.ClientConnectorError( + connection_key=MagicMock(host="localhost", port=11434), + os_error=OSError(111, "Connection refused"), + ) + msg = router._format_connection_issue("http://localhost:11434", err) + assert "localhost" in msg + assert "Connection refused" in msg or "111" in msg + + def test_timeout_error_message(self): + msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError()) + assert "Timed out" in msg + assert "host:1234" in msg + + def test_generic_error(self): + msg = router._format_connection_issue("http://host:1234", ValueError("boom")) + assert "host:1234" in msg + assert "boom" in msg + + +class TestIsExtOpenaiEndpoint: + def test_openai_com_is_ext(self): + cfg = MagicMock() + cfg.endpoints = [] + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True + + def test_ollama_default_port_not_ext(self): + cfg = MagicMock() + cfg.endpoints = ["http://host:11434"] + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.is_ext_openai_endpoint("http://host:11434") is False + + def test_llama_server_not_ext(self): + cfg = MagicMock() + cfg.endpoints = [] + cfg.llama_server_endpoints = ["http://host:8080/v1"] + with patch.object(router, "config", cfg): + assert router.is_ext_openai_endpoint("http://host:8080/v1") is False + + def test_no_v1_not_ext(self): + cfg = MagicMock() + cfg.endpoints = ["http://host:11434"] + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.is_ext_openai_endpoint("http://host:11434") is False + + +class TestIsOpenaiCompatible: + def test_v1_endpoint_compatible(self): + cfg = MagicMock() + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.is_openai_compatible("http://host/v1") is True + + def test_ollama_not_compatible(self): + cfg = MagicMock() + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.is_openai_compatible("http://localhost:11434") is False + + def test_llama_server_in_list_compatible(self): + cfg = MagicMock() + cfg.llama_server_endpoints = ["http://host:8080"] + with patch.object(router, "config", cfg): + assert router.is_openai_compatible("http://host:8080") is True + + +class TestGetTrackingModel: + def test_ollama_adds_latest(self): + cfg = MagicMock() + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest" + + def test_ollama_keeps_existing_tag(self): + cfg = MagicMock() + cfg.llama_server_endpoints = [] + with patch.object(router, "config", cfg): + assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b" + + def test_llama_server_normalizes(self): + ep = "http://host:8080/v1" + cfg = MagicMock() + cfg.llama_server_endpoints = [ep] + with patch.object(router, "config", cfg): + result = router.get_tracking_model(ep, "unsloth/model:Q8_0") + assert result == "model" diff --git a/test/test_unit_rechunk.py b/test/test_unit_rechunk.py new file mode 100644 index 0000000..e0d01c9 --- /dev/null +++ b/test/test_unit_rechunk.py @@ -0,0 +1,173 @@ +"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion.""" +import time +from types import SimpleNamespace + +import ollama + +import router + + +def _ns(**kw): + return SimpleNamespace(**kw) + + +def _stream_chunk(content="hi", role="assistant", finish_reason=None, + usage=None, model="m"): + """Build a SimpleNamespace mimicking a streaming OpenAI chunk.""" + delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None, + tool_calls=None) + choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None) + return _ns(model=model, choices=[choice], usage=usage) + + +def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop", + usage=None, model="m", tool_calls=None): + """Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion.""" + message = _ns(content=content, role=role, reasoning=None, reasoning_content=None, + tool_calls=tool_calls) + choice = _ns(message=message, finish_reason=finish_reason, logprobs=None) + return _ns(model=model, choices=[choice], usage=usage) + + +# ────────────────────────────────────────────────────────────────────────────── +# openai_chat_completion2ollama +# ────────────────────────────────────────────────────────────────────────────── + +class TestChatCompletionToOllama: + def test_streaming_content_chunk(self): + chunk = _stream_chunk(content="hello", finish_reason=None, usage=None) + out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) + assert isinstance(out, ollama.ChatResponse) + assert out.message.role == "assistant" + assert out.message.content == "hello" + assert out.done is False # usage is None → not done yet + assert out.model == "m" + + def test_streaming_empty_content_defaults(self): + # Some chunks have content=None — should coerce to empty string + chunk = _stream_chunk(content=None, role=None) + out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) + assert out.message.role == "assistant" # role defaulted + assert out.message.content == "" + + def test_final_usage_only_chunk_marks_done(self): + usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15) + chunk = _ns(model="m", choices=[], usage=usage) + out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) + assert out.done is True + assert out.done_reason == "stop" + assert out.prompt_eval_count == 10 + assert out.eval_count == 5 + assert out.message.content == "" + + def test_nonstreaming_with_content(self): + usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5) + chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage) + out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter()) + assert out.done is True + assert out.message.content == "response text" + assert out.prompt_eval_count == 2 + assert out.eval_count == 3 + + def test_nonstreaming_tool_calls_converted(self): + """Tool calls with JSON string arguments are parsed into dicts.""" + tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}')) + usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2) + chunk = _nonstream_chunk( + content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc] + ) + out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter()) + assert out.message.tool_calls is not None + assert len(out.message.tool_calls) == 1 + first = out.message.tool_calls[0] + assert first.function.name == "get_weather" + assert first.function.arguments == {"city": "Paris"} + + def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self): + tc = _ns(function=_ns(name="f", arguments="not-json")) + usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2) + chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc]) + out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter()) + assert out.message.tool_calls[0].function.arguments == {} + + def test_streaming_tool_calls_in_delta_are_skipped(self): + """Streaming mode must not assemble tool calls (caller handles it).""" + chunk = _stream_chunk(content="x", finish_reason=None) + # Even if a chunk somehow carried tool_calls in the delta, streaming + # mode should ignore them. + out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter()) + assert out.message.tool_calls is None + + +# ────────────────────────────────────────────────────────────────────────────── +# openai_completion2ollama +# ────────────────────────────────────────────────────────────────────────────── + +class TestCompletionToOllama: + def test_streaming_text_chunk(self): + choice = _ns(text="word", finish_reason=None, reasoning=None) + chunk = _ns(model="m", choices=[choice], usage=None) + out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter()) + assert isinstance(out, ollama.GenerateResponse) + assert out.response == "word" + assert out.done is False + + def test_final_chunk_with_usage(self): + usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10) + choice = _ns(text="end", finish_reason="stop", reasoning=None) + chunk = _ns(model="m", choices=[choice], usage=usage) + out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter()) + assert out.done is True + assert out.prompt_eval_count == 4 + assert out.eval_count == 6 + + +# ────────────────────────────────────────────────────────────────────────────── +# embeddings / embed +# ────────────────────────────────────────────────────────────────────────────── + +class TestEmbeddingConversions: + def test_openai_embeddings2ollama(self): + chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])]) + out = router.rechunk.openai_embeddings2ollama(chunk) + assert isinstance(out, ollama.EmbeddingsResponse) + assert list(out.embedding) == [0.1, 0.2, 0.3] + + def test_openai_embed2ollama(self): + chunk = _ns(data=[_ns(embedding=[0.5, 0.6])]) + out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model") + assert isinstance(out, ollama.EmbedResponse) + assert out.model == "my-embed-model" + assert list(out.embeddings[0]) == [0.5, 0.6] + + +# ────────────────────────────────────────────────────────────────────────────── +# extract_usage_from_llama_timings +# ────────────────────────────────────────────────────────────────────────────── + +class TestExtractUsageFromLlamaTimings: + def test_none_when_no_timings_attr(self): + obj = _ns() + assert router.rechunk.extract_usage_from_llama_timings(obj) is None + + def test_prompt_plus_cache_sums(self): + obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35}) + prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj) + assert prompt == 237 + assert completion == 35 + + def test_missing_keys_default_to_zero(self): + obj = _ns(timings={"predicted_n": 12}) + prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj) + assert prompt == 0 + assert completion == 12 + + def test_null_values_treated_as_zero(self): + obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None}) + prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj) + assert prompt == 0 + assert completion == 0 + + def test_non_dict_timings_returns_none(self): + obj = _ns(timings="not-a-dict") + assert router.rechunk.extract_usage_from_llama_timings(obj) is None diff --git a/test/test_unit_transforms.py b/test/test_unit_transforms.py new file mode 100644 index 0000000..51160a0 --- /dev/null +++ b/test/test_unit_transforms.py @@ -0,0 +1,200 @@ +"""Unit tests for message transformation functions.""" +from unittest.mock import MagicMock + +import pytest + +import router + + +class TestStripAssistantPrefill: + def test_removes_trailing_assistant(self): + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "prefill"}, + ] + result = router._strip_assistant_prefill(msgs) + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_keeps_non_trailing_assistant(self): + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "response"}, + {"role": "user", "content": "follow-up"}, + ] + result = router._strip_assistant_prefill(msgs) + assert len(result) == 3 + + def test_empty_list_unchanged(self): + assert router._strip_assistant_prefill([]) == [] + + def test_single_user_message_unchanged(self): + msgs = [{"role": "user", "content": "hi"}] + assert router._strip_assistant_prefill(msgs) == msgs + + +class TestTransformToolCallsToOpenAI: + def test_adds_type_function(self): + msgs = [{"role": "assistant", "tool_calls": [ + {"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}} + ]}] + result = router.transform_tool_calls_to_openai(msgs) + tc = result[0]["tool_calls"][0] + assert tc["type"] == "function" + + def test_adds_id_when_missing(self): + msgs = [{"role": "assistant", "tool_calls": [ + {"function": {"name": "fn", "arguments": {}}} + ]}] + result = router.transform_tool_calls_to_openai(msgs) + assert "id" in result[0]["tool_calls"][0] + + def test_converts_dict_arguments_to_string(self): + msgs = [{"role": "assistant", "tool_calls": [ + {"function": {"name": "fn", "arguments": {"key": "val"}}} + ]}] + result = router.transform_tool_calls_to_openai(msgs) + args = result[0]["tool_calls"][0]["function"]["arguments"] + assert isinstance(args, str) + import orjson + parsed = orjson.loads(args) + assert parsed == {"key": "val"} + + def test_keeps_string_arguments_unchanged(self): + msgs = [{"role": "assistant", "tool_calls": [ + {"function": {"name": "fn", "arguments": '{"key": "val"}'}} + ]}] + result = router.transform_tool_calls_to_openai(msgs) + args = result[0]["tool_calls"][0]["function"]["arguments"] + assert args == '{"key": "val"}' + + def test_links_tool_call_id_to_tool_response(self): + msgs = [ + {"role": "assistant", "tool_calls": [ + {"function": {"name": "get_weather", "arguments": {}}} + ]}, + {"role": "tool", "name": "get_weather", "content": "sunny"}, + ] + result = router.transform_tool_calls_to_openai(msgs) + tc_id = result[0]["tool_calls"][0]["id"] + assert result[1].get("tool_call_id") == tc_id + + def test_non_tool_messages_unchanged(self): + msgs = [{"role": "user", "content": "hello"}] + result = router.transform_tool_calls_to_openai(msgs) + assert result == msgs + + +class TestStripImagesFromMessages: + def test_removes_image_url_parts(self): + msgs = [{"role": "user", "content": [ + {"type": "text", "text": "what is this?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ]}] + result = router._strip_images_from_messages(msgs) + content = result[0]["content"] + assert content == "what is this?" + + def test_keeps_text_only_messages(self): + msgs = [{"role": "user", "content": "plain text"}] + result = router._strip_images_from_messages(msgs) + assert result[0]["content"] == "plain text" + + def test_multiple_text_parts_kept_as_list(self): + msgs = [{"role": "user", "content": [ + {"type": "text", "text": "part one"}, + {"type": "text", "text": "part two"}, + {"type": "image_url", "image_url": {"url": "data:..."}}, + ]}] + result = router._strip_images_from_messages(msgs) + content = result[0]["content"] + assert isinstance(content, list) + assert len(content) == 2 + + def test_all_images_removed_empty_list(self): + msgs = [{"role": "user", "content": [ + {"type": "image_url", "image_url": {"url": "data:..."}}, + ]}] + result = router._strip_images_from_messages(msgs) + # Image-only content becomes empty list + content = result[0]["content"] + assert content == [] + + +class TestAccumulateOpenAITcDelta: + def _make_chunk(self, index, name=None, args_fragment="", tc_id=None): + delta = MagicMock() + tc = MagicMock() + tc.index = index + tc.id = tc_id + tc.function = MagicMock() + tc.function.name = name + tc.function.arguments = args_fragment + delta.tool_calls = [tc] + chunk = MagicMock() + chunk.choices = [MagicMock(delta=delta)] + return chunk + + def test_first_delta_creates_entry(self): + acc = {} + chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"') + router._accumulate_openai_tc_delta(chunk, acc) + assert 0 in acc + assert acc[0]["name"] == "my_fn" + assert acc[0]["arguments"] == '{"k"' + + def test_subsequent_deltas_concatenate_args(self): + acc = {} + router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc) + router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc) + assert acc[0]["arguments"] == '{"k": "v"}' + + def test_multiple_tool_calls_tracked_separately(self): + acc = {} + c1 = self._make_chunk(0, name="fn1", args_fragment="{}") + c2 = self._make_chunk(1, name="fn2", args_fragment="{}") + chunk = MagicMock() + tc1 = MagicMock() + tc1.index = 0 + tc1.id = "id1" + tc1.function = MagicMock(name="fn1", arguments="{}") + tc2 = MagicMock() + tc2.index = 1 + tc2.id = "id2" + tc2.function = MagicMock(name="fn2", arguments="{}") + chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))] + router._accumulate_openai_tc_delta(chunk, acc) + assert 0 in acc and 1 in acc + + def test_no_choices_is_noop(self): + acc = {} + chunk = MagicMock(choices=[]) + router._accumulate_openai_tc_delta(chunk, acc) + assert acc == {} + + +class TestBuildOllamaToolCalls: + def test_builds_from_accumulator(self): + acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}} + result = router._build_ollama_tool_calls(acc) + assert result is not None + assert len(result) == 1 + assert result[0].function.name == "get_weather" + assert result[0].function.arguments == {"city": "Berlin"} + + def test_invalid_json_args_becomes_empty_dict(self): + acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}} + result = router._build_ollama_tool_calls(acc) + assert result[0].function.arguments == {} + + def test_empty_accumulator_returns_none(self): + assert router._build_ollama_tool_calls({}) is None + + def test_preserves_order_by_index(self): + acc = { + 1: {"id": "c2", "name": "fn2", "arguments": "{}"}, + 0: {"id": "c1", "name": "fn1", "arguments": "{}"}, + } + result = router._build_ollama_tool_calls(acc) + assert result[0].function.name == "fn1" + assert result[1].function.name == "fn2" From 64e38978a94206ef8403ba5642f329b3f99240f7 Mon Sep 17 00:00:00 2001 From: Renovate Bot Date: Sat, 16 May 2026 07:46:35 +0000 Subject: [PATCH 15/22] chore(deps): update dependency openai to v2 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c187485..115fe04 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ idna==3.15 jiter==0.14.0 multidict==6.7.1 ollama==0.6.2 -openai==1.109.1 +openai==2.37.0 orjson>=3.11.5 numpy>=1.26 pillow==12.2.0 From 27bc57d4a4aaeabc3ce47862b2e7df7176caf44f Mon Sep 17 00:00:00 2001 From: Renovate Bot Date: Sun, 17 May 2026 00:58:02 +0000 Subject: [PATCH 16/22] chore(deps): update dependency click to v8.4.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 115fe04..5920529 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ anyio==4.13.0 async-timeout==5.0.1 attrs==26.1.0 certifi==2026.4.22 -click==8.3.3 +click==8.4.0 distro==1.9.0 exceptiongroup==1.3.1 fastapi==0.136.1 From 0b64a84e96bb3f97988811ab43a3dca78f8db8b4 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Sun, 17 May 2026 10:53:33 +0200 Subject: [PATCH 17/22] fix: replace hardcoded tokendb path --- test/config_test.yaml | 2 -- test/conftest.py | 9 ++++++++- test/test_cache.py | 6 +++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/config_test.yaml b/test/config_test.yaml index f05cce7..30f2fa3 100644 --- a/test/config_test.yaml +++ b/test/config_test.yaml @@ -10,6 +10,4 @@ api_keys: "http://192.168.0.51:12434": "ollama" "http://192.168.0.51:12434/v1": "llama" -db_path: "/tmp/nomyo_test_tokens.db" - cache_enabled: false diff --git a/test/conftest.py b/test/conftest.py index c95fa2d..c5142da 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -16,6 +16,7 @@ import asyncio import os import ssl import sys +import tempfile from pathlib import Path from unittest.mock import MagicMock, patch @@ -24,8 +25,14 @@ import httpx import pytest _TEST_DIR = Path(__file__).parent -# Must be set before importing router so the module-level Config.from_yaml uses test config +# Must be set before importing router so module-level Config.from_yaml + Config field +# defaults pick these up. db_path is intentionally absent from config_test.yaml so the +# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows). os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml")) +os.environ.setdefault( + "NOMYO_ROUTER_DB_PATH", + str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"), +) sys.path.insert(0, str(_TEST_DIR.parent)) diff --git a/test/test_cache.py b/test/test_cache.py index cd37688..f2ce1a9 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -1,4 +1,6 @@ """Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed).""" +import tempfile +from pathlib import Path from types import SimpleNamespace import orjson @@ -13,6 +15,8 @@ from cache import ( openai_nonstream_to_sse, ) +_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db") + def _exact_cfg(backend: str = "memory") -> SimpleNamespace: """Config for exact-match mode — similarity=1.0 avoids embedding deps.""" @@ -22,7 +26,7 @@ def _exact_cfg(backend: str = "memory") -> SimpleNamespace: cache_similarity=1.0, cache_history_weight=0.3, cache_ttl=300, - cache_db_path="/tmp/test_cache.db", + cache_db_path=_CACHE_DB_PATH, cache_redis_url="redis://localhost:6379", ) From db6aa739036bc774a1071ba42dc630c64cfff790 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Mon, 18 May 2026 13:45:06 +0200 Subject: [PATCH 18/22] fix: - _fetch_loaded_models_internal now writes _loaded_error_cache[endpoint] = time.time() on /api/ps or /v1/models failure, and clears the entry on success - choose_endpoint now filters out candidates with a fresh (<300s) loaded-models error. - /health now probes both /api/version and /api/ps for Ollama endpoints - dashboard adaption relates to #83 --- router.py | 195 ++++++++++++++++++++++++----------- static/index.html | 60 ++++++----- test/test_choose_endpoint.py | 56 +++++++++- test/test_fetch.py | 30 ++++++ 4 files changed, 251 insertions(+), 90 deletions(-) diff --git a/router.py b/router.py index 08225fb..c465ebc 100644 --- a/router.py +++ b/router.py @@ -1000,7 +1000,7 @@ class fetch: async with client.get(f"{endpoint}/models") as resp: await _ensure_success(resp) data = await resp.json() - + # Filter for loaded models only items = data.get("data", []) models = { @@ -1012,11 +1012,19 @@ class fetch: # Update cache with lock protection async with _loaded_models_cache_lock: _loaded_models_cache[endpoint] = (models, time.time()) + # Probe succeeded — clear any stale error so the endpoint + # becomes routable again. + async with _loaded_error_cache_lock: + _loaded_error_cache.pop(endpoint, None) return models except Exception as e: # If anything goes wrong we simply assume the endpoint has no models message = _format_connection_issue(f"{endpoint}/models", e) print(f"[fetch.loaded_models] {message}") + # Record the failure so `choose_endpoint` can avoid routing + # to an unhealthy backend and repeated probes short-circuit. + async with _loaded_error_cache_lock: + _loaded_error_cache[endpoint] = time.time() return set() else: # Original Ollama /api/ps logic @@ -1031,11 +1039,15 @@ class fetch: # Update cache with lock protection async with _loaded_models_cache_lock: _loaded_models_cache[endpoint] = (models, time.time()) + async with _loaded_error_cache_lock: + _loaded_error_cache.pop(endpoint, None) return models except Exception as e: # If anything goes wrong we simply assume the endpoint has no models message = _format_connection_issue(f"{endpoint}/api/ps", e) print(f"[fetch.loaded_models] {message}") + async with _loaded_error_cache_lock: + _loaded_error_cache[endpoint] = time.time() return set() async def _refresh_loaded_models(endpoint: str) -> None: @@ -1853,6 +1865,28 @@ async def choose_endpoint(model: str, reserve: bool = True, load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints] loaded_sets = await asyncio.gather(*load_tasks) + # 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing + # recently. Without this filter, an endpoint where `/api/ps` returns 5xx + # would appear with an empty loaded set but pass through to the + # free-slot fallback (step 4) — sending completion calls to an + # unhealthy backend. See issue #83. + async with _loaded_error_cache_lock: + unhealthy = { + ep for ep, ts in _loaded_error_cache.items() + if _is_fresh(ts, 300) + } + if unhealthy: + filtered = [ + (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets) + if ep not in unhealthy + ] + if filtered: + candidate_endpoints = [ep for ep, _ in filtered] + loaded_sets = [models for _, models in filtered] + # If *every* candidate is unhealthy we still fall through with the + # original list — refusing to route is worse than retrying a + # possibly-recovered backend. + # Look up a possible affinity hint *before* taking usage_lock. The two # locks are never held together to avoid lock-ordering issues. affine_ep: Optional[str] = None @@ -3154,44 +3188,103 @@ async def usage_proxy(request: Request): "token_usage_counts": token_usage_counts} # ------------------------------------------------------------- -# 20. Proxy config route – for monitoring and frontent usage +# 20. Endpoint health probes (shared by /api/config and /health) +# ------------------------------------------------------------- +async def _raw_probe( + ep: str, + route: str, + api_key: Optional[str] = None, + timeout: Optional[float] = None, +) -> tuple[bool, object]: + """Direct HTTP probe that distinguishes success from failure + (unlike `fetch.endpoint_details`, which returns [] on either). + Returns `(ok, payload_or_error_message)`. + """ + headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")} + if api_key is not None: + headers["Authorization"] = "Bearer " + api_key + url = f"{ep.rstrip('/')}/{route.lstrip('/')}" + req_kwargs = {} + if timeout is not None: + req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout) + try: + client: aiohttp.ClientSession = get_session(ep) + async with client.get(url, headers=headers, **req_kwargs) as resp: + await _ensure_success(resp) + data = await resp.json() + return True, data + except Exception as exc: + return False, _format_connection_issue(url, exc) + + +async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict: + """Probe an endpoint and return `{status, version?, detail?}`. + + Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so + that a daemon which is reachable but has a broken model-introspection + path (issue #83) is reported as `error` rather than `ok`. + OpenAI-compatible endpoints use a single `/models` probe. + """ + if is_openai_compatible(ep): + ok, payload = await _raw_probe( + ep, "/models", config.api_keys.get(ep), timeout=timeout, + ) + if ok: + return {"status": "ok", "version": "latest"} + return {"status": "error", "detail": str(payload)} + + (version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather( + _raw_probe(ep, "/api/version", timeout=timeout), + _raw_probe(ep, "/api/ps", timeout=timeout), + ) + + version_value = ( + version_payload.get("version") + if version_ok and isinstance(version_payload, dict) + else None + ) + + if version_ok and ps_ok: + return {"status": "ok", "version": version_value} + if not version_ok and not ps_ok: + return {"status": "error", "detail": str(version_payload)} + # Partial failure — daemon reachable but one probe failed. Report + # as "error" so callers can surface the issue; include `version` so + # the operator knows the daemon itself is alive. + if not ps_ok: + return { + "status": "error", + "version": version_value, + "detail": f"/api/ps: {ps_payload}", + } + return { + "status": "error", + "detail": f"/api/version: {version_payload}", + } + + +# ------------------------------------------------------------- +# 20b. Proxy config route – for monitoring and frontend usage # ------------------------------------------------------------- @app.get("/api/config") async def config_proxy(request: Request): """ Return a simple JSON object that contains the configured - Ollama endpoints and llama_server_endpoints. The front‑end uses this to display - which endpoints are being proxied. + Ollama endpoints and llama_server_endpoints. The front‑end uses this + to display which endpoints are being proxied and their health. + Status is "error" when either liveness (/api/version) or routing + health (/api/ps) fails — see issue #83. """ - async def check_endpoint(url: str): - client: aiohttp.ClientSession = get_session(url) - headers = None - if "/v1" in url: - headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")} - target_url = f"{url}/models" - else: - target_url = f"{url}/api/version" + async def check(url: str) -> dict: + return {"url": url, **(await _endpoint_health(url, timeout=5))} - try: - async with client.get(target_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: - await _ensure_success(resp) - data = await resp.json() - if "/v1" in url: - return {"url": url, "status": "ok", "version": "latest"} - else: - return {"url": url, "status": "ok", "version": data.get("version")} - except Exception as e: - detail = _format_connection_issue(target_url, e) - return {"url": url, "status": "error", "detail": detail} - - # Check Ollama endpoints - ollama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints]) - - # Check llama-server endpoints + ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints]) llama_results = [] if config.llama_server_endpoints: - llama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.llama_server_endpoints]) - + llama_results = await asyncio.gather( + *[check(ep) for ep in config.llama_server_endpoints] + ) + return { "endpoints": ollama_results, "llama_server_endpoints": llama_results, @@ -4003,44 +4096,30 @@ async def health_proxy(request: Request): """ Health‑check endpoint for monitoring the proxy. - * Queries each configured endpoint for its `/api/version` response. + * Queries each configured endpoint for both liveness and routing health: + Ollama endpoints are probed at `/api/version` AND `/api/ps`, + OpenAI-compatible endpoints at `/models`. * Returns a JSON object containing: - - `status`: "ok" if every endpoint replied, otherwise "error". + - `status`: "ok" if every endpoint replied to every probe, otherwise "error". - `endpoints`: a mapping of endpoint URL → `{status, version|detail}`. * The HTTP status code is 200 when everything is healthy, 503 otherwise. """ # Run all health checks in parallel. - # Ollama endpoints expose /api/version; OpenAI-compatible endpoints (vLLM, - # llama-server, external) expose /models. Using /api/version against an - # OpenAI-compatible endpoint yields a 404 and noisy log output. + # Ollama endpoints expose /api/version (liveness) and /api/ps (routing + # health — required by `choose_endpoint`). OpenAI-compatible endpoints + # (vLLM, llama-server, external) expose /models, which serves both + # purposes. Probing /api/version alone would miss the case where the + # Ollama process is up but /api/ps is failing — see issue #83. all_endpoints = list(config.endpoints) llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] all_endpoints += llama_eps_extra - tasks = [] - for ep in all_endpoints: - if is_openai_compatible(ep): - tasks.append(fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True)) - else: - tasks.append(fetch.endpoint_details(ep, "/api/version", "version", skip_error_cache=True)) + probe_results = await asyncio.gather( + *(_endpoint_health(ep) for ep in all_endpoints), + ) - results = await asyncio.gather(*tasks, return_exceptions=True) - - health_summary = {} - overall_ok = True - - for ep, result in zip(all_endpoints, results): - if isinstance(result, Exception): - # Endpoint did not respond / returned an error - health_summary[ep] = {"status": "error", "detail": str(result)} - overall_ok = False - else: - # Successful response – report the reported version (Ollama) or - # indicate the endpoint is reachable (OpenAI-compatible). - if is_openai_compatible(ep): - health_summary[ep] = {"status": "ok"} - else: - health_summary[ep] = {"status": "ok", "version": result} + health_summary = dict(zip(all_endpoints, probe_results)) + overall_ok = all(entry.get("status") == "ok" for entry in probe_results) response_payload = { "status": "ok" if overall_ok else "error", diff --git a/static/index.html b/static/index.html index b29f22b..8c0b16c 100644 --- a/static/index.html +++ b/static/index.html @@ -192,6 +192,10 @@ color: #8b0000; font-weight: bold; } + .status-error[title] { + cursor: help; + text-decoration: underline dotted; + } .copy-link, .delete-link, .show-link, @@ -736,6 +740,16 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { return await resp.json(); } + function escapeHtml(value) { + if (value === null || value === undefined) return ""; + return String(value) + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + } + function toggleDarkMode() { document.documentElement.classList.toggle("dark-mode"); } @@ -752,40 +766,24 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) { // Build HTML for both endpoints and llama_server_endpoints let html = ""; - // Add Ollama endpoints - html += data.endpoints - .map((e) => { - const statusClass = - e.status === "ok" - ? "status-ok" - : "status-error"; - const version = e.version || "N/A"; - return ` + const renderRow = (e) => { + const statusClass = + e.status === "ok" ? "status-ok" : "status-error"; + const version = e.version || "N/A"; + const titleAttr = e.detail + ? ` title="${escapeHtml(e.detail)}"` + : ""; + return ` - ${e.url} - ${e.status} - ${version} + ${escapeHtml(e.url)} + ${escapeHtml(e.status)} + ${escapeHtml(version)} `; - }) - .join(""); - - // Add llama-server endpoints + }; + + html += data.endpoints.map(renderRow).join(""); if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) { - html += data.llama_server_endpoints - .map((e) => { - const statusClass = - e.status === "ok" - ? "status-ok" - : "status-error"; - const version = e.version || "N/A"; - return ` - - ${e.url} - ${e.status} - ${version} - `; - }) - .join(""); + html += data.llama_server_endpoints.map(renderRow).join(""); } body.innerHTML = html; diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py index b94fdf1..ece609a 100644 --- a/test/test_choose_endpoint.py +++ b/test/test_choose_endpoint.py @@ -1,4 +1,5 @@ """Tests for choose_endpoint routing logic with mocked fetch calls.""" +import time from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -25,10 +26,12 @@ def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, prior @pytest.fixture(autouse=True) def reset_usage(): - """Clear usage_counts between tests to prevent bleed.""" + """Clear usage_counts and error caches between tests to prevent bleed.""" router.usage_counts.clear() + router._loaded_error_cache.clear() yield router.usage_counts.clear() + router._loaded_error_cache.clear() class TestChooseEndpointBasic: @@ -102,6 +105,57 @@ class TestChooseEndpointBasic: # Least-busy is EP2 assert ep == EP2 + async def test_excludes_endpoint_with_recent_loaded_error(self): + # Regression: issue #83 — when /api/ps fails for EP1 but EP1 + # still advertises the model via /api/tags, routing must not + # fall back to EP1 just because it has a free slot. + cfg = _make_cfg([EP1, EP2]) + + async def available(ep, *_): + return {"llama3.2:latest"} + + # EP1's /api/ps probe failed recently; EP2 is fine but the model + # is not loaded there. Without the health filter, EP1 would be + # picked by the free-slot fallback (step 4 in choose_endpoint). + router._loaded_error_cache[EP1] = time.time() + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", side_effect=available), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), + ): + ep, _ = await router.choose_endpoint("llama3.2:latest") + assert ep == EP2 + + async def test_stale_loaded_error_does_not_exclude(self): + # Errors older than the 300s window must not keep an endpoint + # excluded forever. + cfg = _make_cfg([EP1]) + router._loaded_error_cache[EP1] = time.time() - 301 + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"m:latest"})), + ): + ep, _ = await router.choose_endpoint("m:latest") + assert ep == EP1 + + async def test_all_unhealthy_still_routes(self): + # If every candidate has a fresh loaded-error we still try one + # (it may have recovered between the cache write and now) rather + # than refusing to route. + cfg = _make_cfg([EP1]) + router._loaded_error_cache[EP1] = time.time() + + with ( + patch.object(router, "config", cfg), + patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})), + patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())), + ): + ep, _ = await router.choose_endpoint("m:latest") + assert ep == EP1 + async def test_reserve_increments_usage(self): cfg = _make_cfg([EP1]) with ( diff --git a/test/test_fetch.py b/test/test_fetch.py index 9b542a1..6f2ed50 100644 --- a/test/test_fetch.py +++ b/test/test_fetch.py @@ -178,3 +178,33 @@ class TestFetchLoadedModels: first = await router.fetch.loaded_models(MOCK_OLLAMA_EP) second = await router.fetch.loaded_models(MOCK_OLLAMA_EP) assert first == second + + async def test_records_error_in_loaded_error_cache_on_failure(self): + # Regression: issue #83 — /api/ps failures must be recorded so + # `choose_endpoint` can exclude unhealthy backends from routing. + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={}) + await router.fetch.loaded_models(MOCK_OLLAMA_EP) + assert MOCK_OLLAMA_EP in router._loaded_error_cache + + async def test_records_error_for_llama_server_on_failure(self): + cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP]) + with patch.object(router, "config", cfg), aioresponses() as m: + m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={}) + await router.fetch.loaded_models(MOCK_LLAMA_EP) + assert MOCK_LLAMA_EP in router._loaded_error_cache + + async def test_clears_error_cache_on_subsequent_success(self): + cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) + # Pre-seed an old error so loaded_models() falls through to the + # network probe instead of short-circuiting on the error cache. + async with router._loaded_error_cache_lock: + router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301 + with patch.object(router, "config", cfg), aioresponses() as m: + m.get( + f"{MOCK_OLLAMA_EP}/api/ps", + payload={"models": [{"name": "qwen:7b"}]}, + ) + await router.fetch.loaded_models(MOCK_OLLAMA_EP) + assert MOCK_OLLAMA_EP not in router._loaded_error_cache From 59b59386ac0e61503552a27784bfd1793cba05f6 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Mon, 18 May 2026 15:59:52 +0200 Subject: [PATCH 19/22] fix: futureproof docker builds related to #84 --- .../workflows/docker-publish-semantic.yml | 32 +++++++++++++++---- .forgejo/workflows/docker-publish.yml | 30 ++++++++++++++--- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/.forgejo/workflows/docker-publish-semantic.yml b/.forgejo/workflows/docker-publish-semantic.yml index 2fa59d5..d4e1213 100644 --- a/.forgejo/workflows/docker-publish-semantic.yml +++ b/.forgejo/workflows/docker-publish-semantic.yml @@ -76,19 +76,30 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.REGISTRY_TOKEN }} - - name: Build and push platform image + - name: Build and push image by digest id: build uses: https://github.com/docker/build-push-action@v6 with: context: . platforms: ${{ matrix.platform }} - push: true provenance: false build-args: | SEMANTIC_CACHE=true - tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }} - cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }} - cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }},mode=max + outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: https://github.com/actions/upload-artifact@v4 + with: + name: digests-semantic-${{ matrix.arch }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 merge: runs-on: docker-amd64 @@ -117,6 +128,13 @@ jobs: cat /tmp/dockerd.log exit 1 + - name: Download digests + uses: https://github.com/actions/download-artifact@v4 + with: + path: /tmp/digests + pattern: digests-semantic-* + merge-multiple: true + - name: Set up Docker Buildx uses: https://github.com/docker/setup-buildx-action@v3 @@ -141,9 +159,9 @@ jobs: type=sha,prefix=sha-,suffix=-semantic - name: Create and push multi-arch manifest + working-directory: /tmp/digests run: | docker buildx imagetools create \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64 \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64 + $(printf '${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@sha256:%s ' *) diff --git a/.forgejo/workflows/docker-publish.yml b/.forgejo/workflows/docker-publish.yml index 27cd879..3b69030 100644 --- a/.forgejo/workflows/docker-publish.yml +++ b/.forgejo/workflows/docker-publish.yml @@ -69,15 +69,28 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.REGISTRY_TOKEN }} - - name: Build and push platform image + - name: Build and push image by digest id: build uses: https://github.com/docker/build-push-action@v6 with: context: . platforms: ${{ matrix.platform }} - push: true provenance: false - tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }} + outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: https://github.com/actions/upload-artifact@v4 + with: + name: digests-${{ matrix.arch }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 merge: runs-on: docker-amd64 @@ -106,6 +119,13 @@ jobs: cat /tmp/dockerd.log exit 1 + - name: Download digests + uses: https://github.com/actions/download-artifact@v4 + with: + path: /tmp/digests + pattern: digests-* + merge-multiple: true + - name: Set up Docker Buildx uses: https://github.com/docker/setup-buildx-action@v3 @@ -130,9 +150,9 @@ jobs: type=sha,prefix=sha- - name: Create and push multi-arch manifest + working-directory: /tmp/digests run: | docker buildx imagetools create \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64 \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64 + $(printf '${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@sha256:%s ' *) From ea8cda73d9f6894210828b89a95b03c2c5e57452 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Mon, 18 May 2026 16:31:09 +0200 Subject: [PATCH 20/22] fix: use unique-per-run platform tags. --- .../workflows/docker-publish-semantic.yml | 30 ++++--------------- .forgejo/workflows/docker-publish.yml | 30 ++++--------------- 2 files changed, 10 insertions(+), 50 deletions(-) diff --git a/.forgejo/workflows/docker-publish-semantic.yml b/.forgejo/workflows/docker-publish-semantic.yml index d4e1213..163f1a1 100644 --- a/.forgejo/workflows/docker-publish-semantic.yml +++ b/.forgejo/workflows/docker-publish-semantic.yml @@ -76,30 +76,17 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.REGISTRY_TOKEN }} - - name: Build and push image by digest + - name: Build and push platform image id: build uses: https://github.com/docker/build-push-action@v6 with: context: . platforms: ${{ matrix.platform }} + push: true provenance: false build-args: | SEMANTIC_CACHE=true - outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true - - - name: Export digest - run: | - mkdir -p /tmp/digests - digest="${{ steps.build.outputs.digest }}" - touch "/tmp/digests/${digest#sha256:}" - - - name: Upload digest - uses: https://github.com/actions/upload-artifact@v4 - with: - name: digests-semantic-${{ matrix.arch }} - path: /tmp/digests/* - if-no-files-found: error - retention-days: 1 + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}-${{ github.run_id }} merge: runs-on: docker-amd64 @@ -128,13 +115,6 @@ jobs: cat /tmp/dockerd.log exit 1 - - name: Download digests - uses: https://github.com/actions/download-artifact@v4 - with: - path: /tmp/digests - pattern: digests-semantic-* - merge-multiple: true - - name: Set up Docker Buildx uses: https://github.com/docker/setup-buildx-action@v3 @@ -159,9 +139,9 @@ jobs: type=sha,prefix=sha-,suffix=-semantic - name: Create and push multi-arch manifest - working-directory: /tmp/digests run: | docker buildx imagetools create \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - $(printf '${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@sha256:%s ' *) + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64-${{ github.run_id }} \ + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64-${{ github.run_id }} diff --git a/.forgejo/workflows/docker-publish.yml b/.forgejo/workflows/docker-publish.yml index 3b69030..09e145c 100644 --- a/.forgejo/workflows/docker-publish.yml +++ b/.forgejo/workflows/docker-publish.yml @@ -69,28 +69,15 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.REGISTRY_TOKEN }} - - name: Build and push image by digest + - name: Build and push platform image id: build uses: https://github.com/docker/build-push-action@v6 with: context: . platforms: ${{ matrix.platform }} + push: true provenance: false - outputs: type=image,name=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true - - - name: Export digest - run: | - mkdir -p /tmp/digests - digest="${{ steps.build.outputs.digest }}" - touch "/tmp/digests/${digest#sha256:}" - - - name: Upload digest - uses: https://github.com/actions/upload-artifact@v4 - with: - name: digests-${{ matrix.arch }} - path: /tmp/digests/* - if-no-files-found: error - retention-days: 1 + tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}-${{ github.run_id }} merge: runs-on: docker-amd64 @@ -119,13 +106,6 @@ jobs: cat /tmp/dockerd.log exit 1 - - name: Download digests - uses: https://github.com/actions/download-artifact@v4 - with: - path: /tmp/digests - pattern: digests-* - merge-multiple: true - - name: Set up Docker Buildx uses: https://github.com/docker/setup-buildx-action@v3 @@ -150,9 +130,9 @@ jobs: type=sha,prefix=sha- - name: Create and push multi-arch manifest - working-directory: /tmp/digests run: | docker buildx imagetools create \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - $(printf '${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@sha256:%s ' *) + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64-${{ github.run_id }} \ + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64-${{ github.run_id }} From 539d5f98a21cc01f7f44eeef5b48a75770eb7199 Mon Sep 17 00:00:00 2001 From: alpha nerd Date: Mon, 18 May 2026 17:03:04 +0200 Subject: [PATCH 21/22] doc: update on /health and /api/config endpoints --- doc/architecture.md | 2 ++ doc/monitoring.md | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/doc/architecture.md b/doc/architecture.md index f725573..c2408d8 100644 --- a/doc/architecture.md +++ b/doc/architecture.md @@ -206,6 +206,8 @@ The `/health` endpoint provides comprehensive health status: } ``` +For Ollama endpoints the probe is a parallel check of `/api/version` (liveness) and `/api/ps` (the route used by `choose_endpoint` when selecting a backend for a request). Reporting `ok` only when both succeed prevents the router from advertising an endpoint as healthy while completion calls dead-end on `/api/ps`. The same dual probe backs `/api/config`, which the dashboard uses to render endpoint health. + ## Database Schema The router uses SQLite for persistent storage: diff --git a/doc/monitoring.md b/doc/monitoring.md index ab75d25..9ce25ec 100644 --- a/doc/monitoring.md +++ b/doc/monitoring.md @@ -29,6 +29,10 @@ Response: - `200`: All endpoints healthy - `503`: One or more endpoints unhealthy +**Probe scope per endpoint**: +- **Ollama endpoints** are probed at both `/api/version` (liveness) and `/api/ps` (model-introspection used by the router). If either fails the endpoint is reported as `error`; the response still includes `version` when the daemon is reachable so operators can tell a partial failure from a full outage. The `detail` field names the failing probe, e.g. `"/api/ps: 502 …"`. +- **OpenAI-compatible / llama-server endpoints** are probed at `/models`. + ### Current Usage ```bash @@ -133,6 +137,8 @@ Response: } ``` +Uses the same dual-probe logic as `/health` (Ollama: `/api/version` + `/api/ps`; OpenAI-compatible: `/models`). An endpoint will report `error` whenever either probe fails. The dashboard renders the `detail` field as a tooltip on the status cell. + ### Cache Statistics ```bash From 9bba10d7f4ec8f0e73ec50efad7e1dbe40ac67ac Mon Sep 17 00:00:00 2001 From: Renovate Bot Date: Tue, 19 May 2026 10:10:54 +0000 Subject: [PATCH 22/22] chore(deps): update dependency jiter to v0.15.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5920529..15512a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ h11==0.16.0 httpcore==1.0.9 httpx==0.28.1 idna==3.15 -jiter==0.14.0 +jiter==0.15.0 multidict==6.7.1 ollama==0.6.2 openai==2.37.0