diff --git a/.forgejo/workflows/docker-publish-semantic.yml b/.forgejo/workflows/docker-publish-semantic.yml
index 163f1a1..2fa59d5 100644
--- a/.forgejo/workflows/docker-publish-semantic.yml
+++ b/.forgejo/workflows/docker-publish-semantic.yml
@@ -86,7 +86,9 @@ jobs:
provenance: false
build-args: |
SEMANTIC_CACHE=true
- tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}-${{ github.run_id }}
+ tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}
+ cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }}
+ cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }},mode=max
merge:
runs-on: docker-amd64
@@ -142,6 +144,6 @@ jobs:
run: |
docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64-${{ github.run_id }} \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64-${{ github.run_id }}
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64 \
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64
diff --git a/.forgejo/workflows/docker-publish.yml b/.forgejo/workflows/docker-publish.yml
index 09e145c..27cd879 100644
--- a/.forgejo/workflows/docker-publish.yml
+++ b/.forgejo/workflows/docker-publish.yml
@@ -77,7 +77,7 @@ jobs:
platforms: ${{ matrix.platform }}
push: true
provenance: false
- tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}-${{ github.run_id }}
+ tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}
merge:
runs-on: docker-amd64
@@ -133,6 +133,6 @@ jobs:
run: |
docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64-${{ github.run_id }} \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64-${{ github.run_id }}
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64 \
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64
diff --git a/.forgejo/workflows/pr-tests.yml b/.forgejo/workflows/pr-tests.yml
deleted file mode 100644
index aa96b84..0000000
--- a/.forgejo/workflows/pr-tests.yml
+++ /dev/null
@@ -1,39 +0,0 @@
-name: PR Tests
-on: [pull_request]
-jobs:
- test:
- runs-on: docker-arm64
- container:
- image: python:3.12-slim
- env:
- CMAKE_BUILD_PARALLEL_LEVEL: "4"
- steps:
- - name: Install system deps
- run: |
- apt-get update
- apt-get install -y --no-install-recommends \
- git ca-certificates \
- build-essential pkg-config
- rm -rf /var/lib/apt/lists/*
- - name: Checkout
- run: |
- git config --global --add safe.directory "$PWD"
- git clone --depth=1 \
- "https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" .
- git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr"
- git checkout pr
- - name: Fetch action source
- run: |
- git clone --depth=1 --branch master \
- "https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \
- ./.run-tests
- - uses: ./.run-tests/run-tests
- with:
- setup: |
- python -m pip install --upgrade pip
- pip install -r requirements.txt
- pip install -r test/requirements_test.txt
- command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml
- artifacts-path: |
- report.xml
- coverage.xml
diff --git a/.gitignore b/.gitignore
index 7cd8431..cfce37c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -66,4 +66,7 @@ config.yaml
# SQLite
*.db*
-*settings.json
\ No newline at end of file
+*settings.json
+
+# Test suite (local only, not committed yet)
+test/
\ No newline at end of file
diff --git a/.nyx/triage.json b/.nyx/triage.json
deleted file mode 100644
index 2c4dc31..0000000
--- a/.nyx/triage.json
+++ /dev/null
@@ -1,24 +0,0 @@
-{
- "version": 1,
- "decisions": [],
- "suppression_rules": [
- {
- "by": "rule",
- "value": "py.auth.token_override_without_validation",
- "state": "suppressed",
- "note": "false_positive: token validation handled upstream by middleware"
- },
- {
- "by": "rule",
- "value": "state-resource-leak",
- "state": "suppressed",
- "note": "false_positive: resource lifecycle managed externally"
- },
- {
- "by": "rule",
- "value": "py.crypto.sha1",
- "state": "suppressed",
- "note": "accepted_risk: used for non-security checksum only"
- }
- ]
-}
\ No newline at end of file
diff --git a/config.yaml b/config.yaml
index 2107a3c..76fbbe1 100644
--- a/config.yaml
+++ b/config.yaml
@@ -26,26 +26,6 @@ max_concurrent_connections: 2
# When false (default), equally-idle endpoints are chosen at random.
# priority_routing: true
-# Conversation affinity (optional, default: false).
-# Pins a conversation to the endpoint that served its first turn so the
-# llama.cpp / Ollama prompt cache (KV cache) stays warm — first turn pays
-# the cold prefill, every follow-up turn reuses the same prefix.
-#
-# Fingerprint = sha1(model + leading system messages + first user turn).
-# Same chat → same fingerprint on every follow-up turn → same pin, TTL
-# refreshed on each reuse. Soft preference: if the pinned endpoint no
-# longer has the model loaded or has no free slot, the standard algorithm
-# takes over (no failure, just a cache miss).
-#
-# Heads-up: most chat UIs (Open WebUI, LibreChat, …) fire side requests for
-# title / tag / follow-up generation. Those have their own first turn and
-# therefore their own pin, so a single visible "chat" may show several dots
-# in the dashboard's Affinity column. That is correct — each pin matches a
-# real warm KV prefix on the backend. See doc/configuration.md for details.
-conversation_affinity: true
-conversation_affinity_ttl: 300 # seconds of inactivity before a pin expires;
- # bumped on every reuse. Matches Ollama's default keep_alive.
-
# Optional router-level API key that gates router/API/web UI access (leave empty to disable)
nomyo-router-api-key: ""
diff --git a/doc/architecture.md b/doc/architecture.md
index c2408d8..f725573 100644
--- a/doc/architecture.md
+++ b/doc/architecture.md
@@ -206,8 +206,6 @@ The `/health` endpoint provides comprehensive health status:
}
```
-For Ollama endpoints the probe is a parallel check of `/api/version` (liveness) and `/api/ps` (the route used by `choose_endpoint` when selecting a backend for a request). Reporting `ok` only when both succeed prevents the router from advertising an endpoint as healthy while completion calls dead-end on `/api/ps`. The same dual probe backs `/api/config`, which the dashboard uses to render endpoint health.
-
## Database Schema
The router uses SQLite for persistent storage:
diff --git a/doc/configuration.md b/doc/configuration.md
index 1addd66..7d9986a 100644
--- a/doc/configuration.md
+++ b/doc/configuration.md
@@ -166,91 +166,6 @@ With this config the primary handles up to 4 concurrent requests before the seco
---
-### `conversation_affinity`
-
-**Type**: `bool` (optional)
-
-**Default**: `false`
-
-**Companion setting**: [`conversation_affinity_ttl`](#conversation_affinity_ttl)
-
-**Description**: When enabled, the router prefers to send follow-up requests of the same conversation back to the endpoint that already served the first turn. This keeps the backend's prompt cache (the llama.cpp / Ollama **KV cache**) warm: the first user turn pays the cold prefill cost, every later turn reuses the same prefix and only generates new tokens. It is a **soft preference** — when the previously-chosen endpoint is no longer eligible (model unloaded, no free slot), the router falls back to the standard selection algorithm (`priority_routing` or random).
-
-#### How a conversation is identified
-
-The router does **not** track session IDs or auth tokens. It computes a stable fingerprint per request from:
-
-```
-SHA1( model
- + every leading message with role="system"
- + the first message with role="user" )
-```
-
-Anything after the first user turn is ignored — those later messages extend the same KV prefix, so they don't change the cache identity.
-
-**What this means in practice**
-
-| You send… | Fingerprint behaves like… |
-|---|---|
-| Turn 2 of the same chat (history grows but first system+user are unchanged) | **Same** as turn 1 → pin is reused and TTL refreshed |
-| Turn 1 of a fresh chat | **New** fingerprint → new pin |
-| Same first user prompt but a different model | **New** fingerprint (model is part of the hash) |
-| Same chat but the client mutates the system prompt between turns (e.g. injects a fresh timestamp) | **New** fingerprint — the affinity will not stick |
-
-#### TTL and refresh
-
-Every time `choose_endpoint` returns a pinned endpoint, the entry's expiry is bumped to `now + conversation_affinity_ttl`. An idle conversation drops out of the map once that window elapses without traffic. Default 300 s matches Ollama's default `keep_alive` — once the backend has unloaded the model, the KV cache is gone too, so a stale pin would be pointless anyway.
-
-#### Why the dashboard may show more than one dot per visible conversation
-
-The fingerprint is computed per **HTTP request**, not per chat-window. Most chat UIs (Open WebUI in particular) fire several **auxiliary** requests alongside the real conversation:
-
-- *Title generation* — synthetic system prompt + the user message as content
-- *Follow-up question suggestion* — synthetic system prompt + the conversation as content
-- *Tag generation*, *memory extraction*, *retrieval query rewriting*, etc.
-
-Each of those has its own `(system + first user turn)` and therefore its own fingerprint and its own pin in [the affinity dot matrix](monitoring.md#affinity-stats-conversation-affinity). They all *correctly* refer to a real warm KV-cache prefix on the backend, so the routing they drive is right — they just don't visually map 1:1 to a user-perceived "conversation."
-
-#### Example
-
-```yaml
-endpoints:
- - http://gpu-primary:11434
- - http://gpu-secondary:11434
-
-conversation_affinity: true
-conversation_affinity_ttl: 300
-```
-
-With this configuration, a chat that starts on `gpu-primary` will keep returning to `gpu-primary` for follow-up turns as long as the model is still loaded there and a slot is free, even if `gpu-secondary` happens to be more idle at that moment. Cold-prefill cost is paid once instead of once per turn.
-
-#### When to enable
-
-- ✅ Interactive chat workloads with long histories — the prefill savings on every follow-up turn are substantial.
-- ✅ Multi-endpoint deployments where models are loaded on more than one node.
-- ❌ Pure one-shot / single-turn workloads (no KV-cache to keep warm).
-- ❌ When you specifically want strict load-balancing parity — affinity intentionally biases against perfect balance.
-
----
-
-### `conversation_affinity_ttl`
-
-**Type**: `int` (seconds, optional)
-
-**Default**: `300`
-
-**Description**: How long a conversation stays pinned to its endpoint after the last request that touched it. Refreshed on every reuse — so an actively-used conversation keeps its pin indefinitely; an abandoned one expires after `conversation_affinity_ttl` seconds of silence.
-
-**Recommendation**: leave this aligned with the backend's `keep_alive` window. If the model is unloaded by the backend, the KV cache is gone and there is no benefit to keeping the pin.
-
-**Example**:
-```yaml
-conversation_affinity: true
-conversation_affinity_ttl: 600 # half an hour of inactivity before un-pinning
-```
-
----
-
### `router_api_key`
**Type**: `str` (optional)
diff --git a/doc/monitoring.md b/doc/monitoring.md
index 9ce25ec..b5bcbff 100644
--- a/doc/monitoring.md
+++ b/doc/monitoring.md
@@ -29,10 +29,6 @@ Response:
- `200`: All endpoints healthy
- `503`: One or more endpoints unhealthy
-**Probe scope per endpoint**:
-- **Ollama endpoints** are probed at both `/api/version` (liveness) and `/api/ps` (model-introspection used by the router). If either fails the endpoint is reported as `error`; the response still includes `version` when the daemon is reachable so operators can tell a partial failure from a full outage. The `detail` field names the failing probe, e.g. `"/api/ps: 502 …"`.
-- **OpenAI-compatible / llama-server endpoints** are probed at `/models`.
-
### Current Usage
```bash
@@ -137,8 +133,6 @@ Response:
}
```
-Uses the same dual-probe logic as `/health` (Ollama: `/api/version` + `/api/ps`; OpenAI-compatible: `/models`). An endpoint will report `error` whenever either probe fails. The dashboard renders the `detail` field as a tooltip on the status cell.
-
### Cache Statistics
```bash
@@ -172,39 +166,6 @@ curl -X POST http://localhost:12434/api/cache/invalidate
Clears all cached entries and resets hit/miss counters.
-### Affinity Stats (Conversation Affinity)
-
-```bash
-curl http://localhost:12434/api/affinity_stats
-```
-
-Response when [`conversation_affinity`](configuration.md#conversation_affinity) is enabled:
-
-```json
-{
- "enabled": true,
- "ttl": 300,
- "entries": [
- { "endpoint": "http://gpu-primary:11434", "model": "llama3.2:latest", "remaining": 287.4 },
- { "endpoint": "http://gpu-primary:11434", "model": "llama3.2:latest", "remaining": 113.0 },
- { "endpoint": "http://gpu-secondary:11434", "model": "qwen2.5-coder:7b", "remaining": 44.8 }
- ]
-}
-```
-
-Response when the feature is disabled:
-```json
-{ "enabled": false, "ttl": 300, "entries": [] }
-```
-
-- One element per **live pinned conversation** (no fingerprints or content — just the endpoint/model the pin points to and how many seconds it has left before expiry).
-- Aggregation by `(endpoint, model)` is left to the consumer: the dashboard does this client-side.
-- The endpoint is gated by the same `nomyo-router-api-key` middleware as the rest of `/api/*`.
-
-The dashboard's **Running Models (PS) → Affinity** column is rendered from this data. The column auto-hides when `enabled: false`. Each row shows one dot per live pin against that `(endpoint, model)` pair; dot opacity = `remaining / ttl` (floor 0.15), so freshly-routed pins are solid and pins close to expiry fade out. A `+N` overflow badge appears once a single (endpoint, model) holds more than 12 active pins; an em-dash (`—`) marks an `(endpoint, model)` with no live pins.
-
-> Multiple dots for what looks like "one chat window" is normal — most chat UIs (Open WebUI, LibreChat, …) fire auxiliary requests (title generation, follow-up suggestions, tag extraction) that have their own first-turn fingerprint and therefore their own pin. See [Conversation Affinity → Why the dashboard may show more than one dot per visible conversation](configuration.md#conversation_affinity) for the details.
-
### Real-time Usage Stream
```bash
diff --git a/requirements.txt b/requirements.txt
index 15512a2..159c062 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,7 @@ anyio==4.13.0
async-timeout==5.0.1
attrs==26.1.0
certifi==2026.4.22
-click==8.4.0
+click==8.3.3
distro==1.9.0
exceptiongroup==1.3.1
fastapi==0.136.1
@@ -15,11 +15,11 @@ frozenlist==1.8.0
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
-idna==3.15
-jiter==0.15.0
+idna==3.14
+jiter==0.14.0
multidict==6.7.1
ollama==0.6.2
-openai==2.37.0
+openai==1.109.1
orjson>=3.11.5
numpy>=1.26
pillow==12.2.0
@@ -32,11 +32,11 @@ PyYAML==6.0.3
sniffio==1.3.1
starlette==0.52.1
truststore==0.10.4
-tiktoken==0.13.0
+tiktoken==0.12.0
tqdm==4.67.3
typing-inspection==0.4.2
typing_extensions==4.15.0
-uvicorn==0.47.0
+uvicorn==0.46.0
uvloop
yarl==1.23.0
aiosqlite
diff --git a/router.py b/router.py
index c465ebc..603387e 100644
--- a/router.py
+++ b/router.py
@@ -2,11 +2,11 @@
title: NOMYO Router - an (O)llama and OpenAI API v1 Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
-version: 0.9
+version: 0.7
license: AGPL
"""
# -------------------------------------------------------------
-import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx, hashlib
+import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx
try:
import truststore; truststore.inject_into_ssl()
except ImportError:
@@ -223,15 +223,6 @@ class Config(BaseSettings):
# When True, config order = priority; routes by utilization ratio + config index (WRR)
priority_routing: bool = Field(default=False)
- # Conversation affinity: route the same conversation back to the endpoint that
- # previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm.
- # Soft preference — falls back to the standard algorithm when the affine endpoint
- # is saturated or no longer has the model loaded.
- conversation_affinity: bool = Field(default=False)
- # TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min):
- # if the backend has already evicted the model, the KV cache is cold anyway.
- conversation_affinity_ttl: int = Field(default=300)
-
api_keys: Dict[str, str] = Field(default_factory=dict)
# Optional router-level API key used to gate access to this service and dashboard
router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY")
@@ -256,8 +247,9 @@ class Config(BaseSettings):
cache_history_weight: float = Field(default=0.3)
class Config:
- # YAML loading is handled manually via Config.from_yaml(); env vars use this prefix.
+ # Load from `config.yaml` first, then from env variables
env_prefix = "NOMYO_ROUTER_"
+ yaml_file = Path("config.yaml") # relative to cwd
@classmethod
def _expand_env_refs(cls, obj):
@@ -444,47 +436,6 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(
usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
-# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic).
-# Keeps the same conversation pinned to the endpoint that already has its
-# KV-cache prefix warm. Model is stored so the dashboard can aggregate live
-# entries per (endpoint, model) without recomputing fingerprints.
-# Never held together with usage_lock.
-_affinity_map: Dict[str, tuple[str, str, float]] = {}
-_affinity_lock = asyncio.Lock()
-_AFFINITY_MAX_ENTRIES = 10000
-
-
-def _conversation_fingerprint(model: str, messages: Optional[list],
- prompt: Optional[str]) -> Optional[str]:
- """
- Stable hash over (model, first system + first user turn). That prefix
- determines whether the backend's prompt cache is reusable; later turns
- don't influence the routing decision because they extend the same prefix.
- Returns None when there is no usable prefix.
- """
- parts: list[str] = [model or "_"]
- if messages:
- for m in messages:
- role = m.get("role") if isinstance(m, dict) else None
- if role not in ("system", "user"):
- continue
- content = m.get("content")
- if isinstance(content, list): # OpenAI multimodal parts
- content = "".join(
- p.get("text", "") for p in content
- if isinstance(p, dict) and p.get("type") == "text"
- )
- if not isinstance(content, str):
- continue
- parts.append(f"{role}:{content}")
- if role == "user":
- break
- elif prompt:
- parts.append(f"user:{prompt}")
- else:
- return None
- return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()
-
# Database instance
db: "TokenDatabase" = None
@@ -1000,7 +951,7 @@ class fetch:
async with client.get(f"{endpoint}/models") as resp:
await _ensure_success(resp)
data = await resp.json()
-
+
# Filter for loaded models only
items = data.get("data", [])
models = {
@@ -1012,19 +963,11 @@ class fetch:
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
- # Probe succeeded — clear any stale error so the endpoint
- # becomes routable again.
- async with _loaded_error_cache_lock:
- _loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/models", e)
print(f"[fetch.loaded_models] {message}")
- # Record the failure so `choose_endpoint` can avoid routing
- # to an unhealthy backend and repeated probes short-circuit.
- async with _loaded_error_cache_lock:
- _loaded_error_cache[endpoint] = time.time()
return set()
else:
# Original Ollama /api/ps logic
@@ -1039,15 +982,11 @@ class fetch:
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
- async with _loaded_error_cache_lock:
- _loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/api/ps", e)
print(f"[fetch.loaded_models] {message}")
- async with _loaded_error_cache_lock:
- _loaded_error_cache[endpoint] = time.time()
return set()
async def _refresh_loaded_models(endpoint: str) -> None:
@@ -1430,30 +1369,30 @@ def resize_image_if_needed(image_data):
pass
# Decode the base64 image data
image_bytes = base64.b64decode(image_data)
- with Image.open(io.BytesIO(image_bytes)) as image:
- if image.mode not in ("RGB", "L"):
- image = image.convert("RGB")
+ image = Image.open(io.BytesIO(image_bytes))
+ if image.mode not in ("RGB", "L"):
+ image = image.convert("RGB")
- # Get current size
- width, height = image.size
+ # Get current size
+ width, height = image.size
- # Calculate the new dimensions while maintaining aspect ratio
- if width > 512 or height > 512:
- aspect_ratio = width / height
- if aspect_ratio > 1: # Width is larger
- new_width = 512
- new_height = int(512 / aspect_ratio)
- else: # Height is larger
- new_height = 512
- new_width = int(512 * aspect_ratio)
+ # Calculate the new dimensions while maintaining aspect ratio
+ if width > 512 or height > 512:
+ aspect_ratio = width / height
+ if aspect_ratio > 1: # Width is larger
+ new_width = 512
+ new_height = int(512 / aspect_ratio)
+ else: # Height is larger
+ new_height = 512
+ new_width = int(512 * aspect_ratio)
- image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
- # Encode the resized image back to base64
- buffered = io.BytesIO()
- image.save(buffered, format="PNG")
- resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
- return resized_image_data
+ # Encode the resized image back to base64
+ buffered = io.BytesIO()
+ image.save(buffered, format="PNG")
+ resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return resized_image_data
except Exception as e:
print(f"Error processing image: {e}")
@@ -1799,8 +1738,7 @@ def get_max_connections(ep: str) -> int:
"max_concurrent_connections", config.max_concurrent_connections
)
-async def choose_endpoint(model: str, reserve: bool = True,
- affinity_key: Optional[str] = None) -> tuple[str, str]:
+async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpoint‑model pair **and**
@@ -1810,14 +1748,10 @@ async def choose_endpoint(model: str, reserve: bool = True,
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
2️⃣ Build a list of endpoints that contain the requested model.
- 2️⃣.5 If conversation affinity is enabled and the caller passes
- ``affinity_key``, prefer the endpoint that previously served the
- same conversation — but only when it still has the model loaded
- and a free slot. Otherwise fall through to the standard logic.
3️⃣ For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
4️⃣ If none are both loaded and free, fall back to any endpoint
- from the filtered list that simply has a free slot and randomly
+ from the filtered list that simply has a free slot and randomly
select one.
5️⃣ If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
@@ -1865,41 +1799,6 @@ async def choose_endpoint(model: str, reserve: bool = True,
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
- # 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing
- # recently. Without this filter, an endpoint where `/api/ps` returns 5xx
- # would appear with an empty loaded set but pass through to the
- # free-slot fallback (step 4) — sending completion calls to an
- # unhealthy backend. See issue #83.
- async with _loaded_error_cache_lock:
- unhealthy = {
- ep for ep, ts in _loaded_error_cache.items()
- if _is_fresh(ts, 300)
- }
- if unhealthy:
- filtered = [
- (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
- if ep not in unhealthy
- ]
- if filtered:
- candidate_endpoints = [ep for ep, _ in filtered]
- loaded_sets = [models for _, models in filtered]
- # If *every* candidate is unhealthy we still fall through with the
- # original list — refusing to route is worse than retrying a
- # possibly-recovered backend.
-
- # Look up a possible affinity hint *before* taking usage_lock. The two
- # locks are never held together to avoid lock-ordering issues.
- affine_ep: Optional[str] = None
- if config.conversation_affinity and affinity_key:
- async with _affinity_lock:
- entry = _affinity_map.get(affinity_key)
- if entry is not None:
- ep, _stored_model, expires_at = entry
- if expires_at < time.monotonic():
- _affinity_map.pop(affinity_key, None)
- else:
- affine_ep = ep
-
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
@@ -1915,75 +1814,59 @@ async def choose_endpoint(model: str, reserve: bool = True,
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
- selected: Optional[str] = None
+ # 3️⃣ Endpoints that have the model loaded *and* a free slot
+ loaded_and_free = [
+ ep for ep, models in zip(candidate_endpoints, loaded_sets)
+ if model in models and tracking_usage(ep) < get_max_connections(ep)
+ ]
- # 2️⃣.5 Conversation affinity preference — only honour the hint when
- # the affine endpoint still advertises the model loaded *and* has a
- # free slot. Otherwise fall back to the standard algorithm.
- if affine_ep:
- ep_loaded = {
- ep: set(models)
- for ep, models in zip(candidate_endpoints, loaded_sets)
- }
- if (affine_ep in candidate_endpoints
- and model in ep_loaded.get(affine_ep, set())
- and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
- selected = affine_ep
-
- if selected is None:
- # 3️⃣ Endpoints that have the model loaded *and* a free slot
- loaded_and_free = [
- ep for ep, models in zip(candidate_endpoints, loaded_sets)
- if model in models and tracking_usage(ep) < get_max_connections(ep)
+ if loaded_and_free:
+ if config.priority_routing:
+ # WRR: sort by config order first (stable), then by utilization ratio.
+ # Stable sort preserves priority for equal-ratio endpoints.
+ loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
+ loaded_and_free.sort(key=utilization_ratio)
+ selected = loaded_and_free[0]
+ else:
+ # Sort ascending for load balancing — all endpoints here already have the
+ # model loaded, so there is no model-switching cost to optimise for.
+ loaded_and_free.sort(key=tracking_usage)
+ # When all candidates are equally idle, randomise to avoid always picking
+ # the first entry in a stable sort.
+ if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
+ selected = random.choice(loaded_and_free)
+ else:
+ selected = loaded_and_free[0]
+ else:
+ # 4️⃣ Endpoints among the candidates that simply have a free slot
+ endpoints_with_free_slot = [
+ ep for ep in candidate_endpoints
+ if tracking_usage(ep) < get_max_connections(ep)
]
- if loaded_and_free:
+ if endpoints_with_free_slot:
if config.priority_routing:
- # WRR: sort by config order first (stable), then by utilization ratio.
- # Stable sort preserves priority for equal-ratio endpoints.
- loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
- loaded_and_free.sort(key=utilization_ratio)
- selected = loaded_and_free[0]
+ endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
+ endpoints_with_free_slot.sort(key=utilization_ratio)
+ selected = endpoints_with_free_slot[0]
else:
- # Sort ascending for load balancing — all endpoints here already have the
- # model loaded, so there is no model-switching cost to optimise for.
- loaded_and_free.sort(key=tracking_usage)
- # When all candidates are equally idle, randomise to avoid always picking
- # the first entry in a stable sort.
- if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
- selected = random.choice(loaded_and_free)
+ # Sort by total endpoint load (ascending) to prefer idle endpoints.
+ endpoints_with_free_slot.sort(
+ key=lambda ep: sum(usage_counts.get(ep, {}).values())
+ )
+ if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
+ selected = random.choice(endpoints_with_free_slot)
else:
- selected = loaded_and_free[0]
- else:
- # 4️⃣ Endpoints among the candidates that simply have a free slot
- endpoints_with_free_slot = [
- ep for ep in candidate_endpoints
- if tracking_usage(ep) < get_max_connections(ep)
- ]
-
- if endpoints_with_free_slot:
- if config.priority_routing:
- endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
- endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
- else:
- # Sort by total endpoint load (ascending) to prefer idle endpoints.
- endpoints_with_free_slot.sort(
- key=lambda ep: sum(usage_counts.get(ep, {}).values())
- )
- if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
- selected = random.choice(endpoints_with_free_slot)
- else:
- selected = endpoints_with_free_slot[0]
+ else:
+ # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
+ if config.priority_routing:
+ selected = min(
+ candidate_endpoints,
+ key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
+ )
else:
- # 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
- if config.priority_routing:
- selected = min(
- candidate_endpoints,
- key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
- )
- else:
- selected = min(candidate_endpoints, key=tracking_usage)
+ selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model)
snapshot = None
@@ -1992,15 +1875,6 @@ async def choose_endpoint(model: str, reserve: bool = True,
snapshot = _capture_snapshot()
if snapshot is not None:
await _distribute_snapshot(snapshot)
- # Record / refresh affinity *after* releasing usage_lock.
- if reserve and config.conversation_affinity and affinity_key:
- expires_at = time.monotonic() + config.conversation_affinity_ttl
- async with _affinity_lock:
- _affinity_map[affinity_key] = (selected, model, expires_at)
- if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
- now = time.monotonic()
- for k in [k for k, v in _affinity_map.items() if v[2] < now]:
- _affinity_map.pop(k, None)
return selected, tracking_model
# -------------------------------------------------------------
@@ -2051,8 +1925,7 @@ async def proxy(request: Request):
yield _cached
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
- _affinity_key = _conversation_fingerprint(model, None, prompt)
- endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
+ endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@@ -2222,8 +2095,7 @@ async def chat_proxy(request: Request):
opt = True
else:
opt = False
- _affinity_key = _conversation_fingerprint(model, messages, None)
- endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
+ endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@@ -3138,43 +3010,6 @@ async def ps_details_proxy(request: Request):
return JSONResponse(content={"models": models}, status_code=200)
-# -------------------------------------------------------------
-# 18b. Conversation-affinity stats – feeds the PS-table dot matrix
-# -------------------------------------------------------------
-@app.get("/api/affinity_stats")
-async def affinity_stats(request: Request):
- """
- Aggregate live conversation-affinity pins, one entry per pinned conversation.
- Each entry exposes only the endpoint, model, and remaining TTL in seconds —
- no fingerprints or content. When conversation_affinity is disabled the
- `entries` list is always empty.
- """
- if not config.conversation_affinity:
- return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
-
- now = time.monotonic()
- entries: list[dict] = []
- llama_eps = set(config.llama_server_endpoints)
- async with _affinity_lock:
- for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
- remaining = expires_at - now
- if remaining <= 0:
- _affinity_map.pop(fp, None)
- continue
- # Mirror the normalisation used by /api/ps_details so the dashboard
- # can join affinity entries to PS rows by (endpoint, model).
- display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
- entries.append({
- "endpoint": ep,
- "model": display_model,
- "remaining": round(remaining, 2),
- })
- return {
- "enabled": True,
- "ttl": config.conversation_affinity_ttl,
- "entries": entries,
- }
-
# -------------------------------------------------------------
# 19. Proxy usage route – for monitoring
# -------------------------------------------------------------
@@ -3188,103 +3023,44 @@ async def usage_proxy(request: Request):
"token_usage_counts": token_usage_counts}
# -------------------------------------------------------------
-# 20. Endpoint health probes (shared by /api/config and /health)
-# -------------------------------------------------------------
-async def _raw_probe(
- ep: str,
- route: str,
- api_key: Optional[str] = None,
- timeout: Optional[float] = None,
-) -> tuple[bool, object]:
- """Direct HTTP probe that distinguishes success from failure
- (unlike `fetch.endpoint_details`, which returns [] on either).
- Returns `(ok, payload_or_error_message)`.
- """
- headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
- if api_key is not None:
- headers["Authorization"] = "Bearer " + api_key
- url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
- req_kwargs = {}
- if timeout is not None:
- req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
- try:
- client: aiohttp.ClientSession = get_session(ep)
- async with client.get(url, headers=headers, **req_kwargs) as resp:
- await _ensure_success(resp)
- data = await resp.json()
- return True, data
- except Exception as exc:
- return False, _format_connection_issue(url, exc)
-
-
-async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
- """Probe an endpoint and return `{status, version?, detail?}`.
-
- Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
- that a daemon which is reachable but has a broken model-introspection
- path (issue #83) is reported as `error` rather than `ok`.
- OpenAI-compatible endpoints use a single `/models` probe.
- """
- if is_openai_compatible(ep):
- ok, payload = await _raw_probe(
- ep, "/models", config.api_keys.get(ep), timeout=timeout,
- )
- if ok:
- return {"status": "ok", "version": "latest"}
- return {"status": "error", "detail": str(payload)}
-
- (version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
- _raw_probe(ep, "/api/version", timeout=timeout),
- _raw_probe(ep, "/api/ps", timeout=timeout),
- )
-
- version_value = (
- version_payload.get("version")
- if version_ok and isinstance(version_payload, dict)
- else None
- )
-
- if version_ok and ps_ok:
- return {"status": "ok", "version": version_value}
- if not version_ok and not ps_ok:
- return {"status": "error", "detail": str(version_payload)}
- # Partial failure — daemon reachable but one probe failed. Report
- # as "error" so callers can surface the issue; include `version` so
- # the operator knows the daemon itself is alive.
- if not ps_ok:
- return {
- "status": "error",
- "version": version_value,
- "detail": f"/api/ps: {ps_payload}",
- }
- return {
- "status": "error",
- "detail": f"/api/version: {version_payload}",
- }
-
-
-# -------------------------------------------------------------
-# 20b. Proxy config route – for monitoring and frontend usage
+# 20. Proxy config route – for monitoring and frontent usage
# -------------------------------------------------------------
@app.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
- Ollama endpoints and llama_server_endpoints. The front‑end uses this
- to display which endpoints are being proxied and their health.
- Status is "error" when either liveness (/api/version) or routing
- health (/api/ps) fails — see issue #83.
+ Ollama endpoints and llama_server_endpoints. The front‑end uses this to display
+ which endpoints are being proxied.
"""
- async def check(url: str) -> dict:
- return {"url": url, **(await _endpoint_health(url, timeout=5))}
+ async def check_endpoint(url: str):
+ client: aiohttp.ClientSession = get_session(url)
+ headers = None
+ if "/v1" in url:
+ headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
+ target_url = f"{url}/models"
+ else:
+ target_url = f"{url}/api/version"
- ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
+ try:
+ async with client.get(target_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
+ await _ensure_success(resp)
+ data = await resp.json()
+ if "/v1" in url:
+ return {"url": url, "status": "ok", "version": "latest"}
+ else:
+ return {"url": url, "status": "ok", "version": data.get("version")}
+ except Exception as e:
+ detail = _format_connection_issue(target_url, e)
+ return {"url": url, "status": "error", "detail": detail}
+
+ # Check Ollama endpoints
+ ollama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
+
+ # Check llama-server endpoints
llama_results = []
if config.llama_server_endpoints:
- llama_results = await asyncio.gather(
- *[check(ep) for ep in config.llama_server_endpoints]
- )
-
+ llama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.llama_server_endpoints])
+
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
@@ -3452,8 +3228,7 @@ async def openai_chat_completions_proxy(request: Request):
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
# 2. Endpoint logic
- _affinity_key = _conversation_fingerprint(model, messages, None)
- endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
+ endpoint, tracking_model = await choose_endpoint(model)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list:
@@ -3763,8 +3538,7 @@ async def openai_completions_proxy(request: Request):
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
# 2. Endpoint logic
- _affinity_key = _conversation_fingerprint(model, None, prompt)
- endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
+ endpoint, tracking_model = await choose_endpoint(model)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
@@ -4096,30 +3870,44 @@ async def health_proxy(request: Request):
"""
Health‑check endpoint for monitoring the proxy.
- * Queries each configured endpoint for both liveness and routing health:
- Ollama endpoints are probed at `/api/version` AND `/api/ps`,
- OpenAI-compatible endpoints at `/models`.
+ * Queries each configured endpoint for its `/api/version` response.
* Returns a JSON object containing:
- - `status`: "ok" if every endpoint replied to every probe, otherwise "error".
+ - `status`: "ok" if every endpoint replied, otherwise "error".
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
# Run all health checks in parallel.
- # Ollama endpoints expose /api/version (liveness) and /api/ps (routing
- # health — required by `choose_endpoint`). OpenAI-compatible endpoints
- # (vLLM, llama-server, external) expose /models, which serves both
- # purposes. Probing /api/version alone would miss the case where the
- # Ollama process is up but /api/ps is failing — see issue #83.
+ # Ollama endpoints expose /api/version; OpenAI-compatible endpoints (vLLM,
+ # llama-server, external) expose /models. Using /api/version against an
+ # OpenAI-compatible endpoint yields a 404 and noisy log output.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints += llama_eps_extra
- probe_results = await asyncio.gather(
- *(_endpoint_health(ep) for ep in all_endpoints),
- )
+ tasks = []
+ for ep in all_endpoints:
+ if is_openai_compatible(ep):
+ tasks.append(fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True))
+ else:
+ tasks.append(fetch.endpoint_details(ep, "/api/version", "version", skip_error_cache=True))
- health_summary = dict(zip(all_endpoints, probe_results))
- overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ health_summary = {}
+ overall_ok = True
+
+ for ep, result in zip(all_endpoints, results):
+ if isinstance(result, Exception):
+ # Endpoint did not respond / returned an error
+ health_summary[ep] = {"status": "error", "detail": str(result)}
+ overall_ok = False
+ else:
+ # Successful response – report the reported version (Ollama) or
+ # indicate the endpoint is reachable (OpenAI-compatible).
+ if is_openai_compatible(ep):
+ health_summary[ep] = {"status": "ok"}
+ else:
+ health_summary[ep] = {"status": "ok", "version": result}
response_payload = {
"status": "ok" if overall_ok else "error",
@@ -4240,16 +4028,6 @@ async def startup_event() -> None:
@app.on_event("shutdown")
async def shutdown_event() -> None:
await close_all_sse_queues()
-
- # Stop background tasks first so they stop touching the DB before we close it.
- for t in (token_worker_task, flush_task):
- if t is not None:
- t.cancel()
- try:
- await t
- except (asyncio.CancelledError, Exception):
- pass
-
await flush_remaining_buffers()
await app_state["session"].close()
@@ -4269,11 +4047,7 @@ async def shutdown_event() -> None:
except Exception as e:
print(f"[shutdown] Error closing httpx client {ep}: {e}")
- # Close the aiosqlite connection last — its worker thread is non-daemon
- # and would otherwise keep the interpreter alive after lifespan completes.
- if db is not None:
- try:
- await db.close()
- print("[shutdown] Closed token DB connection.")
- except Exception as e:
- print(f"[shutdown] Error closing DB: {e}")
+ if token_worker_task is not None:
+ token_worker_task.cancel()
+ if flush_task is not None:
+ flush_task.cancel()
diff --git a/static/index.html b/static/index.html
index 8c0b16c..419d7bb 100644
--- a/static/index.html
+++ b/static/index.html
@@ -121,45 +121,6 @@
.ps-subrow + .ps-subrow {
margin-top: 2px;
}
- #ps-table .affinity-col,
- #ps-table .affinity-cell {
- display: none;
- }
- #ps-table.affinity-on .affinity-col,
- #ps-table.affinity-on .affinity-cell {
- display: table-cell;
- width: 90px;
- text-align: center;
- padding-left: 6px;
- padding-right: 6px;
- }
- #ps-table.affinity-on .affinity-dots {
- max-width: 78px;
- }
- .affinity-dots {
- display: inline-flex;
- flex-wrap: wrap;
- gap: 3px;
- align-items: center;
- line-height: 1;
- }
- .affinity-dot {
- width: 8px;
- height: 8px;
- border-radius: 50%;
- background: #2e7d32;
- display: inline-block;
- transition: opacity 1s linear;
- }
- .affinity-overflow {
- font-size: 10px;
- color: #555;
- margin-left: 2px;
- }
- .affinity-empty {
- color: #bbb;
- font-size: 11px;
- }
#ps-table {
width: max-content;
min-width: 100%;
@@ -170,13 +131,13 @@
max-width: 300px;
white-space: nowrap;
}
- /* Optimize narrow columns (Params / Quant / Ctx) */
+ /* Optimize narrow columns */
+ #ps-table th:nth-child(3),
+ #ps-table td:nth-child(3),
#ps-table th:nth-child(4),
#ps-table td:nth-child(4),
#ps-table th:nth-child(5),
- #ps-table td:nth-child(5),
- #ps-table th:nth-child(6),
- #ps-table td:nth-child(6) {
+ #ps-table td:nth-child(5) {
width: 80px;
text-align: center;
}
@@ -192,10 +153,6 @@
color: #8b0000;
font-weight: bold;
}
- .status-error[title] {
- cursor: help;
- text-decoration: underline dotted;
- }
.copy-link,
.delete-link,
.show-link,
@@ -438,7 +395,6 @@
| Model |
Endpoint |
- Affinity |
Params |
Quant |
Ctx |
@@ -450,7 +406,7 @@
- | Loading… |
+ Loading… |
@@ -740,16 +696,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return await resp.json();
}
- function escapeHtml(value) {
- if (value === null || value === undefined) return "";
- return String(value)
- .replace(/&/g, "&")
- .replace(//g, ">")
- .replace(/"/g, """)
- .replace(/'/g, "'");
- }
-
function toggleDarkMode() {
document.documentElement.classList.toggle("dark-mode");
}
@@ -766,24 +712,40 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
// Build HTML for both endpoints and llama_server_endpoints
let html = "";
- const renderRow = (e) => {
- const statusClass =
- e.status === "ok" ? "status-ok" : "status-error";
- const version = e.version || "N/A";
- const titleAttr = e.detail
- ? ` title="${escapeHtml(e.detail)}"`
- : "";
- return `
+ // Add Ollama endpoints
+ html += data.endpoints
+ .map((e) => {
+ const statusClass =
+ e.status === "ok"
+ ? "status-ok"
+ : "status-error";
+ const version = e.version || "N/A";
+ return `
- | ${escapeHtml(e.url)} |
- ${escapeHtml(e.status)} |
- ${escapeHtml(version)} |
+ ${e.url} |
+ ${e.status} |
+ ${version} |
`;
- };
-
- html += data.endpoints.map(renderRow).join("");
+ })
+ .join("");
+
+ // Add llama-server endpoints
if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) {
- html += data.llama_server_endpoints.map(renderRow).join("");
+ html += data.llama_server_endpoints
+ .map((e) => {
+ const statusClass =
+ e.status === "ok"
+ ? "status-ok"
+ : "status-error";
+ const version = e.version || "N/A";
+ return `
+
+ | ${e.url} |
+ ${e.status} |
+ ${version} |
+
`;
+ })
+ .join("");
}
body.innerHTML = html;
@@ -970,14 +932,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return items.map((item) => `${item || ""}
`).join("");
};
- const escapeAttr = (s) => String(s).replace(/&/g, "&").replace(/"/g, """).replace(//g, ">");
- const renderAffinitySlots = (endpoints, modelName) => {
- if (!endpoints.length) return "";
- return endpoints
- .map((ep) => `
`)
- .join("");
- };
-
body.innerHTML = Array.from(grouped.entries())
.map(([modelName, modelInstances]) => {
const existingRow = psRows.get(modelName);
@@ -1001,7 +955,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return `
| ${modelName} stats |
${renderInstanceList(endpoints)} |
- ${renderAffinitySlots(endpoints, modelName)} |
${params} |
${quant} |
${ctx} |
@@ -1019,83 +972,11 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
const model = row.dataset.model;
if (model) psRows.set(model, row);
});
- renderAffinityDots();
} catch (e) {
console.error(e);
}
}
- /* ---------- Conversation-affinity dots ---------- */
- const AFFINITY_MAX_DOTS = 12;
- let affinityIndex = new Map(); // `${endpoint}|${model}` -> array of {expiresAt}
- let affinityTtl = 300;
- let affinityEnabled = false;
-
- async function loadAffinity() {
- try {
- const data = await fetchJSON("/api/affinity_stats");
- affinityEnabled = !!data.enabled;
- affinityTtl = Number(data.ttl) || 300;
- const now = Date.now() / 1000;
- const idx = new Map();
- for (const e of data.entries || []) {
- const key = `${e.endpoint}|${e.model}`;
- if (!idx.has(key)) idx.set(key, []);
- idx.get(key).push({ expiresAt: now + Number(e.remaining) });
- }
- affinityIndex = idx;
- applyAffinityColumnVisibility();
- renderAffinityDots();
- } catch (err) {
- // Endpoint may 404 on older deployments — silently degrade.
- affinityEnabled = false;
- affinityIndex = new Map();
- applyAffinityColumnVisibility();
- renderAffinityDots();
- }
- }
-
- function applyAffinityColumnVisibility() {
- const table = document.getElementById("ps-table");
- if (!table) return;
- table.classList.toggle("affinity-on", affinityEnabled);
- }
-
- function renderAffinityDots() {
- const spans = document.querySelectorAll(".affinity-dots");
- if (!spans.length) return;
- const now = Date.now() / 1000;
- spans.forEach((span) => {
- const ep = span.dataset.endpoint;
- const mdl = span.dataset.model;
- const key = `${ep}|${mdl}`;
- const pins = (affinityIndex.get(key) || []).filter((p) => p.expiresAt > now);
- if (pins.length !== (affinityIndex.get(key) || []).length) {
- if (pins.length) affinityIndex.set(key, pins);
- else affinityIndex.delete(key);
- }
- if (!pins.length) {
- span.innerHTML = affinityEnabled
- ? `—`
- : "";
- return;
- }
- // Sort freshest first so visible dots are the most "recent".
- pins.sort((a, b) => b.expiresAt - a.expiresAt);
- const visible = pins.slice(0, AFFINITY_MAX_DOTS);
- const overflow = pins.length - visible.length;
- const dotsHtml = visible
- .map((p) => {
- const remaining = Math.max(0, p.expiresAt - now);
- const opacity = Math.max(0.15, Math.min(1, remaining / affinityTtl));
- const secs = Math.round(remaining);
- return ``;
- })
- .join("");
- span.innerHTML = dotsHtml + (overflow > 0 ? `+${overflow}` : "");
- });
- }
-
/* ---------- Usage Chart (stacked‑percentage) ---------- */
function getColor(seed) {
const h = Math.abs(hashString(seed) % 360);
@@ -1292,13 +1173,10 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
loadEndpoints();
loadTags();
loadPS();
- loadAffinity();
loadUsage();
initHeaderChart();
setInterval(tickTpsChart, 1000);
setInterval(loadPS, 60_000);
- setInterval(loadAffinity, 15_000);
- setInterval(renderAffinityDots, 2_000);
setInterval(loadEndpoints, 300_000);
/* show logic */
diff --git a/test/config_test.yaml b/test/config_test.yaml
deleted file mode 100644
index 30f2fa3..0000000
--- a/test/config_test.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-endpoints:
- - http://192.168.0.51:12434
-
-llama_server_endpoints:
- - http://192.168.0.51:12434/v1
-
-max_concurrent_connections: 2
-
-api_keys:
- "http://192.168.0.51:12434": "ollama"
- "http://192.168.0.51:12434/v1": "llama"
-
-cache_enabled: false
diff --git a/test/conftest.py b/test/conftest.py
deleted file mode 100644
index c5142da..0000000
--- a/test/conftest.py
+++ /dev/null
@@ -1,233 +0,0 @@
-"""
-Test configuration for nomyo-router.
-
-Run from project root:
- pytest test/ -v
- pytest test/ -m "not integration" # skip real-server tests
- pytest test/ -m integration -v # only real-server tests
-
-Environment variables:
- NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434)
- NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1)
- NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset)
- NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset)
-"""
-import asyncio
-import os
-import ssl
-import sys
-import tempfile
-from pathlib import Path
-from unittest.mock import MagicMock, patch
-
-import aiohttp
-import httpx
-import pytest
-
-_TEST_DIR = Path(__file__).parent
-# Must be set before importing router so module-level Config.from_yaml + Config field
-# defaults pick these up. db_path is intentionally absent from config_test.yaml so the
-# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows).
-os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml"))
-os.environ.setdefault(
- "NOMYO_ROUTER_DB_PATH",
- str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"),
-)
-
-sys.path.insert(0, str(_TEST_DIR.parent))
-
-import router # noqa: E402
-
-TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434")
-TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1")
-
-
-def pytest_configure(config):
- config.addinivalue_line(
- "markers",
- "integration: tests that require a real backend at 192.168.0.50:12434",
- )
-
-
-# ── Config mocks ─────────────────────────────────────────────────────────────
-
-@pytest.fixture
-def mock_config():
- """Minimal config pointing at TEST_OLLAMA / TEST_LLAMA."""
- cfg = MagicMock()
- cfg.endpoints = [TEST_OLLAMA]
- cfg.llama_server_endpoints = [TEST_LLAMA]
- cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
- cfg.max_concurrent_connections = 2
- cfg.router_api_key = None
- cfg.cache_enabled = False
- return cfg
-
-
-@pytest.fixture
-def mock_config_no_llama():
- """Config with Ollama only, no llama-server."""
- cfg = MagicMock()
- cfg.endpoints = [TEST_OLLAMA]
- cfg.llama_server_endpoints = []
- cfg.api_keys = {TEST_OLLAMA: "ollama"}
- cfg.max_concurrent_connections = 2
- cfg.router_api_key = None
- cfg.cache_enabled = False
- return cfg
-
-
-@pytest.fixture
-def mock_config_with_key():
- """Config with router_api_key set (enables auth middleware)."""
- cfg = MagicMock()
- cfg.endpoints = [TEST_OLLAMA]
- cfg.llama_server_endpoints = []
- cfg.api_keys = {}
- cfg.max_concurrent_connections = 2
- cfg.router_api_key = "test-secret-key"
- cfg.cache_enabled = False
- return cfg
-
-
-# ── aiohttp session (used by fetch tests + choose_endpoint tests) ─────────────
-
-@pytest.fixture
-async def aio_session():
- """Real aiohttp session stored in app_state; intercepted by aioresponses."""
- ssl_ctx = ssl.create_default_context()
- conn = aiohttp.TCPConnector(ssl=ssl_ctx)
- session = aiohttp.ClientSession(connector=conn)
- router.app_state["session"] = session
-
- # Clear caches to prevent test bleed
- router._models_cache.clear()
- router._loaded_models_cache.clear()
- router._available_error_cache.clear()
- router._loaded_error_cache.clear()
- router._inflight_available_models.clear()
- router._inflight_loaded_models.clear()
- router._bg_refresh_available.clear()
- router._bg_refresh_loaded.clear()
-
- yield session
-
- await session.close()
- router.app_state["session"] = None
-
-
-# ── Validation-only HTTP client (no real backend needed) ──────────────────────
-
-@pytest.fixture
-async def client(mock_config, tmp_path):
- """httpx client for validation/auth tests — no real backend calls made."""
- from db import TokenDatabase
-
- ssl_ctx = ssl.create_default_context()
- conn = aiohttp.TCPConnector(ssl=ssl_ctx)
- session = aiohttp.ClientSession(connector=conn)
-
- db_inst = TokenDatabase(str(tmp_path / "test.db"))
- await db_inst.init_db()
-
- old_session = router.app_state.get("session")
- old_db = router.db
-
- router.app_state["session"] = session
- router.db = db_inst
-
- with patch.object(router, "config", mock_config):
- transport = httpx.ASGITransport(app=router.app)
- async with httpx.AsyncClient(
- transport=transport, base_url="http://test", timeout=10.0
- ) as c:
- yield c
-
- await session.close()
- router.app_state["session"] = old_session
- router.db = old_db
-
-
-@pytest.fixture
-async def client_auth(mock_config_with_key, tmp_path):
- """httpx client with router_api_key configured (for auth middleware tests)."""
- from db import TokenDatabase
-
- ssl_ctx = ssl.create_default_context()
- conn = aiohttp.TCPConnector(ssl=ssl_ctx)
- session = aiohttp.ClientSession(connector=conn)
-
- db_inst = TokenDatabase(str(tmp_path / "test_auth.db"))
- await db_inst.init_db()
-
- old_session = router.app_state.get("session")
- old_db = router.db
-
- router.app_state["session"] = session
- router.db = db_inst
-
- with patch.object(router, "config", mock_config_with_key):
- transport = httpx.ASGITransport(app=router.app)
- async with httpx.AsyncClient(
- transport=transport, base_url="http://test", timeout=10.0
- ) as c:
- yield c
-
- await session.close()
- router.app_state["session"] = old_session
- router.db = old_db
-
-
-# ── Integration client (full startup with real backend) ──────────────────────
-
-@pytest.fixture(scope="module")
-async def integration_client():
- """Full app startup pointing at the real test server."""
- await router.startup_event()
- transport = httpx.ASGITransport(app=router.app)
- async with httpx.AsyncClient(
- transport=transport,
- base_url="http://test",
- timeout=httpx.Timeout(60.0),
- ) as c:
- yield c
- await router.shutdown_event()
-
-
-# ── Model discovery fixtures ──────────────────────────────────────────────────
-
-@pytest.fixture(scope="module")
-async def chat_model(integration_client):
- """Return a chat/generation model name available on the test server."""
- env_model = os.getenv("NOMYO_TEST_MODEL_CHAT")
- if env_model:
- return env_model
- resp = await integration_client.get("/api/tags")
- if resp.status_code != 200:
- pytest.skip("Cannot reach test server")
- models = resp.json().get("models", [])
- # Prefer small models for faster tests
- for m in models:
- name = m.get("name", "")
- if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]):
- return name
- if models:
- return models[0]["name"]
- pytest.skip("No chat models available on test server")
-
-
-@pytest.fixture(scope="module")
-async def embed_model(integration_client):
- """Return an embedding model name available on the test server."""
- env_model = os.getenv("NOMYO_TEST_EMBED_MODEL")
- if env_model:
- return env_model
- resp = await integration_client.get("/api/tags")
- if resp.status_code != 200:
- pytest.skip("Cannot reach test server")
- models = resp.json().get("models", [])
- for m in models:
- name = m.get("name", "")
- if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]):
- return name
- pytest.skip("No embedding model available on test server")
diff --git a/test/pytest.ini b/test/pytest.ini
deleted file mode 100644
index 1d05e6d..0000000
--- a/test/pytest.ini
+++ /dev/null
@@ -1,7 +0,0 @@
-[pytest]
-asyncio_mode = auto
-markers =
- integration: tests that require a real backend at 192.168.0.51:12434
-testpaths = .
-filterwarnings =
- ignore::pytest.PytestUnhandledThreadExceptionWarning
diff --git a/test/requirements_test.txt b/test/requirements_test.txt
deleted file mode 100644
index 8c7c53f..0000000
--- a/test/requirements_test.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-pytest>=8.0
-pytest-asyncio>=0.24
-pytest-cov>=5.0
-aioresponses>=0.7
diff --git a/test/test.md b/test/test.md
deleted file mode 100644
index 533a3ee..0000000
--- a/test/test.md
+++ /dev/null
@@ -1,60 +0,0 @@
-# Testing nomyo-router
-
-## Setup
-
-Install test dependencies (from the project root):
-
-```bash
-pip install -r test/requirements_test.txt
-```
-
-## Running tests
-
-All commands run from the `test/` directory:
-
-```bash
-cd test
-```
-
-**All non-integration tests** (no backend required):
-```bash
-pytest -m "not integration" -v
-```
-
-**Integration tests only** (requires backend at `192.168.0.51:12434`):
-```bash
-pytest -m integration -v
-```
-
-**Everything:**
-```bash
-pytest -v
-```
-
-## Test structure
-
-| File | What it covers | Backend needed |
-|---|---|---|
-| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No |
-| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No |
-| `test_unit_context.py` | Context window trimming logic | No |
-| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No |
-| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No |
-| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No |
-| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** |
-
-## Integration test backend
-
-Integration tests start the router in-process via `startup_event()` and route traffic
-through `httpx.ASGITransport` — no separately running router instance is needed.
-
-They do require a reachable Ollama or llama-server backend. Override the defaults via
-environment variables:
-
-```bash
-export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434
-export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise
-export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise
-```
-
-If the backend is unreachable, integration tests are automatically skipped.
diff --git a/test/test_api_integration.py b/test/test_api_integration.py
deleted file mode 100644
index 6c40fdc..0000000
--- a/test/test_api_integration.py
+++ /dev/null
@@ -1,304 +0,0 @@
-"""
-Integration tests against the real backend at 192.168.0.50:12434.
-
-Run with:
- pytest test/test_api_integration.py -v -m integration
-
-All tests in this file are marked @pytest.mark.integration.
-They require the test server to be reachable and to have at least one
-chat model and one embedding model available.
-
-Env vars to pin specific models:
- NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b
- NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest
-"""
-import json
-
-import pytest
-
-
-pytestmark = pytest.mark.integration
-
-
-# ── Health / discovery routes ─────────────────────────────────────────────────
-
-class TestDiscoveryRoutes:
- async def test_version(self, integration_client):
- resp = await integration_client.get("/api/version")
- assert resp.status_code == 200
- data = resp.json()
- assert "version" in data
- assert isinstance(data["version"], str)
-
- async def test_tags_returns_models(self, integration_client):
- resp = await integration_client.get("/api/tags")
- assert resp.status_code == 200
- data = resp.json()
- assert "models" in data
- assert isinstance(data["models"], list)
- assert len(data["models"]) > 0
-
- async def test_ps_returns_list(self, integration_client):
- resp = await integration_client.get("/api/ps")
- assert resp.status_code == 200
- data = resp.json()
- assert "models" in data
- assert isinstance(data["models"], list)
-
- async def test_v1_models_returns_data(self, integration_client):
- resp = await integration_client.get("/v1/models")
- assert resp.status_code == 200
- data = resp.json()
- assert "data" in data
- assert isinstance(data["data"], list)
-
- async def test_usage_returns_counts(self, integration_client):
- resp = await integration_client.get("/api/usage")
- assert resp.status_code == 200
- data = resp.json()
- assert "usage_counts" in data
- assert "token_usage_counts" in data
-
- async def test_config_returns_endpoints(self, integration_client):
- resp = await integration_client.get("/api/config")
- assert resp.status_code == 200
- data = resp.json()
- assert "endpoints" in data
-
- async def test_hostname(self, integration_client):
- resp = await integration_client.get("/api/hostname")
- assert resp.status_code == 200
- assert "hostname" in resp.json()
-
- async def test_health(self, integration_client):
- resp = await integration_client.get("/health")
- assert resp.status_code in (200, 503)
- data = resp.json()
- assert data["status"] in ("ok", "error")
- assert "endpoints" in data
-
- async def test_cache_stats(self, integration_client):
- resp = await integration_client.get("/api/cache/stats")
- assert resp.status_code == 200
- data = resp.json()
- assert "enabled" in data
-
-
-# ── /api/chat ─────────────────────────────────────────────────────────────────
-
-class TestApiChat:
- async def test_non_streaming(self, integration_client, chat_model):
- resp = await integration_client.post(
- "/api/chat",
- json={
- "model": chat_model,
- "stream": False,
- "messages": [{"role": "user", "content": "Reply with exactly: OK"}],
- "options": {"num_predict": 10},
- },
- )
- assert resp.status_code == 200
- data = resp.json()
- assert "message" in data
- assert "content" in data["message"]
-
- async def test_streaming_ndjson(self, integration_client, chat_model):
- resp = await integration_client.post(
- "/api/chat",
- json={
- "model": chat_model,
- "stream": True,
- "messages": [{"role": "user", "content": "Say hi"}],
- "options": {"num_predict": 5},
- },
- )
- assert resp.status_code == 200
- lines = [l for l in resp.text.strip().split("\n") if l.strip()]
- assert len(lines) >= 1
- for line in lines:
- obj = json.loads(line)
- assert "model" in obj
-
- async def test_non_streaming_has_token_counts(self, integration_client, chat_model):
- resp = await integration_client.post(
- "/api/chat",
- json={
- "model": chat_model,
- "stream": False,
- "messages": [{"role": "user", "content": "Count to 3"}],
- "options": {"num_predict": 20},
- },
- )
- assert resp.status_code == 200
- data = resp.json()
- assert data.get("done") is True
- # Token counts should be present in the final chunk
- assert data.get("prompt_eval_count", 0) >= 0
-
- async def test_system_message_honoured(self, integration_client, chat_model):
- resp = await integration_client.post(
- "/api/chat",
- json={
- "model": chat_model,
- "stream": False,
- "messages": [
- {"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"},
- {"role": "user", "content": "PING"},
- ],
- "options": {"num_predict": 10},
- },
- )
- assert resp.status_code == 200
- content = resp.json()["message"]["content"]
- assert isinstance(content, str)
- assert len(content) > 0
-
-
-# ── /api/generate ─────────────────────────────────────────────────────────────
-
-class TestApiGenerate:
- async def test_non_streaming(self, integration_client, chat_model):
- resp = await integration_client.post(
- "/api/generate",
- json={
- "model": chat_model,
- "prompt": "Complete: The sky is",
- "stream": False,
- "options": {"num_predict": 5},
- },
- )
- assert resp.status_code == 200
- data = resp.json()
- assert "response" in data
-
- async def test_streaming(self, integration_client, chat_model):
- resp = await integration_client.post(
- "/api/generate",
- json={
- "model": chat_model,
- "prompt": "One plus one equals",
- "stream": True,
- "options": {"num_predict": 5},
- },
- )
- assert resp.status_code == 200
- lines = [l for l in resp.text.strip().split("\n") if l.strip()]
- assert len(lines) >= 1
-
-
-# ── /api/embed ────────────────────────────────────────────────────────────────
-
-class TestApiEmbed:
- async def test_embed_single_string(self, integration_client, embed_model):
- resp = await integration_client.post(
- "/api/embed",
- json={"model": embed_model, "input": "The quick brown fox"},
- )
- assert resp.status_code == 200
- data = resp.json()
- assert "embeddings" in data
- assert isinstance(data["embeddings"], list)
- assert len(data["embeddings"]) == 1
- assert len(data["embeddings"][0]) > 0
-
- async def test_embed_multiple_inputs(self, integration_client, embed_model):
- resp = await integration_client.post(
- "/api/embed",
- json={"model": embed_model, "input": ["sentence one", "sentence two"]},
- )
- assert resp.status_code == 200
- data = resp.json()
- assert "embeddings" in data
- assert len(data["embeddings"]) == 2
-
-
-# ── /v1/chat/completions ──────────────────────────────────────────────────────
-
-class TestOpenAIChatCompletions:
- async def test_non_streaming(self, integration_client, chat_model):
- model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
- resp = await integration_client.post(
- "/v1/chat/completions",
- json={
- "model": model,
- "messages": [{"role": "user", "content": "Reply OK"}],
- "max_tokens": 10,
- "stream": False,
- },
- )
- assert resp.status_code == 200
- data = resp.json()
- assert "choices" in data
- assert len(data["choices"]) > 0
- assert "message" in data["choices"][0]
-
- async def test_streaming_sse(self, integration_client, chat_model):
- model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
- resp = await integration_client.post(
- "/v1/chat/completions",
- json={
- "model": model,
- "messages": [{"role": "user", "content": "Hi"}],
- "max_tokens": 5,
- "stream": True,
- },
- )
- assert resp.status_code == 200
- # Response should be SSE format
- assert "data:" in resp.text or "[DONE]" in resp.text
-
- async def test_non_streaming_has_usage(self, integration_client, chat_model):
- model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
- resp = await integration_client.post(
- "/v1/chat/completions",
- json={
- "model": model,
- "messages": [{"role": "user", "content": "Say yes"}],
- "max_tokens": 5,
- "stream": False,
- },
- )
- assert resp.status_code == 200
- data = resp.json()
- if "usage" in data and data["usage"]:
- assert data["usage"].get("prompt_tokens", 0) >= 0
-
-
-# ── /v1/embeddings ────────────────────────────────────────────────────────────
-
-class TestOpenAIEmbeddings:
- async def test_single_input(self, integration_client, embed_model):
- model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model
- resp = await integration_client.post(
- "/v1/embeddings",
- json={"model": model, "input": "Test sentence"},
- )
- assert resp.status_code == 200
- data = resp.json()
- assert "data" in data
- assert len(data["data"]) > 0
- embedding = data["data"][0].get("embedding")
- assert isinstance(embedding, list)
- assert len(embedding) > 0
-
-
-# ── Token counts (database-backed) ───────────────────────────────────────────
-
-class TestTokenCounts:
- async def test_token_counts_endpoint(self, integration_client):
- resp = await integration_client.get("/api/token_counts")
- assert resp.status_code == 200
- data = resp.json()
- assert "total_tokens" in data
- assert "breakdown" in data
-
-
-# ── ps_details (extended ps) ─────────────────────────────────────────────────
-
-class TestPsDetails:
- async def test_ps_details_returns_models(self, integration_client):
- resp = await integration_client.get("/api/ps_details")
- assert resp.status_code == 200
- data = resp.json()
- assert "models" in data
- assert isinstance(data["models"], list)
diff --git a/test/test_api_validation.py b/test/test_api_validation.py
deleted file mode 100644
index 5d2b52d..0000000
--- a/test/test_api_validation.py
+++ /dev/null
@@ -1,230 +0,0 @@
-"""
-HTTP-level validation and auth middleware tests.
-
-These tests use an in-process httpx client and never reach a real backend:
-all requests are rejected at the validation or auth layer before any
-endpoint-selection or upstream HTTP calls occur.
-"""
-import pytest
-
-
-class TestChatValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post(
- "/api/chat",
- json={"messages": [{"role": "user", "content": "hello"}]},
- )
- assert resp.status_code == 400
- assert "model" in resp.json()["detail"].lower()
-
- async def test_missing_messages_returns_400(self, client):
- resp = await client.post("/api/chat", json={"model": "llama3.2"})
- assert resp.status_code == 400
-
- async def test_invalid_json_returns_400(self, client):
- resp = await client.post(
- "/api/chat",
- content=b"not-json",
- headers={"Content-Type": "application/json"},
- )
- assert resp.status_code == 400
-
- async def test_messages_not_list_returns_400(self, client):
- resp = await client.post(
- "/api/chat",
- json={"model": "m", "messages": "not-a-list"},
- )
- assert resp.status_code == 400
-
- async def test_options_not_dict_returns_400(self, client):
- resp = await client.post(
- "/api/chat",
- json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"},
- )
- assert resp.status_code == 400
-
-
-class TestGenerateValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post("/api/generate", json={"prompt": "hello"})
- assert resp.status_code == 400
- assert "model" in resp.json()["detail"].lower()
-
- async def test_missing_prompt_returns_400(self, client):
- resp = await client.post("/api/generate", json={"model": "m"})
- assert resp.status_code == 400
- assert "prompt" in resp.json()["detail"].lower()
-
- async def test_invalid_json_returns_400(self, client):
- resp = await client.post(
- "/api/generate",
- content=b"{bad-json",
- headers={"Content-Type": "application/json"},
- )
- assert resp.status_code == 400
-
-
-class TestEmbedValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post("/api/embed", json={"input": "hello"})
- assert resp.status_code == 400
-
- async def test_missing_input_returns_400(self, client):
- resp = await client.post("/api/embed", json={"model": "nomic-embed-text"})
- assert resp.status_code == 400
-
-
-class TestEmbeddingsValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post("/api/embeddings", json={"prompt": "hello"})
- assert resp.status_code == 400
-
- async def test_missing_prompt_returns_400(self, client):
- resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"})
- assert resp.status_code == 400
-
-
-class TestOpenAIChatValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post(
- "/v1/chat/completions",
- json={"messages": [{"role": "user", "content": "hello"}]},
- )
- assert resp.status_code == 400
-
- async def test_missing_messages_returns_400(self, client):
- resp = await client.post(
- "/v1/chat/completions",
- json={"model": "gpt-4o"},
- )
- assert resp.status_code == 400
-
- async def test_invalid_json_returns_400(self, client):
- resp = await client.post(
- "/v1/chat/completions",
- content=b"}{",
- headers={"Content-Type": "application/json"},
- )
- assert resp.status_code == 400
-
- async def test_svg_image_rejected(self, client):
- resp = await client.post(
- "/v1/chat/completions",
- json={
- "model": "vision-model",
- "messages": [{
- "role": "user",
- "content": [
- {"type": "text", "text": "describe"},
- {"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}},
- ],
- }],
- },
- )
- assert resp.status_code == 400
- assert "svg" in resp.json()["detail"].lower()
-
-
-class TestOpenAICompletionsValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post("/v1/completions", json={"prompt": "hello"})
- assert resp.status_code == 400
-
- async def test_missing_prompt_returns_400(self, client):
- resp = await client.post("/v1/completions", json={"model": "m"})
- assert resp.status_code == 400
-
-
-class TestRerankValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post(
- "/v1/rerank",
- json={"query": "search query", "documents": ["doc1"]},
- )
- assert resp.status_code == 400
-
- async def test_missing_query_returns_400(self, client):
- resp = await client.post(
- "/v1/rerank",
- json={"model": "reranker", "documents": ["doc1"]},
- )
- assert resp.status_code == 400
-
- async def test_empty_documents_returns_400(self, client):
- resp = await client.post(
- "/v1/rerank",
- json={"model": "reranker", "query": "search", "documents": []},
- )
- assert resp.status_code == 400
-
-
-class TestShowValidation:
- async def test_missing_model_returns_400(self, client):
- resp = await client.post("/api/show", json={})
- assert resp.status_code == 400
-
-
-class TestCopyValidation:
- async def test_missing_source_returns_400(self, client):
- resp = await client.post("/api/copy", json={"destination": "dst"})
- assert resp.status_code == 400
-
- async def test_missing_destination_returns_400(self, client):
- resp = await client.post("/api/copy", json={"source": "src"})
- assert resp.status_code == 400
-
-
-class TestDeleteValidation:
- async def test_missing_model_returns_400(self, client):
- import json as _json
- resp = await client.request(
- "DELETE",
- "/api/delete",
- content=_json.dumps({}).encode(),
- headers={"Content-Type": "application/json"},
- )
- assert resp.status_code == 400
-
-
-class TestAuthMiddleware:
- async def test_no_key_returns_401(self, client_auth):
- resp = await client_auth.post(
- "/api/chat",
- json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
- )
- assert resp.status_code == 401
- assert "Missing" in resp.json()["detail"]
-
- async def test_invalid_key_returns_403(self, client_auth):
- resp = await client_auth.post(
- "/api/chat",
- headers={"Authorization": "Bearer wrong-key"},
- json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
- )
- assert resp.status_code == 403
- assert "Invalid" in resp.json()["detail"]
-
- async def test_valid_key_passes_middleware(self, client_auth):
- # /api/usage reads in-memory counters only — no backend call needed
- resp = await client_auth.get(
- "/api/usage",
- headers={"Authorization": "Bearer test-secret-key"},
- )
- assert resp.status_code == 200
-
- async def test_key_via_query_param(self, client_auth):
- resp = await client_auth.get("/api/usage?api_key=test-secret-key")
- assert resp.status_code == 200
-
- async def test_options_bypasses_auth(self, client_auth):
- resp = await client_auth.options("/api/chat")
- assert resp.status_code not in (401, 403)
-
- async def test_root_path_bypasses_auth(self, client_auth):
- resp = await client_auth.get("/")
- assert resp.status_code not in (401, 403)
-
- async def test_favicon_bypasses_auth(self, client_auth):
- resp = await client_auth.get("/favicon.ico")
- # Should not be blocked by auth (may 404 in test but not 401/403)
- assert resp.status_code not in (401, 403)
diff --git a/test/test_cache.py b/test/test_cache.py
deleted file mode 100644
index f2ce1a9..0000000
--- a/test/test_cache.py
+++ /dev/null
@@ -1,333 +0,0 @@
-"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed)."""
-import tempfile
-from pathlib import Path
-from types import SimpleNamespace
-
-import orjson
-import pytest
-
-import cache as cache_mod
-from cache import (
- LLMCache,
- _bm25_weighted_text,
- get_llm_cache,
- init_llm_cache,
- openai_nonstream_to_sse,
-)
-
-_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db")
-
-
-def _exact_cfg(backend: str = "memory") -> SimpleNamespace:
- """Config for exact-match mode — similarity=1.0 avoids embedding deps."""
- return SimpleNamespace(
- cache_enabled=True,
- cache_backend=backend,
- cache_similarity=1.0,
- cache_history_weight=0.3,
- cache_ttl=300,
- cache_db_path=_CACHE_DB_PATH,
- cache_redis_url="redis://localhost:6379",
- )
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# Pure helpers
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestBM25WeightedText:
- def test_empty_history(self):
- assert _bm25_weighted_text([]) == ""
-
- def test_history_without_content(self):
- assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == ""
-
- def test_repeats_high_idf_terms(self):
- history = [
- {"role": "user", "content": "Tell me about quantum entanglement"},
- {"role": "assistant", "content": "Quantum entanglement is a phenomenon"},
- {"role": "user", "content": "How does entanglement work?"},
- ]
- out = _bm25_weighted_text(history)
- # Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped
- assert "entanglement" in out
- assert "is" not in out.split()
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# openai_nonstream_to_sse
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestOpenAINonstreamToSSE:
- def test_valid_chat_completion(self):
- chat = {
- "id": "x1",
- "created": 123,
- "model": "gpt-4o",
- "choices": [{"message": {"role": "assistant", "content": "hello"}}],
- "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
- }
- out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o")
- text = out.decode()
- assert text.startswith("data: ")
- assert text.endswith("data: [DONE]\n\n")
- # First chunk contains the original content
- first = text.split("\n\n")[0][len("data: "):]
- parsed = orjson.loads(first)
- assert parsed["choices"][0]["delta"]["content"] == "hello"
- assert parsed["usage"]["total_tokens"] == 3
-
- def test_corrupt_bytes_return_done_only(self):
- out = openai_nonstream_to_sse(b"not-json", "m")
- assert out == b"data: [DONE]\n\n"
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# LLMCache internal helpers
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestLLMCacheParsing:
- def test_namespace_is_stable_and_isolated(self):
- c = LLMCache(_exact_cfg())
- a = c._namespace("chat", "m1", "system A")
- b = c._namespace("chat", "m1", "system A")
- assert a == b
- assert c._namespace("chat", "m1", "system B") != a
- assert c._namespace("generate", "m1", "system A") != a
- assert len(a) == 16
-
- def test_parse_messages_flat_strings(self):
- c = LLMCache(_exact_cfg())
- sys, hist, last = c._parse_messages([
- {"role": "system", "content": "be helpful"},
- {"role": "user", "content": "hi"},
- {"role": "assistant", "content": "hello"},
- {"role": "user", "content": "what is 2+2?"},
- ])
- assert sys == "be helpful"
- assert last == "what is 2+2?"
- assert hist == [
- {"role": "user", "content": "hi"},
- {"role": "assistant", "content": "hello"},
- ]
-
- def test_parse_messages_multimodal_content(self):
- c = LLMCache(_exact_cfg())
- sys, _hist, last = c._parse_messages([
- {"role": "system", "content": "sys"},
- {"role": "user", "content": [
- {"type": "text", "text": "describe"},
- {"type": "image_url", "image_url": {"url": "data:..."}},
- ]},
- ])
- assert sys == "sys"
- assert last == "describe"
-
- def test_parse_messages_no_user_message(self):
- c = LLMCache(_exact_cfg())
- sys, hist, last = c._parse_messages([
- {"role": "system", "content": "sys only"},
- ])
- assert sys == "sys only"
- assert last == ""
- assert hist == []
-
-
-class TestPersonalTokenExtraction:
- def test_email_extracted(self):
- c = LLMCache(_exact_cfg())
- toks = c._extract_personal_tokens("Reach me at alice@example.com please")
- assert "alice@example.com" in toks
-
- def test_numeric_id_after_keyword(self):
- c = LLMCache(_exact_cfg())
- toks = c._extract_personal_tokens("User id: 123456")
- assert "123456" in toks
-
- def test_identity_tag_names_extracted(self):
- c = LLMCache(_exact_cfg())
- toks = c._extract_personal_tokens(
- "[Tags: identity] User's name is Andreas Schwibbe"
- )
- # Both name tokens should be extracted lowercased; stopwords dropped
- assert "andreas" in toks
- assert "schwibbe" in toks
- assert "name" not in toks # in _IDENTITY_STOPWORDS
- assert "user" not in toks
-
- def test_empty_system_returns_empty_set(self):
- c = LLMCache(_exact_cfg())
- assert c._extract_personal_tokens("") == frozenset()
-
-
-class TestResponseIsPersonalized:
- def _resp(self, content: str) -> bytes:
- return orjson.dumps({"choices": [{"message": {"content": content}}]})
-
- def test_email_in_response_is_personalized(self):
- c = LLMCache(_exact_cfg())
- assert c._response_is_personalized(self._resp("contact bob@x.com"), "")
-
- def test_uuid_in_response_is_personalized(self):
- c = LLMCache(_exact_cfg())
- uuid = "550e8400-e29b-41d4-a716-446655440000"
- assert c._response_is_personalized(self._resp(f"id={uuid}"), "")
-
- def test_long_numeric_id_in_response_is_personalized(self):
- c = LLMCache(_exact_cfg())
- assert c._response_is_personalized(self._resp("account 12345678"), "")
-
- def test_identity_token_from_system_echoed_in_response(self):
- c = LLMCache(_exact_cfg())
- system = "[Tags: identity] Andreas works here"
- assert c._response_is_personalized(
- self._resp("Yes, Andreas is logged in"), system
- )
-
- def test_generic_response_not_personalized(self):
- c = LLMCache(_exact_cfg())
- assert not c._response_is_personalized(
- self._resp("The capital of France is Paris."), "be helpful"
- )
-
- def test_ollama_message_format_parsed(self):
- c = LLMCache(_exact_cfg())
- body = orjson.dumps({"message": {"content": "alice@example.com"}})
- assert c._response_is_personalized(body, "")
-
- def test_unparseable_body_with_bytes_is_conservative(self):
- c = LLMCache(_exact_cfg())
- # Can't parse → returns True (err on the side of privacy)
- assert c._response_is_personalized(b"binary-junk", "")
-
- def test_empty_response_not_personalized(self):
- c = LLMCache(_exact_cfg())
- assert not c._response_is_personalized(b"", "anything")
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# End-to-end exact-match cache with the memory backend
-# ──────────────────────────────────────────────────────────────────────────────
-
-@pytest.fixture
-async def memcache():
- """LLMCache wired up with the in-memory backend (no external deps)."""
- c = LLMCache(_exact_cfg("memory"))
- await c.init()
- return c
-
-
-class TestExactMatchCache:
- async def test_miss_then_set_then_hit(self, memcache):
- msgs = [
- {"role": "system", "content": "be helpful"},
- {"role": "user", "content": "what is 2+2?"},
- ]
- resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]})
-
- assert await memcache.get_chat("chat", "m1", msgs) is None
- await memcache.set_chat("chat", "m1", msgs, resp)
- hit = await memcache.get_chat("chat", "m1", msgs)
- assert hit == resp
-
- async def test_namespace_isolation_by_system(self, memcache):
- resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
- msgs_a = [
- {"role": "system", "content": "system A"},
- {"role": "user", "content": "same question"},
- ]
- msgs_b = [
- {"role": "system", "content": "system B"},
- {"role": "user", "content": "same question"},
- ]
- await memcache.set_chat("chat", "m", msgs_a, resp)
- # Same question + different system prompt = different namespace = miss
- assert await memcache.get_chat("chat", "m", msgs_b) is None
-
- async def test_namespace_isolation_by_route(self, memcache):
- resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
- msgs = [{"role": "user", "content": "ping"}]
- await memcache.set_chat("chat", "m", msgs, resp)
- assert await memcache.get_chat("openai_chat", "m", msgs) is None
-
- async def test_no_user_message_is_noop(self, memcache):
- msgs = [{"role": "system", "content": "sys only"}]
- resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]})
- # Both get and set should silently no-op
- assert await memcache.get_chat("chat", "m", msgs) is None
- await memcache.set_chat("chat", "m", msgs, resp)
- assert await memcache.get_chat("chat", "m", msgs) is None
-
- async def test_personalized_response_generic_system_not_stored(self, memcache):
- msgs = [
- {"role": "system", "content": "be helpful"}, # generic
- {"role": "user", "content": "give me an email"},
- ]
- # Response contains an email → would leak across users sharing the
- # generic namespace → must NOT be stored at all
- resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]})
- await memcache.set_chat("chat", "m", msgs, resp)
- assert await memcache.get_chat("chat", "m", msgs) is None
-
- async def test_personalized_response_user_specific_system_stored(self, memcache):
- msgs = [
- {"role": "system", "content": "User id: 998877 prefers concise answers"},
- {"role": "user", "content": "what is my id?"},
- ]
- resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]})
- await memcache.set_chat("chat", "m", msgs, resp)
- # User-specific namespace → exact-match within this user is OK
- assert await memcache.get_chat("chat", "m", msgs) == resp
-
- async def test_generate_convenience_wrappers(self, memcache):
- resp = orjson.dumps({"response": "blue"})
- await memcache.set_generate("m", "what color is the sky?", "", resp)
- assert await memcache.get_generate("m", "what color is the sky?") == resp
-
-
-class TestStatsAndClear:
- async def test_stats_tracks_hits_and_misses(self, memcache):
- msgs = [{"role": "user", "content": "hello"}]
- await memcache.get_chat("chat", "m", msgs) # miss
- resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]})
- await memcache.set_chat("chat", "m", msgs, resp)
- await memcache.get_chat("chat", "m", msgs) # hit
- s = memcache.stats()
- assert s["hits"] == 1
- assert s["misses"] == 1
- assert s["hit_rate"] == 0.5
- assert s["semantic"] is False
- assert s["backend"] == "memory"
-
- async def test_clear_resets_counters_and_storage(self, memcache):
- msgs = [{"role": "user", "content": "hi"}]
- resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
- await memcache.set_chat("chat", "m", msgs, resp)
- await memcache.get_chat("chat", "m", msgs)
- await memcache.clear()
- s = memcache.stats()
- assert s["hits"] == 0
- assert s["misses"] == 0
- assert await memcache.get_chat("chat", "m", msgs) is None
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# Module-level helpers
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestInitLLMCache:
- async def test_disabled_returns_none(self):
- cfg = _exact_cfg()
- cfg.cache_enabled = False
- result = await init_llm_cache(cfg)
- assert result is None
-
- async def test_enabled_returns_initialized_cache(self):
- cfg = _exact_cfg()
- try:
- result = await init_llm_cache(cfg)
- assert result is not None
- assert get_llm_cache() is result
- finally:
- # Reset singleton between tests
- cache_mod._cache = None
diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py
deleted file mode 100644
index ece609a..0000000
--- a/test/test_choose_endpoint.py
+++ /dev/null
@@ -1,399 +0,0 @@
-"""Tests for choose_endpoint routing logic with mocked fetch calls."""
-import time
-from unittest.mock import AsyncMock, MagicMock, patch
-
-import pytest
-
-import router
-
-EP1 = "http://ep1:11434"
-EP2 = "http://ep2:11434"
-EP3 = "http://ep3:11434"
-LLAMA_EP = "http://llama:8080/v1"
-
-
-def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
- cfg = MagicMock()
- cfg.endpoints = endpoints
- cfg.llama_server_endpoints = llama_eps or []
- cfg.api_keys = {}
- cfg.max_concurrent_connections = max_conn
- cfg.endpoint_config = endpoint_config or {}
- cfg.priority_routing = priority_routing
- cfg.router_api_key = None
- return cfg
-
-
-@pytest.fixture(autouse=True)
-def reset_usage():
- """Clear usage_counts and error caches between tests to prevent bleed."""
- router.usage_counts.clear()
- router._loaded_error_cache.clear()
- yield
- router.usage_counts.clear()
- router._loaded_error_cache.clear()
-
-
-class TestChooseEndpointBasic:
- async def test_selects_single_candidate(self):
- cfg = _make_cfg([EP1])
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
- ):
- ep, tracking = await router.choose_endpoint("llama3.2:latest")
- assert ep == EP1
- assert tracking == "llama3.2:latest"
-
- async def test_raises_when_no_endpoint_has_model(self):
- cfg = _make_cfg([EP1, EP2])
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(return_value=set())),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
- ):
- with pytest.raises(RuntimeError, match="advertise the model"):
- await router.choose_endpoint("unknown-model:latest")
-
- async def test_prefers_loaded_endpoint(self):
- cfg = _make_cfg([EP1, EP2])
- async def available(ep, *_):
- return {"llama3.2:latest"}
-
- async def loaded(ep):
- return {"llama3.2:latest"} if ep == EP2 else set()
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", side_effect=loaded),
- ):
- ep, _ = await router.choose_endpoint("llama3.2:latest")
- assert ep == EP2
-
- async def test_falls_back_to_free_slot(self):
- cfg = _make_cfg([EP1, EP2])
- async def available(ep, *_):
- return {"llama3.2:latest"}
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
- ):
- ep, _ = await router.choose_endpoint("llama3.2:latest")
- assert ep in (EP1, EP2)
-
- async def test_saturated_picks_least_busy(self):
- cfg = _make_cfg([EP1, EP2])
- cfg.max_concurrent_connections = 1
-
- async def available(ep, *_):
- return {"llama3.2:latest"}
-
- # Saturate EP1 with 2 active connections, EP2 with 1
- router.usage_counts[EP1]["llama3.2:latest"] = 2
- router.usage_counts[EP2]["llama3.2:latest"] = 1
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
- ):
- ep, _ = await router.choose_endpoint("llama3.2:latest")
- # Least-busy is EP2
- assert ep == EP2
-
- async def test_excludes_endpoint_with_recent_loaded_error(self):
- # Regression: issue #83 — when /api/ps fails for EP1 but EP1
- # still advertises the model via /api/tags, routing must not
- # fall back to EP1 just because it has a free slot.
- cfg = _make_cfg([EP1, EP2])
-
- async def available(ep, *_):
- return {"llama3.2:latest"}
-
- # EP1's /api/ps probe failed recently; EP2 is fine but the model
- # is not loaded there. Without the health filter, EP1 would be
- # picked by the free-slot fallback (step 4 in choose_endpoint).
- router._loaded_error_cache[EP1] = time.time()
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
- ):
- ep, _ = await router.choose_endpoint("llama3.2:latest")
- assert ep == EP2
-
- async def test_stale_loaded_error_does_not_exclude(self):
- # Errors older than the 300s window must not keep an endpoint
- # excluded forever.
- cfg = _make_cfg([EP1])
- router._loaded_error_cache[EP1] = time.time() - 301
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"m:latest"})),
- ):
- ep, _ = await router.choose_endpoint("m:latest")
- assert ep == EP1
-
- async def test_all_unhealthy_still_routes(self):
- # If every candidate has a fresh loaded-error we still try one
- # (it may have recovered between the cache write and now) rather
- # than refusing to route.
- cfg = _make_cfg([EP1])
- router._loaded_error_cache[EP1] = time.time()
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
- ):
- ep, _ = await router.choose_endpoint("m:latest")
- assert ep == EP1
-
- async def test_reserve_increments_usage(self):
- cfg = _make_cfg([EP1])
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})),
- ):
- ep, tracking = await router.choose_endpoint("model:latest", reserve=True)
- assert router.usage_counts[ep][tracking] == 1
-
-
-class TestChooseEndpointModelNaming:
- async def test_strips_latest_for_openai_endpoints(self):
- cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP])
- cfg.endpoints = []
-
- async def available(ep, *_):
- # llama-server advertises without :latest
- return {"gpt-4o"}
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})),
- ):
- ep, _ = await router.choose_endpoint("gpt-4o:latest")
- assert ep == LLAMA_EP
-
- async def test_adds_latest_for_ollama_when_bare_name(self):
- cfg = _make_cfg([EP1])
-
- async def available(ep, *_):
- return {"llama3.2:latest"}
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
- ):
- ep, _ = await router.choose_endpoint("llama3.2")
- assert ep == EP1
-
-
-class TestChooseEndpointLoadBalancing:
- async def test_random_selection_among_idle(self):
- cfg = _make_cfg([EP1, EP2, EP3])
- selected = set()
-
- async def available(ep, *_):
- return {"model:latest"}
-
- async def loaded(ep):
- return {"model:latest"}
-
- for _ in range(20):
- router.usage_counts.clear()
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", side_effect=loaded),
- ):
- ep, _ = await router.choose_endpoint("model:latest", reserve=False)
- selected.add(ep)
-
- # With 20 draws from 3 idle endpoints, all three should appear
- assert len(selected) > 1
-
- async def test_sort_by_load_ascending(self):
- cfg = _make_cfg([EP1, EP2])
- router.usage_counts[EP1]["model:latest"] = 1
- router.usage_counts[EP2]["model:latest"] = 0
-
- async def available(ep, *_):
- return {"model:latest"}
-
- async def loaded(ep):
- return {"model:latest"}
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", side_effect=available),
- patch.object(router.fetch, "loaded_models", side_effect=loaded),
- ):
- ep, _ = await router.choose_endpoint("model:latest", reserve=False)
- # EP2 has fewer active connections → should be selected
- assert ep == EP2
-
-
-# ---------------------------------------------------------------------------
-# get_max_connections unit tests
-# ---------------------------------------------------------------------------
-
-class TestGetMaxConnections:
- def test_returns_global_default_when_no_override(self):
- cfg = _make_cfg([EP1, EP2], max_conn=3)
- with patch.object(router, "config", cfg):
- assert router.get_max_connections(EP1) == 3
- assert router.get_max_connections(EP2) == 3
-
- def test_returns_per_endpoint_override(self):
- cfg = _make_cfg(
- [EP1, EP2],
- max_conn=2,
- endpoint_config={EP1: {"max_concurrent_connections": 5}},
- )
- with patch.object(router, "config", cfg):
- assert router.get_max_connections(EP1) == 5
- assert router.get_max_connections(EP2) == 2 # falls back to global
-
- def test_unrecognised_endpoint_falls_back_to_global(self):
- cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}})
- with patch.object(router, "config", cfg):
- assert router.get_max_connections(EP3) == 4
-
-
-# ---------------------------------------------------------------------------
-# Priority / WRR routing tests
-# ---------------------------------------------------------------------------
-
-MODEL = "model:latest"
-
-
-def _all_loaded(ep):
- """Side-effect: every endpoint advertises and has MODEL loaded."""
- return {MODEL}
-
-
-class TestPriorityRouting:
- """Tests for priority_routing=True (WRR + config-order tiebreaking)."""
-
- async def test_idle_picks_first_in_config_order(self):
- """When all endpoints are idle, priority picks the first listed endpoint."""
- cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True)
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
- patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
- ):
- ep, _ = await router.choose_endpoint(MODEL, reserve=False)
- assert ep == EP1
-
- async def test_lower_utilization_preferred_over_priority(self):
- """An endpoint with lower ratio is preferred even if it has lower priority."""
- cfg = _make_cfg([EP1, EP2], priority_routing=True)
- # EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0
- router.usage_counts[EP1][MODEL] = 1
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
- patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
- ):
- ep, _ = await router.choose_endpoint(MODEL, reserve=False)
- assert ep == EP2
-
- async def test_wrr_distribution_matches_expected_sequence(self):
- """
- Full WRR sequence with heterogeneous capacities, mirroring the issue example:
- EP1 max=2, EP2 max=2, EP3 max=1
-
- Expected routing order for 5 sequential requests:
- EP1, EP2, EP3, EP1, EP2
- """
- cfg = _make_cfg(
- [EP1, EP2, EP3],
- max_conn=2,
- endpoint_config={EP3: {"max_concurrent_connections": 1}},
- priority_routing=True,
- )
-
- expected = [EP1, EP2, EP3, EP1, EP2]
- actual = []
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
- patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
- ):
- for _ in expected:
- ep, _ = await router.choose_endpoint(MODEL, reserve=True)
- actual.append(ep)
-
- assert actual == expected
-
- async def test_saturated_picks_lowest_ratio_then_priority(self):
- """When all endpoints are saturated, pick lowest utilization ratio; break ties by priority."""
- cfg = _make_cfg(
- [EP1, EP2, EP3],
- max_conn=1,
- endpoint_config={EP3: {"max_concurrent_connections": 2}},
- priority_routing=True,
- )
- # EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins
- router.usage_counts[EP1][MODEL] = 1
- router.usage_counts[EP2][MODEL] = 1
- router.usage_counts[EP3][MODEL] = 1
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
- patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
- ):
- ep, _ = await router.choose_endpoint(MODEL, reserve=False)
- assert ep == EP3
-
- async def test_saturated_ties_broken_by_priority(self):
- """When all are saturated with equal ratio, config order wins."""
- cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True)
- router.usage_counts[EP1][MODEL] = 1
- router.usage_counts[EP2][MODEL] = 1
- router.usage_counts[EP3][MODEL] = 1
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
- patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
- ):
- ep, _ = await router.choose_endpoint(MODEL, reserve=False)
- assert ep == EP1
-
-
-class TestPriorityRoutingDisabled:
- """Verify that priority_routing=False keeps the original random behaviour."""
-
- async def test_idle_endpoints_are_randomised(self):
- """Without priority routing, all-idle selection must eventually pick each endpoint."""
- cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False)
- selected = set()
-
- with (
- patch.object(router, "config", cfg),
- patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
- patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
- ):
- for _ in range(30):
- router.usage_counts.clear()
- ep, _ = await router.choose_endpoint(MODEL, reserve=False)
- selected.add(ep)
-
- # With 30 draws from 3 equally-idle endpoints, all three must appear
- assert selected == {EP1, EP2, EP3}
diff --git a/test/test_db.py b/test/test_db.py
deleted file mode 100644
index 833b375..0000000
--- a/test/test_db.py
+++ /dev/null
@@ -1,197 +0,0 @@
-"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
-from datetime import datetime, timezone
-
-import pytest
-
-from db import TokenDatabase
-
-
-@pytest.fixture
-async def db(tmp_path):
- inst = TokenDatabase(str(tmp_path / "tokens.db"))
- await inst.init_db()
- yield inst
- await inst.close()
-
-
-class TestInit:
- async def test_init_creates_tables(self, db):
- # Re-init must be idempotent
- await db.init_db()
- # Insert + read confirms tables exist
- await db.update_token_counts("http://ep", "m", 1, 2)
- rows = [r async for r in db.load_token_counts()]
- assert len(rows) == 1
-
- async def test_creates_parent_directory(self, tmp_path):
- nested = tmp_path / "nested" / "subdir" / "x.db"
- inst = TokenDatabase(str(nested))
- await inst.init_db()
- try:
- assert nested.parent.exists()
- finally:
- await inst.close()
-
-
-class TestUpdateTokenCounts:
- async def test_insert_then_update_aggregates(self, db):
- await db.update_token_counts("http://ep", "m1", 10, 20)
- await db.update_token_counts("http://ep", "m1", 5, 7)
- rows = [r async for r in db.load_token_counts()]
- assert len(rows) == 1
- r = rows[0]
- assert r["endpoint"] == "http://ep"
- assert r["model"] == "m1"
- assert r["input_tokens"] == 15
- assert r["output_tokens"] == 27
- assert r["total_tokens"] == 42
-
- async def test_independent_endpoint_model_pairs(self, db):
- await db.update_token_counts("http://ep1", "m1", 1, 1)
- await db.update_token_counts("http://ep1", "m2", 2, 2)
- await db.update_token_counts("http://ep2", "m1", 3, 3)
- rows = [r async for r in db.load_token_counts()]
- assert len(rows) == 3
- totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
- assert totals == {
- ("http://ep1", "m1"): 2,
- ("http://ep1", "m2"): 4,
- ("http://ep2", "m1"): 6,
- }
-
-
-class TestBatchedCounts:
- async def test_update_batched_counts(self, db):
- counts = {
- "http://a": {"m": (4, 6)},
- "http://b": {"m": (1, 1), "n": (10, 0)},
- }
- await db.update_batched_counts(counts)
- rows = [r async for r in db.load_token_counts()]
- totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
- assert totals == {
- ("http://a", "m"): 10,
- ("http://b", "m"): 2,
- ("http://b", "n"): 10,
- }
-
- async def test_empty_batch_is_noop(self, db):
- await db.update_batched_counts({})
- rows = [r async for r in db.load_token_counts()]
- assert rows == []
-
-
-class TestTimeSeries:
- async def test_add_time_series_entry(self, db):
- # The aggregate FK requires the (endpoint,model) row to exist first
- await db.update_token_counts("http://ep", "m", 0, 0)
- await db.add_time_series_entry("http://ep", "m", 3, 4)
- await db.add_time_series_entry("http://ep", "m", 1, 1)
- rows = [r async for r in db.get_latest_time_series(limit=10)]
- assert len(rows) == 2
- # Newest-first ordering; both timestamps are within the same minute,
- # so just check totals are present and well-formed
- for r in rows:
- assert r["endpoint"] == "http://ep"
- assert r["model"] == "m"
- assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
-
- async def test_add_batched_time_series(self, db):
- await db.update_token_counts("http://ep", "m", 0, 0)
- now = int(datetime.now(tz=timezone.utc).timestamp())
- entries = [
- {"endpoint": "http://ep", "model": "m", "input_tokens": 1,
- "output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
- {"endpoint": "http://ep", "model": "m", "input_tokens": 4,
- "output_tokens": 5, "total_tokens": 9, "timestamp": now},
- ]
- await db.add_batched_time_series(entries)
- rows = [r async for r in db.get_latest_time_series(limit=10)]
- assert len(rows) == 2
- assert rows[0]["timestamp"] >= rows[1]["timestamp"]
-
- async def test_get_time_series_for_model_filters(self, db):
- await db.update_token_counts("http://ep", "m1", 0, 0)
- await db.update_token_counts("http://ep", "m2", 0, 0)
- now = int(datetime.now(tz=timezone.utc).timestamp())
- await db.add_batched_time_series([
- {"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
- "output_tokens": 1, "total_tokens": 2, "timestamp": now},
- {"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
- "output_tokens": 9, "total_tokens": 18, "timestamp": now},
- ])
- rows = [r async for r in db.get_time_series_for_model("m1")]
- assert len(rows) == 1
- assert rows[0]["total_tokens"] == 2
-
- async def test_endpoint_distribution_for_model(self, db):
- await db.update_token_counts("http://a", "m", 0, 0)
- await db.update_token_counts("http://b", "m", 0, 0)
- now = int(datetime.now(tz=timezone.utc).timestamp())
- await db.add_batched_time_series([
- {"endpoint": "http://a", "model": "m", "input_tokens": 1,
- "output_tokens": 1, "total_tokens": 2, "timestamp": now},
- {"endpoint": "http://a", "model": "m", "input_tokens": 1,
- "output_tokens": 1, "total_tokens": 2, "timestamp": now},
- {"endpoint": "http://b", "model": "m", "input_tokens": 5,
- "output_tokens": 5, "total_tokens": 10, "timestamp": now},
- ])
- dist = await db.get_endpoint_distribution_for_model("m")
- assert dist == {"http://a": 4, "http://b": 10}
-
-
-class TestGetTokenCountsForModel:
- async def test_aggregates_across_endpoints(self, db):
- await db.update_token_counts("http://a", "m", 1, 2)
- await db.update_token_counts("http://b", "m", 3, 4)
- result = await db.get_token_counts_for_model("m")
- assert result is not None
- assert result["endpoint"] == "aggregated"
- assert result["model"] == "m"
- assert result["input_tokens"] == 4
- assert result["output_tokens"] == 6
- assert result["total_tokens"] == 10
-
- async def test_unknown_model_returns_zero_aggregate(self, db):
- # SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
- result = await db.get_token_counts_for_model("nope")
- assert result is not None
- assert result["input_tokens"] in (0, None)
-
-
-class TestAggregateTimeSeriesOlderThan:
- async def test_aggregates_old_entries_by_day(self, db):
- await db.update_token_counts("http://ep", "m", 0, 0)
- now = int(datetime.now(tz=timezone.utc).timestamp())
- old = now - (40 * 86400) # 40 days ago
- await db.add_batched_time_series([
- {"endpoint": "http://ep", "model": "m", "input_tokens": 1,
- "output_tokens": 1, "total_tokens": 2, "timestamp": old},
- {"endpoint": "http://ep", "model": "m", "input_tokens": 3,
- "output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
- {"endpoint": "http://ep", "model": "m", "input_tokens": 99,
- "output_tokens": 99, "total_tokens": 198, "timestamp": now},
- ])
- n = await db.aggregate_time_series_older_than(30, trim_old=False)
- assert n == 1 # one (endpoint, model, day) group rolled up
-
- async def test_invalid_days_falls_back_to_30(self, db):
- # Just ensure it doesn't blow up with a bogus value
- n = await db.aggregate_time_series_older_than(0)
- assert n == 0
-
- async def test_trim_old_removes_aggregated_rows(self, db):
- await db.update_token_counts("http://ep", "m", 0, 0)
- now = int(datetime.now(tz=timezone.utc).timestamp())
- old = now - (40 * 86400)
- await db.add_batched_time_series([
- {"endpoint": "http://ep", "model": "m", "input_tokens": 1,
- "output_tokens": 1, "total_tokens": 2, "timestamp": old},
- {"endpoint": "http://ep", "model": "m", "input_tokens": 99,
- "output_tokens": 99, "total_tokens": 198, "timestamp": now},
- ])
- await db.aggregate_time_series_older_than(30, trim_old=True)
- remaining = [r async for r in db.get_latest_time_series(limit=10)]
- # Only the recent (within-cutoff) row should remain
- assert len(remaining) == 1
- assert remaining[0]["total_tokens"] == 198
diff --git a/test/test_fetch.py b/test/test_fetch.py
deleted file mode 100644
index 6f2ed50..0000000
--- a/test/test_fetch.py
+++ /dev/null
@@ -1,210 +0,0 @@
-"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking."""
-import time
-from unittest.mock import patch, MagicMock
-
-import pytest
-from aioresponses import aioresponses
-
-import router
-from conftest import TEST_OLLAMA, TEST_LLAMA
-
-MOCK_OLLAMA_EP = "http://mock-ollama:11434"
-MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
-
-
-def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
- cfg = MagicMock()
- cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
- cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
- cfg.api_keys = api_keys or {}
- cfg.max_concurrent_connections = 2
- cfg.router_api_key = None
- return cfg
-
-
-@pytest.fixture(autouse=True)
-def clear_caches(aio_session):
- """aio_session fixture already clears caches and sets up app_state."""
- yield
-
-
-class TestFetchAvailableModels:
- async def test_ollama_tags(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/tags",
- payload={"models": [
- {"name": "llama3.2:latest"},
- {"name": "qwen2.5:7b"},
- ]},
- )
- models = await router.fetch.available_models(MOCK_OLLAMA_EP)
- assert models == {"llama3.2:latest", "qwen2.5:7b"}
-
- async def test_openai_compatible_models_endpoint(self):
- cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_LLAMA_EP}/models",
- payload={"data": [{"id": "unsloth/model:Q8_0"}]},
- )
- models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok")
- assert "unsloth/model:Q8_0" in models
-
- async def test_caches_successful_result(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/tags",
- payload={"models": [{"name": "llama3.2:latest"}]},
- )
- first = await router.fetch.available_models(MOCK_OLLAMA_EP)
- second = await router.fetch.available_models(MOCK_OLLAMA_EP)
- # second call must be served from cache without a second HTTP request
- assert first == second == {"llama3.2:latest"}
-
- async def test_returns_empty_on_http_500(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
- models = await router.fetch.available_models(MOCK_OLLAMA_EP)
- assert models == set()
-
- async def test_returns_empty_on_connection_error(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- import aiohttp
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/tags",
- exception=aiohttp.ClientConnectorError(
- connection_key=MagicMock(host="mock-ollama", port=11434),
- os_error=OSError(111, "refused"),
- ),
- )
- models = await router.fetch.available_models(MOCK_OLLAMA_EP)
- assert models == set()
-
- async def test_stale_cache_returned_while_refresh_runs(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/tags",
- payload={"models": [{"name": "llama3.2:latest"}]},
- )
- await router.fetch.available_models(MOCK_OLLAMA_EP)
-
- # Manually age cache into stale-but-valid window (300-600s)
- async with router._models_cache_lock:
- models, _ = router._models_cache[MOCK_OLLAMA_EP]
- router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
-
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/tags",
- payload={"models": [{"name": "llama3.2:latest"}]},
- )
- # Should return stale data immediately
- stale = await router.fetch.available_models(MOCK_OLLAMA_EP)
- assert "llama3.2:latest" in stale
-
- async def test_error_cache_short_circuits(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- # Seed error cache with a very recent error
- async with router._available_error_cache_lock:
- router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
-
- with patch.object(router, "config", cfg), aioresponses():
- # No HTTP mock registered — if a call happens it will raise
- models = await router.fetch.available_models(MOCK_OLLAMA_EP)
- assert models == set()
-
-
-class TestFetchLoadedModels:
- async def test_ollama_ps(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/ps",
- payload={"models": [{"name": "llama3.2:latest"}]},
- )
- models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
- assert models == {"llama3.2:latest"}
-
- async def test_llama_server_filters_loaded(self):
- cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_LLAMA_EP}/models",
- payload={"data": [
- {"id": "model-a", "status": {"value": "loaded"}},
- {"id": "model-b", "status": {"value": "unloaded"}},
- ]},
- )
- models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
- assert models == {"model-a"}
-
- async def test_llama_server_no_status_field_always_loaded(self):
- cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_LLAMA_EP}/models",
- payload={"data": [{"id": "always-on-model"}]},
- )
- models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
- assert "always-on-model" in models
-
- async def test_returns_empty_on_error(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
- models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
- assert models == set()
-
- async def test_ext_openai_always_empty(self):
- ext_ep = "https://api.openai.com/v1"
- cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[])
- with patch.object(router, "config", cfg):
- models = await router.fetch.loaded_models(ext_ep)
- assert models == set()
-
- async def test_caches_result(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/ps",
- payload={"models": [{"name": "qwen:7b"}]},
- )
- first = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
- second = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
- assert first == second
-
- async def test_records_error_in_loaded_error_cache_on_failure(self):
- # Regression: issue #83 — /api/ps failures must be recorded so
- # `choose_endpoint` can exclude unhealthy backends from routing.
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
- await router.fetch.loaded_models(MOCK_OLLAMA_EP)
- assert MOCK_OLLAMA_EP in router._loaded_error_cache
-
- async def test_records_error_for_llama_server_on_failure(self):
- cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP])
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
- await router.fetch.loaded_models(MOCK_LLAMA_EP)
- assert MOCK_LLAMA_EP in router._loaded_error_cache
-
- async def test_clears_error_cache_on_subsequent_success(self):
- cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
- # Pre-seed an old error so loaded_models() falls through to the
- # network probe instead of short-circuiting on the error cache.
- async with router._loaded_error_cache_lock:
- router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301
- with patch.object(router, "config", cfg), aioresponses() as m:
- m.get(
- f"{MOCK_OLLAMA_EP}/api/ps",
- payload={"models": [{"name": "qwen:7b"}]},
- )
- await router.fetch.loaded_models(MOCK_OLLAMA_EP)
- assert MOCK_OLLAMA_EP not in router._loaded_error_cache
diff --git a/test/test_openai_proxies.py b/test/test_openai_proxies.py
deleted file mode 100644
index 8a56c91..0000000
--- a/test/test_openai_proxies.py
+++ /dev/null
@@ -1,181 +0,0 @@
-"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
-
-These tests verify that when the LLM cache reports a hit, the route returns
-the cached payload *without* selecting an endpoint or contacting any backend.
-"""
-from unittest.mock import AsyncMock, patch
-
-import orjson
-import pytest
-from fastapi import HTTPException
-
-import router
-
-
-_BYPASS = HTTPException(status_code=599, detail="bypassed")
-
-
-class _FakeCache:
- """Minimal stand-in for cache.LLMCache.get_chat."""
- def __init__(self, response_bytes: bytes | None):
- self._resp = response_bytes
- self.calls: list[tuple] = []
-
- async def get_chat(self, route, model, messages):
- self.calls.append((route, model, messages))
- return self._resp
-
-
-@pytest.fixture
-def cache_hit_payload():
- return orjson.dumps({
- "id": "cmpl-xyz",
- "created": 1,
- "model": "test-model",
- "choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
- "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
- })
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# /v1/chat/completions
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestOpenAIChatCompletionsCacheHit:
- async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
- fake = _FakeCache(cache_hit_payload)
- # Patch the route's references to both helpers — they're imported by name
- # into router's namespace at module load time.
- with (
- patch.object(router, "get_llm_cache", return_value=fake),
- patch.object(router, "choose_endpoint",
- AsyncMock(side_effect=AssertionError("backend must not be reached"))),
- ):
- resp = await client.post(
- "/v1/chat/completions",
- json={
- "model": "test-model",
- "messages": [{"role": "user", "content": "ping"}],
- "stream": False,
- "nomyo": {"cache": True},
- },
- )
- assert resp.status_code == 200
- # Body is streamed; collect it
- body = resp.content
- parsed = orjson.loads(body)
- assert parsed["choices"][0]["message"]["content"] == "from-cache"
- assert fake.calls and fake.calls[0][0] == "openai_chat"
-
- async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
- fake = _FakeCache(cache_hit_payload)
- with (
- patch.object(router, "get_llm_cache", return_value=fake),
- patch.object(router, "choose_endpoint",
- AsyncMock(side_effect=AssertionError("backend must not be reached"))),
- ):
- resp = await client.post(
- "/v1/chat/completions",
- json={
- "model": "test-model",
- "messages": [{"role": "user", "content": "ping"}],
- "stream": True,
- "nomyo": {"cache": True},
- },
- )
- assert resp.status_code == 200
- assert resp.headers["content-type"].startswith("text/event-stream")
- text = resp.content.decode()
- # First SSE frame contains the cached content as a delta
- first_frame = text.split("\n\n")[0]
- assert first_frame.startswith("data: ")
- chunk = orjson.loads(first_frame[len("data: "):])
- assert chunk["choices"][0]["delta"]["content"] == "from-cache"
- # Stream is terminated with [DONE]
- assert "data: [DONE]" in text
-
- async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
- """When nomyo.cache=False, get_chat is never called even if a cache exists."""
- fake = _FakeCache(b"") # has a response, but should never be consulted
- with (
- patch.object(router, "get_llm_cache", return_value=fake),
- patch.object(router, "choose_endpoint",
- AsyncMock(side_effect=_BYPASS)),
- ):
- resp = await client.post(
- "/v1/chat/completions",
- json={
- "model": "m",
- "messages": [{"role": "user", "content": "hi"}],
- "nomyo": {"cache": False},
- },
- )
- # Got past the cache short-circuit → endpoint selection invoked
- assert resp.status_code == 599
- assert fake.calls == []
-
- async def test_no_cache_configured_bypasses_cache_check(self, client):
- """get_llm_cache() returning None should not break the route."""
- with (
- patch.object(router, "get_llm_cache", return_value=None),
- patch.object(router, "choose_endpoint",
- AsyncMock(side_effect=_BYPASS)),
- ):
- resp = await client.post(
- "/v1/chat/completions",
- json={
- "model": "m",
- "messages": [{"role": "user", "content": "hi"}],
- "nomyo": {"cache": True},
- },
- )
- assert resp.status_code == 599
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# /v1/completions
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestOpenAICompletionsCacheHit:
- async def test_nonstream_cache_hit(self, client, cache_hit_payload):
- fake = _FakeCache(cache_hit_payload)
- with (
- patch.object(router, "get_llm_cache", return_value=fake),
- patch.object(router, "choose_endpoint",
- AsyncMock(side_effect=AssertionError("backend must not be reached"))),
- ):
- resp = await client.post(
- "/v1/completions",
- json={
- "model": "test-model",
- "prompt": "Tell me a joke",
- "stream": False,
- "nomyo": {"cache": True},
- },
- )
- assert resp.status_code == 200
- # Prompt-style cache lookup is namespaced under "openai_completions"
- assert fake.calls[0][0] == "openai_completions"
- # Cache lookup receives the prompt as a single user message
- cached_msgs = fake.calls[0][2]
- assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
-
- async def test_stream_cache_hit(self, client, cache_hit_payload):
- fake = _FakeCache(cache_hit_payload)
- with (
- patch.object(router, "get_llm_cache", return_value=fake),
- patch.object(router, "choose_endpoint",
- AsyncMock(side_effect=AssertionError("backend must not be reached"))),
- ):
- resp = await client.post(
- "/v1/completions",
- json={
- "model": "test-model",
- "prompt": "What is 2+2?",
- "stream": True,
- "nomyo": {"cache": True},
- },
- )
- assert resp.status_code == 200
- assert resp.headers["content-type"].startswith("text/event-stream")
- assert "data: [DONE]" in resp.content.decode()
diff --git a/test/test_unit_context.py b/test/test_unit_context.py
deleted file mode 100644
index de2b98a..0000000
--- a/test/test_unit_context.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""Unit tests for context-window trimming logic."""
-import pytest
-import router
-
-
-def _msgs(roles_contents):
- return [{"role": r, "content": c} for r, c in roles_contents]
-
-
-class TestCountMessageTokens:
- def test_returns_int(self):
- msgs = _msgs([("user", "hello")])
- assert isinstance(router._count_message_tokens(msgs), int)
-
- def test_empty_list(self):
- assert router._count_message_tokens([]) >= 0
-
- def test_longer_content_more_tokens(self):
- short = _msgs([("user", "hi")])
- long_ = _msgs([("user", "a " * 500)])
- assert router._count_message_tokens(long_) > router._count_message_tokens(short)
-
- def test_list_content(self):
- msgs = [{"role": "user", "content": [
- {"type": "text", "text": "what do you see?"},
- ]}]
- tokens = router._count_message_tokens(msgs)
- assert tokens > 0
-
- def test_multiple_messages(self):
- msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
- assert router._count_message_tokens(msgs) > 10
-
-
-class TestTrimMessagesForContext:
- def test_short_history_unchanged(self):
- msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
- result = router._trim_messages_for_context(msgs, n_ctx=4096)
- assert result == msgs
-
- def test_system_messages_always_kept(self):
- msgs = (
- _msgs([("system", "you are helpful")])
- + _msgs([("user", f"msg {i}") for i in range(50)])
- + _msgs([("user", "final question")])
- )
- result = router._trim_messages_for_context(msgs, n_ctx=512)
- system_msgs = [m for m in result if m["role"] == "system"]
- assert len(system_msgs) == 1
- assert system_msgs[0]["content"] == "you are helpful"
-
- def test_last_user_message_always_kept(self):
- msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
- result = router._trim_messages_for_context(msgs, n_ctx=256)
- assert result[-1]["content"] == "very important last question"
-
- def test_oldest_dropped_first(self):
- msgs = _msgs([
- ("user", "oldest msg"),
- ("assistant", "oldest reply"),
- ("user", "newer msg"),
- ("assistant", "newer reply"),
- ("user", "newest"),
- ])
- # Use very small target to force trimming
- result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
- contents = [m["content"] for m in result]
- # "oldest msg" should be dropped before "newest"
- if "oldest msg" in contents:
- assert "newest" in contents
- else:
- assert "newest" in contents
-
- def test_result_starts_with_user(self):
- msgs = _msgs([
- ("assistant", "leftover assistant"),
- ("user", "question"),
- ])
- result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
- if result:
- assert result[0]["role"] == "user"
-
- def test_target_tokens_overrides_safety_margin(self):
- msgs = _msgs([("user", "a " * 200)])
- result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
- result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
- # Both should return at least the last message
- assert len(result_small) >= 1
- assert len(result_large) >= 1
-
-
-class TestCalibratedTrimTarget:
- def test_returns_positive_int(self):
- msgs = [{"role": "user", "content": "hello " * 100}]
- result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
- assert isinstance(result, int)
- assert result >= 1
-
- def test_over_limit_reduces_target(self):
- msgs = [{"role": "user", "content": "a " * 500}]
- # actual_tokens > n_ctx means we need to shed more
- target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
- assert target < router._count_message_tokens(msgs)
-
- def test_well_within_limit_returns_current(self):
- msgs = [{"role": "user", "content": "hi"}]
- # actual_tokens << n_ctx means nothing to shed
- target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
- # Should return cur_tiktoken since to_shed == 0
- assert target == max(1, router._count_message_tokens(msgs))
-
- def test_minimum_is_one(self):
- # Even if we need to shed everything, result is at least 1
- msgs = [{"role": "user", "content": "hello"}]
- target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
- assert target >= 1
diff --git a/test/test_unit_helpers.py b/test/test_unit_helpers.py
deleted file mode 100644
index d38eb37..0000000
--- a/test/test_unit_helpers.py
+++ /dev/null
@@ -1,279 +0,0 @@
-"""Unit tests for pure helper functions in router.py (no network, no app)."""
-import time
-import asyncio
-from unittest.mock import MagicMock, patch
-
-import aiohttp
-import pytest
-
-import router
-
-
-class TestMaskSecrets:
- def test_masks_openai_key(self):
- text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ"
- result = router._mask_secrets(text)
- assert "sk-***redacted***" in result
- assert "sk-abcd1234" not in result
-
- def test_masks_api_key_assignment(self):
- result = router._mask_secrets("api_key: supersecretvalue123")
- assert "supersecretvalue123" not in result
- assert "***redacted***" in result
-
- def test_masks_api_key_with_colon(self):
- result = router._mask_secrets("api-key: mykey")
- assert "mykey" not in result
-
- def test_empty_string_returns_empty(self):
- assert router._mask_secrets("") == ""
-
- def test_none_returns_none(self):
- assert router._mask_secrets(None) is None
-
- def test_no_secrets_unchanged(self):
- text = "this is a normal log line"
- assert router._mask_secrets(text) == text
-
-
-class TestIsFresh:
- def test_fresh_within_ttl(self):
- cached_at = time.time() - 10
- assert router._is_fresh(cached_at, 300) is True
-
- def test_expired_beyond_ttl(self):
- cached_at = time.time() - 400
- assert router._is_fresh(cached_at, 300) is False
-
- def test_exactly_at_boundary(self):
- cached_at = time.time() - 300
- # May be True or False depending on timing, just verify it runs
- result = router._is_fresh(cached_at, 300)
- assert isinstance(result, bool)
-
- def test_just_cached(self):
- assert router._is_fresh(time.time(), 1) is True
-
-
-class TestNormalizeLlamaModelName:
- def test_strips_hf_prefix(self):
- assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF"
-
- def test_strips_quant_suffix(self):
- assert router._normalize_llama_model_name("model:Q8_0") == "model"
-
- def test_strips_both(self):
- result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16")
- assert result == "gpt-oss-20b-GGUF"
-
- def test_no_prefix_no_suffix(self):
- assert router._normalize_llama_model_name("plain-model") == "plain-model"
-
- def test_multiple_slashes(self):
- result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M")
- assert result == "model-name"
-
-
-class TestExtractLlamaQuant:
- def test_extracts_quant(self):
- assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0"
-
- def test_no_quant_returns_empty(self):
- assert router._extract_llama_quant("plain-model") == ""
-
- def test_f16(self):
- assert router._extract_llama_quant("model:F16") == "F16"
-
- def test_q4_k_m(self):
- assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M"
-
-
-class TestIsUnixSocketEndpoint:
- def test_sock_endpoint_detected(self):
- assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True
-
- def test_regular_http_not_sock(self):
- assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False
-
- def test_ollama_not_sock(self):
- assert router._is_unix_socket_endpoint("http://localhost:11434") is False
-
- def test_dot_sock_in_host_detected(self):
- assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True
-
-
-class TestGetSocketPath:
- def test_returns_run_user_path(self):
- import os
- path = router._get_socket_path("http://192.168.0.52.sock/v1")
- uid = os.getuid()
- assert path == f"/run/user/{uid}/192.168.0.52.sock"
-
-
-class TestIsBase64:
- def test_valid_base64(self):
- import base64
- data = base64.b64encode(b"hello world").decode()
- assert router.is_base64(data) is True
-
- def test_invalid_base64(self):
- assert router.is_base64("not-base64!@#$") is False
-
- def test_empty_string(self):
- # Empty string is valid base64 (decodes to empty bytes)
- assert router.is_base64("") is True
-
- def test_non_string(self):
- # Non-strings fall through without returning True (returns None)
- assert not router.is_base64(12345)
-
-
-class TestIsLlamaModelLoaded:
- def test_status_dict_loaded(self):
- assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True
-
- def test_status_dict_unloaded(self):
- assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False
-
- def test_status_string_loaded(self):
- assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True
-
- def test_status_string_unloaded(self):
- assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False
-
- def test_no_status_field_always_loaded(self):
- # No status field → always available (single-model server)
- assert router._is_llama_model_loaded({"id": "m"}) is True
-
- def test_status_none_always_loaded(self):
- assert router._is_llama_model_loaded({"id": "m", "status": None}) is True
-
-
-class TestEp2Base:
- def test_adds_v1_to_ollama(self):
- assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1"
-
- def test_keeps_v1_if_present(self):
- assert router.ep2base("http://host/v1") == "http://host/v1"
-
- def test_llama_server_endpoint_unchanged(self):
- ep = "http://192.168.0.50:8889/v1"
- assert router.ep2base(ep) == ep
-
-
-class TestDedupeOnKeys:
- def test_removes_duplicate_by_single_key(self):
- items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}]
- result = router.dedupe_on_keys(items, ["name"])
- assert len(result) == 2
- assert result[0]["name"] == "a"
- assert result[1]["name"] == "b"
-
- def test_removes_duplicate_by_two_keys(self):
- items = [
- {"digest": "abc", "name": "m1"},
- {"digest": "abc", "name": "m1"},
- {"digest": "def", "name": "m2"},
- ]
- result = router.dedupe_on_keys(items, ["digest", "name"])
- assert len(result) == 2
-
- def test_empty_list(self):
- assert router.dedupe_on_keys([], ["name"]) == []
-
- def test_no_duplicates(self):
- items = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
- assert len(router.dedupe_on_keys(items, ["name"])) == 3
-
-
-class TestFormatConnectionIssue:
- def test_connector_error_message(self):
- err = aiohttp.ClientConnectorError(
- connection_key=MagicMock(host="localhost", port=11434),
- os_error=OSError(111, "Connection refused"),
- )
- msg = router._format_connection_issue("http://localhost:11434", err)
- assert "localhost" in msg
- assert "Connection refused" in msg or "111" in msg
-
- def test_timeout_error_message(self):
- msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError())
- assert "Timed out" in msg
- assert "host:1234" in msg
-
- def test_generic_error(self):
- msg = router._format_connection_issue("http://host:1234", ValueError("boom"))
- assert "host:1234" in msg
- assert "boom" in msg
-
-
-class TestIsExtOpenaiEndpoint:
- def test_openai_com_is_ext(self):
- cfg = MagicMock()
- cfg.endpoints = []
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True
-
- def test_ollama_default_port_not_ext(self):
- cfg = MagicMock()
- cfg.endpoints = ["http://host:11434"]
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.is_ext_openai_endpoint("http://host:11434") is False
-
- def test_llama_server_not_ext(self):
- cfg = MagicMock()
- cfg.endpoints = []
- cfg.llama_server_endpoints = ["http://host:8080/v1"]
- with patch.object(router, "config", cfg):
- assert router.is_ext_openai_endpoint("http://host:8080/v1") is False
-
- def test_no_v1_not_ext(self):
- cfg = MagicMock()
- cfg.endpoints = ["http://host:11434"]
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.is_ext_openai_endpoint("http://host:11434") is False
-
-
-class TestIsOpenaiCompatible:
- def test_v1_endpoint_compatible(self):
- cfg = MagicMock()
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.is_openai_compatible("http://host/v1") is True
-
- def test_ollama_not_compatible(self):
- cfg = MagicMock()
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.is_openai_compatible("http://localhost:11434") is False
-
- def test_llama_server_in_list_compatible(self):
- cfg = MagicMock()
- cfg.llama_server_endpoints = ["http://host:8080"]
- with patch.object(router, "config", cfg):
- assert router.is_openai_compatible("http://host:8080") is True
-
-
-class TestGetTrackingModel:
- def test_ollama_adds_latest(self):
- cfg = MagicMock()
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest"
-
- def test_ollama_keeps_existing_tag(self):
- cfg = MagicMock()
- cfg.llama_server_endpoints = []
- with patch.object(router, "config", cfg):
- assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b"
-
- def test_llama_server_normalizes(self):
- ep = "http://host:8080/v1"
- cfg = MagicMock()
- cfg.llama_server_endpoints = [ep]
- with patch.object(router, "config", cfg):
- result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
- assert result == "model"
diff --git a/test/test_unit_rechunk.py b/test/test_unit_rechunk.py
deleted file mode 100644
index e0d01c9..0000000
--- a/test/test_unit_rechunk.py
+++ /dev/null
@@ -1,173 +0,0 @@
-"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion."""
-import time
-from types import SimpleNamespace
-
-import ollama
-
-import router
-
-
-def _ns(**kw):
- return SimpleNamespace(**kw)
-
-
-def _stream_chunk(content="hi", role="assistant", finish_reason=None,
- usage=None, model="m"):
- """Build a SimpleNamespace mimicking a streaming OpenAI chunk."""
- delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
- tool_calls=None)
- choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None)
- return _ns(model=model, choices=[choice], usage=usage)
-
-
-def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop",
- usage=None, model="m", tool_calls=None):
- """Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion."""
- message = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
- tool_calls=tool_calls)
- choice = _ns(message=message, finish_reason=finish_reason, logprobs=None)
- return _ns(model=model, choices=[choice], usage=usage)
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# openai_chat_completion2ollama
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestChatCompletionToOllama:
- def test_streaming_content_chunk(self):
- chunk = _stream_chunk(content="hello", finish_reason=None, usage=None)
- out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
- assert isinstance(out, ollama.ChatResponse)
- assert out.message.role == "assistant"
- assert out.message.content == "hello"
- assert out.done is False # usage is None → not done yet
- assert out.model == "m"
-
- def test_streaming_empty_content_defaults(self):
- # Some chunks have content=None — should coerce to empty string
- chunk = _stream_chunk(content=None, role=None)
- out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
- assert out.message.role == "assistant" # role defaulted
- assert out.message.content == ""
-
- def test_final_usage_only_chunk_marks_done(self):
- usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15)
- chunk = _ns(model="m", choices=[], usage=usage)
- out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
- assert out.done is True
- assert out.done_reason == "stop"
- assert out.prompt_eval_count == 10
- assert out.eval_count == 5
- assert out.message.content == ""
-
- def test_nonstreaming_with_content(self):
- usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5)
- chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage)
- out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
- assert out.done is True
- assert out.message.content == "response text"
- assert out.prompt_eval_count == 2
- assert out.eval_count == 3
-
- def test_nonstreaming_tool_calls_converted(self):
- """Tool calls with JSON string arguments are parsed into dicts."""
- tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}'))
- usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
- chunk = _nonstream_chunk(
- content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc]
- )
- out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
- assert out.message.tool_calls is not None
- assert len(out.message.tool_calls) == 1
- first = out.message.tool_calls[0]
- assert first.function.name == "get_weather"
- assert first.function.arguments == {"city": "Paris"}
-
- def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self):
- tc = _ns(function=_ns(name="f", arguments="not-json"))
- usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
- chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc])
- out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
- assert out.message.tool_calls[0].function.arguments == {}
-
- def test_streaming_tool_calls_in_delta_are_skipped(self):
- """Streaming mode must not assemble tool calls (caller handles it)."""
- chunk = _stream_chunk(content="x", finish_reason=None)
- # Even if a chunk somehow carried tool_calls in the delta, streaming
- # mode should ignore them.
- out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
- assert out.message.tool_calls is None
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# openai_completion2ollama
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestCompletionToOllama:
- def test_streaming_text_chunk(self):
- choice = _ns(text="word", finish_reason=None, reasoning=None)
- chunk = _ns(model="m", choices=[choice], usage=None)
- out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
- assert isinstance(out, ollama.GenerateResponse)
- assert out.response == "word"
- assert out.done is False
-
- def test_final_chunk_with_usage(self):
- usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10)
- choice = _ns(text="end", finish_reason="stop", reasoning=None)
- chunk = _ns(model="m", choices=[choice], usage=usage)
- out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
- assert out.done is True
- assert out.prompt_eval_count == 4
- assert out.eval_count == 6
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# embeddings / embed
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestEmbeddingConversions:
- def test_openai_embeddings2ollama(self):
- chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])])
- out = router.rechunk.openai_embeddings2ollama(chunk)
- assert isinstance(out, ollama.EmbeddingsResponse)
- assert list(out.embedding) == [0.1, 0.2, 0.3]
-
- def test_openai_embed2ollama(self):
- chunk = _ns(data=[_ns(embedding=[0.5, 0.6])])
- out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model")
- assert isinstance(out, ollama.EmbedResponse)
- assert out.model == "my-embed-model"
- assert list(out.embeddings[0]) == [0.5, 0.6]
-
-
-# ──────────────────────────────────────────────────────────────────────────────
-# extract_usage_from_llama_timings
-# ──────────────────────────────────────────────────────────────────────────────
-
-class TestExtractUsageFromLlamaTimings:
- def test_none_when_no_timings_attr(self):
- obj = _ns()
- assert router.rechunk.extract_usage_from_llama_timings(obj) is None
-
- def test_prompt_plus_cache_sums(self):
- obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35})
- prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
- assert prompt == 237
- assert completion == 35
-
- def test_missing_keys_default_to_zero(self):
- obj = _ns(timings={"predicted_n": 12})
- prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
- assert prompt == 0
- assert completion == 12
-
- def test_null_values_treated_as_zero(self):
- obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None})
- prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
- assert prompt == 0
- assert completion == 0
-
- def test_non_dict_timings_returns_none(self):
- obj = _ns(timings="not-a-dict")
- assert router.rechunk.extract_usage_from_llama_timings(obj) is None
diff --git a/test/test_unit_transforms.py b/test/test_unit_transforms.py
deleted file mode 100644
index 51160a0..0000000
--- a/test/test_unit_transforms.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""Unit tests for message transformation functions."""
-from unittest.mock import MagicMock
-
-import pytest
-
-import router
-
-
-class TestStripAssistantPrefill:
- def test_removes_trailing_assistant(self):
- msgs = [
- {"role": "user", "content": "hello"},
- {"role": "assistant", "content": "prefill"},
- ]
- result = router._strip_assistant_prefill(msgs)
- assert len(result) == 1
- assert result[0]["role"] == "user"
-
- def test_keeps_non_trailing_assistant(self):
- msgs = [
- {"role": "user", "content": "hello"},
- {"role": "assistant", "content": "response"},
- {"role": "user", "content": "follow-up"},
- ]
- result = router._strip_assistant_prefill(msgs)
- assert len(result) == 3
-
- def test_empty_list_unchanged(self):
- assert router._strip_assistant_prefill([]) == []
-
- def test_single_user_message_unchanged(self):
- msgs = [{"role": "user", "content": "hi"}]
- assert router._strip_assistant_prefill(msgs) == msgs
-
-
-class TestTransformToolCallsToOpenAI:
- def test_adds_type_function(self):
- msgs = [{"role": "assistant", "tool_calls": [
- {"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}}
- ]}]
- result = router.transform_tool_calls_to_openai(msgs)
- tc = result[0]["tool_calls"][0]
- assert tc["type"] == "function"
-
- def test_adds_id_when_missing(self):
- msgs = [{"role": "assistant", "tool_calls": [
- {"function": {"name": "fn", "arguments": {}}}
- ]}]
- result = router.transform_tool_calls_to_openai(msgs)
- assert "id" in result[0]["tool_calls"][0]
-
- def test_converts_dict_arguments_to_string(self):
- msgs = [{"role": "assistant", "tool_calls": [
- {"function": {"name": "fn", "arguments": {"key": "val"}}}
- ]}]
- result = router.transform_tool_calls_to_openai(msgs)
- args = result[0]["tool_calls"][0]["function"]["arguments"]
- assert isinstance(args, str)
- import orjson
- parsed = orjson.loads(args)
- assert parsed == {"key": "val"}
-
- def test_keeps_string_arguments_unchanged(self):
- msgs = [{"role": "assistant", "tool_calls": [
- {"function": {"name": "fn", "arguments": '{"key": "val"}'}}
- ]}]
- result = router.transform_tool_calls_to_openai(msgs)
- args = result[0]["tool_calls"][0]["function"]["arguments"]
- assert args == '{"key": "val"}'
-
- def test_links_tool_call_id_to_tool_response(self):
- msgs = [
- {"role": "assistant", "tool_calls": [
- {"function": {"name": "get_weather", "arguments": {}}}
- ]},
- {"role": "tool", "name": "get_weather", "content": "sunny"},
- ]
- result = router.transform_tool_calls_to_openai(msgs)
- tc_id = result[0]["tool_calls"][0]["id"]
- assert result[1].get("tool_call_id") == tc_id
-
- def test_non_tool_messages_unchanged(self):
- msgs = [{"role": "user", "content": "hello"}]
- result = router.transform_tool_calls_to_openai(msgs)
- assert result == msgs
-
-
-class TestStripImagesFromMessages:
- def test_removes_image_url_parts(self):
- msgs = [{"role": "user", "content": [
- {"type": "text", "text": "what is this?"},
- {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
- ]}]
- result = router._strip_images_from_messages(msgs)
- content = result[0]["content"]
- assert content == "what is this?"
-
- def test_keeps_text_only_messages(self):
- msgs = [{"role": "user", "content": "plain text"}]
- result = router._strip_images_from_messages(msgs)
- assert result[0]["content"] == "plain text"
-
- def test_multiple_text_parts_kept_as_list(self):
- msgs = [{"role": "user", "content": [
- {"type": "text", "text": "part one"},
- {"type": "text", "text": "part two"},
- {"type": "image_url", "image_url": {"url": "data:..."}},
- ]}]
- result = router._strip_images_from_messages(msgs)
- content = result[0]["content"]
- assert isinstance(content, list)
- assert len(content) == 2
-
- def test_all_images_removed_empty_list(self):
- msgs = [{"role": "user", "content": [
- {"type": "image_url", "image_url": {"url": "data:..."}},
- ]}]
- result = router._strip_images_from_messages(msgs)
- # Image-only content becomes empty list
- content = result[0]["content"]
- assert content == []
-
-
-class TestAccumulateOpenAITcDelta:
- def _make_chunk(self, index, name=None, args_fragment="", tc_id=None):
- delta = MagicMock()
- tc = MagicMock()
- tc.index = index
- tc.id = tc_id
- tc.function = MagicMock()
- tc.function.name = name
- tc.function.arguments = args_fragment
- delta.tool_calls = [tc]
- chunk = MagicMock()
- chunk.choices = [MagicMock(delta=delta)]
- return chunk
-
- def test_first_delta_creates_entry(self):
- acc = {}
- chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"')
- router._accumulate_openai_tc_delta(chunk, acc)
- assert 0 in acc
- assert acc[0]["name"] == "my_fn"
- assert acc[0]["arguments"] == '{"k"'
-
- def test_subsequent_deltas_concatenate_args(self):
- acc = {}
- router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc)
- router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc)
- assert acc[0]["arguments"] == '{"k": "v"}'
-
- def test_multiple_tool_calls_tracked_separately(self):
- acc = {}
- c1 = self._make_chunk(0, name="fn1", args_fragment="{}")
- c2 = self._make_chunk(1, name="fn2", args_fragment="{}")
- chunk = MagicMock()
- tc1 = MagicMock()
- tc1.index = 0
- tc1.id = "id1"
- tc1.function = MagicMock(name="fn1", arguments="{}")
- tc2 = MagicMock()
- tc2.index = 1
- tc2.id = "id2"
- tc2.function = MagicMock(name="fn2", arguments="{}")
- chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))]
- router._accumulate_openai_tc_delta(chunk, acc)
- assert 0 in acc and 1 in acc
-
- def test_no_choices_is_noop(self):
- acc = {}
- chunk = MagicMock(choices=[])
- router._accumulate_openai_tc_delta(chunk, acc)
- assert acc == {}
-
-
-class TestBuildOllamaToolCalls:
- def test_builds_from_accumulator(self):
- acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}}
- result = router._build_ollama_tool_calls(acc)
- assert result is not None
- assert len(result) == 1
- assert result[0].function.name == "get_weather"
- assert result[0].function.arguments == {"city": "Berlin"}
-
- def test_invalid_json_args_becomes_empty_dict(self):
- acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}}
- result = router._build_ollama_tool_calls(acc)
- assert result[0].function.arguments == {}
-
- def test_empty_accumulator_returns_none(self):
- assert router._build_ollama_tool_calls({}) is None
-
- def test_preserves_order_by_index(self):
- acc = {
- 1: {"id": "c2", "name": "fn2", "arguments": "{}"},
- 0: {"id": "c1", "name": "fn1", "arguments": "{}"},
- }
- result = router._build_ollama_tool_calls(acc)
- assert result[0].function.name == "fn1"
- assert result[1].function.name == "fn2"