diff --git a/.forgejo/workflows/docker-publish-semantic.yml b/.forgejo/workflows/docker-publish-semantic.yml
index ebbfd36..163f1a1 100644
--- a/.forgejo/workflows/docker-publish-semantic.yml
+++ b/.forgejo/workflows/docker-publish-semantic.yml
@@ -86,7 +86,7 @@ jobs:
provenance: false
build-args: |
SEMANTIC_CACHE=true
- tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}
+ tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}-${{ github.run_id }}
merge:
runs-on: docker-amd64
@@ -142,6 +142,6 @@ jobs:
run: |
docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64 \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64-${{ github.run_id }} \
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64-${{ github.run_id }}
diff --git a/.forgejo/workflows/docker-publish.yml b/.forgejo/workflows/docker-publish.yml
index 27cd879..09e145c 100644
--- a/.forgejo/workflows/docker-publish.yml
+++ b/.forgejo/workflows/docker-publish.yml
@@ -77,7 +77,7 @@ jobs:
platforms: ${{ matrix.platform }}
push: true
provenance: false
- tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}
+ tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}-${{ github.run_id }}
merge:
runs-on: docker-amd64
@@ -133,6 +133,6 @@ jobs:
run: |
docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64 \
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64-${{ github.run_id }} \
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64-${{ github.run_id }}
diff --git a/.forgejo/workflows/pr-tests.yml b/.forgejo/workflows/pr-tests.yml
new file mode 100644
index 0000000..aa96b84
--- /dev/null
+++ b/.forgejo/workflows/pr-tests.yml
@@ -0,0 +1,39 @@
+name: PR Tests
+on: [pull_request]
+jobs:
+ test:
+ runs-on: docker-arm64
+ container:
+ image: python:3.12-slim
+ env:
+ CMAKE_BUILD_PARALLEL_LEVEL: "4"
+ steps:
+ - name: Install system deps
+ run: |
+ apt-get update
+ apt-get install -y --no-install-recommends \
+ git ca-certificates \
+ build-essential pkg-config
+ rm -rf /var/lib/apt/lists/*
+ - name: Checkout
+ run: |
+ git config --global --add safe.directory "$PWD"
+ git clone --depth=1 \
+ "https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" .
+ git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr"
+ git checkout pr
+ - name: Fetch action source
+ run: |
+ git clone --depth=1 --branch master \
+ "https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \
+ ./.run-tests
+ - uses: ./.run-tests/run-tests
+ with:
+ setup: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install -r test/requirements_test.txt
+ command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml
+ artifacts-path: |
+ report.xml
+ coverage.xml
diff --git a/.gitignore b/.gitignore
index cfce37c..7cd8431 100644
--- a/.gitignore
+++ b/.gitignore
@@ -66,7 +66,4 @@ config.yaml
# SQLite
*.db*
-*settings.json
-
-# Test suite (local only, not committed yet)
-test/
\ No newline at end of file
+*settings.json
\ No newline at end of file
diff --git a/doc/architecture.md b/doc/architecture.md
index f725573..c2408d8 100644
--- a/doc/architecture.md
+++ b/doc/architecture.md
@@ -206,6 +206,8 @@ The `/health` endpoint provides comprehensive health status:
}
```
+For Ollama endpoints the probe is a parallel check of `/api/version` (liveness) and `/api/ps` (the route used by `choose_endpoint` when selecting a backend for a request). Reporting `ok` only when both succeed prevents the router from advertising an endpoint as healthy while completion calls dead-end on `/api/ps`. The same dual probe backs `/api/config`, which the dashboard uses to render endpoint health.
+
## Database Schema
The router uses SQLite for persistent storage:
diff --git a/doc/monitoring.md b/doc/monitoring.md
index ab75d25..9ce25ec 100644
--- a/doc/monitoring.md
+++ b/doc/monitoring.md
@@ -29,6 +29,10 @@ Response:
- `200`: All endpoints healthy
- `503`: One or more endpoints unhealthy
+**Probe scope per endpoint**:
+- **Ollama endpoints** are probed at both `/api/version` (liveness) and `/api/ps` (model-introspection used by the router). If either fails the endpoint is reported as `error`; the response still includes `version` when the daemon is reachable so operators can tell a partial failure from a full outage. The `detail` field names the failing probe, e.g. `"/api/ps: 502 …"`.
+- **OpenAI-compatible / llama-server endpoints** are probed at `/models`.
+
### Current Usage
```bash
@@ -133,6 +137,8 @@ Response:
}
```
+Uses the same dual-probe logic as `/health` (Ollama: `/api/version` + `/api/ps`; OpenAI-compatible: `/models`). An endpoint will report `error` whenever either probe fails. The dashboard renders the `detail` field as a tooltip on the status cell.
+
### Cache Statistics
```bash
diff --git a/requirements.txt b/requirements.txt
index c187485..8353c44 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,7 +6,7 @@ anyio==4.13.0
async-timeout==5.0.1
attrs==26.1.0
certifi==2026.4.22
-click==8.3.3
+click==8.4.0
distro==1.9.0
exceptiongroup==1.3.1
fastapi==0.136.1
@@ -16,10 +16,10 @@ h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
idna==3.15
-jiter==0.14.0
+jiter==0.15.0
multidict==6.7.1
ollama==0.6.2
-openai==1.109.1
+openai==2.37.0
orjson>=3.11.5
numpy>=1.26
pillow==12.2.0
@@ -30,7 +30,7 @@ pydantic_core==2.46.4
python-dotenv==1.2.2
PyYAML==6.0.3
sniffio==1.3.1
-starlette==0.52.1
+starlette==1.0.1
truststore==0.10.4
tiktoken==0.13.0
tqdm==4.67.3
diff --git a/router.py b/router.py
index 08225fb..c465ebc 100644
--- a/router.py
+++ b/router.py
@@ -1000,7 +1000,7 @@ class fetch:
async with client.get(f"{endpoint}/models") as resp:
await _ensure_success(resp)
data = await resp.json()
-
+
# Filter for loaded models only
items = data.get("data", [])
models = {
@@ -1012,11 +1012,19 @@ class fetch:
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
+ # Probe succeeded — clear any stale error so the endpoint
+ # becomes routable again.
+ async with _loaded_error_cache_lock:
+ _loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/models", e)
print(f"[fetch.loaded_models] {message}")
+ # Record the failure so `choose_endpoint` can avoid routing
+ # to an unhealthy backend and repeated probes short-circuit.
+ async with _loaded_error_cache_lock:
+ _loaded_error_cache[endpoint] = time.time()
return set()
else:
# Original Ollama /api/ps logic
@@ -1031,11 +1039,15 @@ class fetch:
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
+ async with _loaded_error_cache_lock:
+ _loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/api/ps", e)
print(f"[fetch.loaded_models] {message}")
+ async with _loaded_error_cache_lock:
+ _loaded_error_cache[endpoint] = time.time()
return set()
async def _refresh_loaded_models(endpoint: str) -> None:
@@ -1853,6 +1865,28 @@ async def choose_endpoint(model: str, reserve: bool = True,
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
+ # 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing
+ # recently. Without this filter, an endpoint where `/api/ps` returns 5xx
+ # would appear with an empty loaded set but pass through to the
+ # free-slot fallback (step 4) — sending completion calls to an
+ # unhealthy backend. See issue #83.
+ async with _loaded_error_cache_lock:
+ unhealthy = {
+ ep for ep, ts in _loaded_error_cache.items()
+ if _is_fresh(ts, 300)
+ }
+ if unhealthy:
+ filtered = [
+ (ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
+ if ep not in unhealthy
+ ]
+ if filtered:
+ candidate_endpoints = [ep for ep, _ in filtered]
+ loaded_sets = [models for _, models in filtered]
+ # If *every* candidate is unhealthy we still fall through with the
+ # original list — refusing to route is worse than retrying a
+ # possibly-recovered backend.
+
# Look up a possible affinity hint *before* taking usage_lock. The two
# locks are never held together to avoid lock-ordering issues.
affine_ep: Optional[str] = None
@@ -3154,44 +3188,103 @@ async def usage_proxy(request: Request):
"token_usage_counts": token_usage_counts}
# -------------------------------------------------------------
-# 20. Proxy config route – for monitoring and frontent usage
+# 20. Endpoint health probes (shared by /api/config and /health)
+# -------------------------------------------------------------
+async def _raw_probe(
+ ep: str,
+ route: str,
+ api_key: Optional[str] = None,
+ timeout: Optional[float] = None,
+) -> tuple[bool, object]:
+ """Direct HTTP probe that distinguishes success from failure
+ (unlike `fetch.endpoint_details`, which returns [] on either).
+ Returns `(ok, payload_or_error_message)`.
+ """
+ headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
+ if api_key is not None:
+ headers["Authorization"] = "Bearer " + api_key
+ url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
+ req_kwargs = {}
+ if timeout is not None:
+ req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
+ try:
+ client: aiohttp.ClientSession = get_session(ep)
+ async with client.get(url, headers=headers, **req_kwargs) as resp:
+ await _ensure_success(resp)
+ data = await resp.json()
+ return True, data
+ except Exception as exc:
+ return False, _format_connection_issue(url, exc)
+
+
+async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
+ """Probe an endpoint and return `{status, version?, detail?}`.
+
+ Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
+ that a daemon which is reachable but has a broken model-introspection
+ path (issue #83) is reported as `error` rather than `ok`.
+ OpenAI-compatible endpoints use a single `/models` probe.
+ """
+ if is_openai_compatible(ep):
+ ok, payload = await _raw_probe(
+ ep, "/models", config.api_keys.get(ep), timeout=timeout,
+ )
+ if ok:
+ return {"status": "ok", "version": "latest"}
+ return {"status": "error", "detail": str(payload)}
+
+ (version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
+ _raw_probe(ep, "/api/version", timeout=timeout),
+ _raw_probe(ep, "/api/ps", timeout=timeout),
+ )
+
+ version_value = (
+ version_payload.get("version")
+ if version_ok and isinstance(version_payload, dict)
+ else None
+ )
+
+ if version_ok and ps_ok:
+ return {"status": "ok", "version": version_value}
+ if not version_ok and not ps_ok:
+ return {"status": "error", "detail": str(version_payload)}
+ # Partial failure — daemon reachable but one probe failed. Report
+ # as "error" so callers can surface the issue; include `version` so
+ # the operator knows the daemon itself is alive.
+ if not ps_ok:
+ return {
+ "status": "error",
+ "version": version_value,
+ "detail": f"/api/ps: {ps_payload}",
+ }
+ return {
+ "status": "error",
+ "detail": f"/api/version: {version_payload}",
+ }
+
+
+# -------------------------------------------------------------
+# 20b. Proxy config route – for monitoring and frontend usage
# -------------------------------------------------------------
@app.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
- Ollama endpoints and llama_server_endpoints. The front‑end uses this to display
- which endpoints are being proxied.
+ Ollama endpoints and llama_server_endpoints. The front‑end uses this
+ to display which endpoints are being proxied and their health.
+ Status is "error" when either liveness (/api/version) or routing
+ health (/api/ps) fails — see issue #83.
"""
- async def check_endpoint(url: str):
- client: aiohttp.ClientSession = get_session(url)
- headers = None
- if "/v1" in url:
- headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
- target_url = f"{url}/models"
- else:
- target_url = f"{url}/api/version"
+ async def check(url: str) -> dict:
+ return {"url": url, **(await _endpoint_health(url, timeout=5))}
- try:
- async with client.get(target_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
- await _ensure_success(resp)
- data = await resp.json()
- if "/v1" in url:
- return {"url": url, "status": "ok", "version": "latest"}
- else:
- return {"url": url, "status": "ok", "version": data.get("version")}
- except Exception as e:
- detail = _format_connection_issue(target_url, e)
- return {"url": url, "status": "error", "detail": detail}
-
- # Check Ollama endpoints
- ollama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
-
- # Check llama-server endpoints
+ ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
llama_results = []
if config.llama_server_endpoints:
- llama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.llama_server_endpoints])
-
+ llama_results = await asyncio.gather(
+ *[check(ep) for ep in config.llama_server_endpoints]
+ )
+
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
@@ -4003,44 +4096,30 @@ async def health_proxy(request: Request):
"""
Health‑check endpoint for monitoring the proxy.
- * Queries each configured endpoint for its `/api/version` response.
+ * Queries each configured endpoint for both liveness and routing health:
+ Ollama endpoints are probed at `/api/version` AND `/api/ps`,
+ OpenAI-compatible endpoints at `/models`.
* Returns a JSON object containing:
- - `status`: "ok" if every endpoint replied, otherwise "error".
+ - `status`: "ok" if every endpoint replied to every probe, otherwise "error".
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
# Run all health checks in parallel.
- # Ollama endpoints expose /api/version; OpenAI-compatible endpoints (vLLM,
- # llama-server, external) expose /models. Using /api/version against an
- # OpenAI-compatible endpoint yields a 404 and noisy log output.
+ # Ollama endpoints expose /api/version (liveness) and /api/ps (routing
+ # health — required by `choose_endpoint`). OpenAI-compatible endpoints
+ # (vLLM, llama-server, external) expose /models, which serves both
+ # purposes. Probing /api/version alone would miss the case where the
+ # Ollama process is up but /api/ps is failing — see issue #83.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints += llama_eps_extra
- tasks = []
- for ep in all_endpoints:
- if is_openai_compatible(ep):
- tasks.append(fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True))
- else:
- tasks.append(fetch.endpoint_details(ep, "/api/version", "version", skip_error_cache=True))
+ probe_results = await asyncio.gather(
+ *(_endpoint_health(ep) for ep in all_endpoints),
+ )
- results = await asyncio.gather(*tasks, return_exceptions=True)
-
- health_summary = {}
- overall_ok = True
-
- for ep, result in zip(all_endpoints, results):
- if isinstance(result, Exception):
- # Endpoint did not respond / returned an error
- health_summary[ep] = {"status": "error", "detail": str(result)}
- overall_ok = False
- else:
- # Successful response – report the reported version (Ollama) or
- # indicate the endpoint is reachable (OpenAI-compatible).
- if is_openai_compatible(ep):
- health_summary[ep] = {"status": "ok"}
- else:
- health_summary[ep] = {"status": "ok", "version": result}
+ health_summary = dict(zip(all_endpoints, probe_results))
+ overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
response_payload = {
"status": "ok" if overall_ok else "error",
diff --git a/static/index.html b/static/index.html
index b29f22b..8c0b16c 100644
--- a/static/index.html
+++ b/static/index.html
@@ -192,6 +192,10 @@
color: #8b0000;
font-weight: bold;
}
+ .status-error[title] {
+ cursor: help;
+ text-decoration: underline dotted;
+ }
.copy-link,
.delete-link,
.show-link,
@@ -736,6 +740,16 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return await resp.json();
}
+ function escapeHtml(value) {
+ if (value === null || value === undefined) return "";
+ return String(value)
+ .replace(/&/g, "&")
+ .replace(//g, ">")
+ .replace(/"/g, """)
+ .replace(/'/g, "'");
+ }
+
function toggleDarkMode() {
document.documentElement.classList.toggle("dark-mode");
}
@@ -752,40 +766,24 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
// Build HTML for both endpoints and llama_server_endpoints
let html = "";
- // Add Ollama endpoints
- html += data.endpoints
- .map((e) => {
- const statusClass =
- e.status === "ok"
- ? "status-ok"
- : "status-error";
- const version = e.version || "N/A";
- return `
+ const renderRow = (e) => {
+ const statusClass =
+ e.status === "ok" ? "status-ok" : "status-error";
+ const version = e.version || "N/A";
+ const titleAttr = e.detail
+ ? ` title="${escapeHtml(e.detail)}"`
+ : "";
+ return `
- | ${e.url} |
- ${e.status} |
- ${version} |
+ ${escapeHtml(e.url)} |
+ ${escapeHtml(e.status)} |
+ ${escapeHtml(version)} |
`;
- })
- .join("");
-
- // Add llama-server endpoints
+ };
+
+ html += data.endpoints.map(renderRow).join("");
if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) {
- html += data.llama_server_endpoints
- .map((e) => {
- const statusClass =
- e.status === "ok"
- ? "status-ok"
- : "status-error";
- const version = e.version || "N/A";
- return `
-
- | ${e.url} |
- ${e.status} |
- ${version} |
-
`;
- })
- .join("");
+ html += data.llama_server_endpoints.map(renderRow).join("");
}
body.innerHTML = html;
diff --git a/test/config_test.yaml b/test/config_test.yaml
new file mode 100644
index 0000000..30f2fa3
--- /dev/null
+++ b/test/config_test.yaml
@@ -0,0 +1,13 @@
+endpoints:
+ - http://192.168.0.51:12434
+
+llama_server_endpoints:
+ - http://192.168.0.51:12434/v1
+
+max_concurrent_connections: 2
+
+api_keys:
+ "http://192.168.0.51:12434": "ollama"
+ "http://192.168.0.51:12434/v1": "llama"
+
+cache_enabled: false
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..c5142da
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,233 @@
+"""
+Test configuration for nomyo-router.
+
+Run from project root:
+ pytest test/ -v
+ pytest test/ -m "not integration" # skip real-server tests
+ pytest test/ -m integration -v # only real-server tests
+
+Environment variables:
+ NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434)
+ NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1)
+ NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset)
+ NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset)
+"""
+import asyncio
+import os
+import ssl
+import sys
+import tempfile
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import aiohttp
+import httpx
+import pytest
+
+_TEST_DIR = Path(__file__).parent
+# Must be set before importing router so module-level Config.from_yaml + Config field
+# defaults pick these up. db_path is intentionally absent from config_test.yaml so the
+# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows).
+os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml"))
+os.environ.setdefault(
+ "NOMYO_ROUTER_DB_PATH",
+ str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"),
+)
+
+sys.path.insert(0, str(_TEST_DIR.parent))
+
+import router # noqa: E402
+
+TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434")
+TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1")
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers",
+ "integration: tests that require a real backend at 192.168.0.50:12434",
+ )
+
+
+# ── Config mocks ─────────────────────────────────────────────────────────────
+
+@pytest.fixture
+def mock_config():
+ """Minimal config pointing at TEST_OLLAMA / TEST_LLAMA."""
+ cfg = MagicMock()
+ cfg.endpoints = [TEST_OLLAMA]
+ cfg.llama_server_endpoints = [TEST_LLAMA]
+ cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
+ cfg.max_concurrent_connections = 2
+ cfg.router_api_key = None
+ cfg.cache_enabled = False
+ return cfg
+
+
+@pytest.fixture
+def mock_config_no_llama():
+ """Config with Ollama only, no llama-server."""
+ cfg = MagicMock()
+ cfg.endpoints = [TEST_OLLAMA]
+ cfg.llama_server_endpoints = []
+ cfg.api_keys = {TEST_OLLAMA: "ollama"}
+ cfg.max_concurrent_connections = 2
+ cfg.router_api_key = None
+ cfg.cache_enabled = False
+ return cfg
+
+
+@pytest.fixture
+def mock_config_with_key():
+ """Config with router_api_key set (enables auth middleware)."""
+ cfg = MagicMock()
+ cfg.endpoints = [TEST_OLLAMA]
+ cfg.llama_server_endpoints = []
+ cfg.api_keys = {}
+ cfg.max_concurrent_connections = 2
+ cfg.router_api_key = "test-secret-key"
+ cfg.cache_enabled = False
+ return cfg
+
+
+# ── aiohttp session (used by fetch tests + choose_endpoint tests) ─────────────
+
+@pytest.fixture
+async def aio_session():
+ """Real aiohttp session stored in app_state; intercepted by aioresponses."""
+ ssl_ctx = ssl.create_default_context()
+ conn = aiohttp.TCPConnector(ssl=ssl_ctx)
+ session = aiohttp.ClientSession(connector=conn)
+ router.app_state["session"] = session
+
+ # Clear caches to prevent test bleed
+ router._models_cache.clear()
+ router._loaded_models_cache.clear()
+ router._available_error_cache.clear()
+ router._loaded_error_cache.clear()
+ router._inflight_available_models.clear()
+ router._inflight_loaded_models.clear()
+ router._bg_refresh_available.clear()
+ router._bg_refresh_loaded.clear()
+
+ yield session
+
+ await session.close()
+ router.app_state["session"] = None
+
+
+# ── Validation-only HTTP client (no real backend needed) ──────────────────────
+
+@pytest.fixture
+async def client(mock_config, tmp_path):
+ """httpx client for validation/auth tests — no real backend calls made."""
+ from db import TokenDatabase
+
+ ssl_ctx = ssl.create_default_context()
+ conn = aiohttp.TCPConnector(ssl=ssl_ctx)
+ session = aiohttp.ClientSession(connector=conn)
+
+ db_inst = TokenDatabase(str(tmp_path / "test.db"))
+ await db_inst.init_db()
+
+ old_session = router.app_state.get("session")
+ old_db = router.db
+
+ router.app_state["session"] = session
+ router.db = db_inst
+
+ with patch.object(router, "config", mock_config):
+ transport = httpx.ASGITransport(app=router.app)
+ async with httpx.AsyncClient(
+ transport=transport, base_url="http://test", timeout=10.0
+ ) as c:
+ yield c
+
+ await session.close()
+ router.app_state["session"] = old_session
+ router.db = old_db
+
+
+@pytest.fixture
+async def client_auth(mock_config_with_key, tmp_path):
+ """httpx client with router_api_key configured (for auth middleware tests)."""
+ from db import TokenDatabase
+
+ ssl_ctx = ssl.create_default_context()
+ conn = aiohttp.TCPConnector(ssl=ssl_ctx)
+ session = aiohttp.ClientSession(connector=conn)
+
+ db_inst = TokenDatabase(str(tmp_path / "test_auth.db"))
+ await db_inst.init_db()
+
+ old_session = router.app_state.get("session")
+ old_db = router.db
+
+ router.app_state["session"] = session
+ router.db = db_inst
+
+ with patch.object(router, "config", mock_config_with_key):
+ transport = httpx.ASGITransport(app=router.app)
+ async with httpx.AsyncClient(
+ transport=transport, base_url="http://test", timeout=10.0
+ ) as c:
+ yield c
+
+ await session.close()
+ router.app_state["session"] = old_session
+ router.db = old_db
+
+
+# ── Integration client (full startup with real backend) ──────────────────────
+
+@pytest.fixture(scope="module")
+async def integration_client():
+ """Full app startup pointing at the real test server."""
+ await router.startup_event()
+ transport = httpx.ASGITransport(app=router.app)
+ async with httpx.AsyncClient(
+ transport=transport,
+ base_url="http://test",
+ timeout=httpx.Timeout(60.0),
+ ) as c:
+ yield c
+ await router.shutdown_event()
+
+
+# ── Model discovery fixtures ──────────────────────────────────────────────────
+
+@pytest.fixture(scope="module")
+async def chat_model(integration_client):
+ """Return a chat/generation model name available on the test server."""
+ env_model = os.getenv("NOMYO_TEST_MODEL_CHAT")
+ if env_model:
+ return env_model
+ resp = await integration_client.get("/api/tags")
+ if resp.status_code != 200:
+ pytest.skip("Cannot reach test server")
+ models = resp.json().get("models", [])
+ # Prefer small models for faster tests
+ for m in models:
+ name = m.get("name", "")
+ if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]):
+ return name
+ if models:
+ return models[0]["name"]
+ pytest.skip("No chat models available on test server")
+
+
+@pytest.fixture(scope="module")
+async def embed_model(integration_client):
+ """Return an embedding model name available on the test server."""
+ env_model = os.getenv("NOMYO_TEST_EMBED_MODEL")
+ if env_model:
+ return env_model
+ resp = await integration_client.get("/api/tags")
+ if resp.status_code != 200:
+ pytest.skip("Cannot reach test server")
+ models = resp.json().get("models", [])
+ for m in models:
+ name = m.get("name", "")
+ if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]):
+ return name
+ pytest.skip("No embedding model available on test server")
diff --git a/test/pytest.ini b/test/pytest.ini
new file mode 100644
index 0000000..1d05e6d
--- /dev/null
+++ b/test/pytest.ini
@@ -0,0 +1,7 @@
+[pytest]
+asyncio_mode = auto
+markers =
+ integration: tests that require a real backend at 192.168.0.51:12434
+testpaths = .
+filterwarnings =
+ ignore::pytest.PytestUnhandledThreadExceptionWarning
diff --git a/test/requirements_test.txt b/test/requirements_test.txt
new file mode 100644
index 0000000..8c7c53f
--- /dev/null
+++ b/test/requirements_test.txt
@@ -0,0 +1,4 @@
+pytest>=8.0
+pytest-asyncio>=0.24
+pytest-cov>=5.0
+aioresponses>=0.7
diff --git a/test/test.md b/test/test.md
new file mode 100644
index 0000000..533a3ee
--- /dev/null
+++ b/test/test.md
@@ -0,0 +1,60 @@
+# Testing nomyo-router
+
+## Setup
+
+Install test dependencies (from the project root):
+
+```bash
+pip install -r test/requirements_test.txt
+```
+
+## Running tests
+
+All commands run from the `test/` directory:
+
+```bash
+cd test
+```
+
+**All non-integration tests** (no backend required):
+```bash
+pytest -m "not integration" -v
+```
+
+**Integration tests only** (requires backend at `192.168.0.51:12434`):
+```bash
+pytest -m integration -v
+```
+
+**Everything:**
+```bash
+pytest -v
+```
+
+## Test structure
+
+| File | What it covers | Backend needed |
+|---|---|---|
+| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No |
+| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No |
+| `test_unit_context.py` | Context window trimming logic | No |
+| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No |
+| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No |
+| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No |
+| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** |
+
+## Integration test backend
+
+Integration tests start the router in-process via `startup_event()` and route traffic
+through `httpx.ASGITransport` — no separately running router instance is needed.
+
+They do require a reachable Ollama or llama-server backend. Override the defaults via
+environment variables:
+
+```bash
+export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434
+export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise
+export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise
+```
+
+If the backend is unreachable, integration tests are automatically skipped.
diff --git a/test/test_api_integration.py b/test/test_api_integration.py
new file mode 100644
index 0000000..6c40fdc
--- /dev/null
+++ b/test/test_api_integration.py
@@ -0,0 +1,304 @@
+"""
+Integration tests against the real backend at 192.168.0.50:12434.
+
+Run with:
+ pytest test/test_api_integration.py -v -m integration
+
+All tests in this file are marked @pytest.mark.integration.
+They require the test server to be reachable and to have at least one
+chat model and one embedding model available.
+
+Env vars to pin specific models:
+ NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b
+ NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest
+"""
+import json
+
+import pytest
+
+
+pytestmark = pytest.mark.integration
+
+
+# ── Health / discovery routes ─────────────────────────────────────────────────
+
+class TestDiscoveryRoutes:
+ async def test_version(self, integration_client):
+ resp = await integration_client.get("/api/version")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "version" in data
+ assert isinstance(data["version"], str)
+
+ async def test_tags_returns_models(self, integration_client):
+ resp = await integration_client.get("/api/tags")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "models" in data
+ assert isinstance(data["models"], list)
+ assert len(data["models"]) > 0
+
+ async def test_ps_returns_list(self, integration_client):
+ resp = await integration_client.get("/api/ps")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "models" in data
+ assert isinstance(data["models"], list)
+
+ async def test_v1_models_returns_data(self, integration_client):
+ resp = await integration_client.get("/v1/models")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "data" in data
+ assert isinstance(data["data"], list)
+
+ async def test_usage_returns_counts(self, integration_client):
+ resp = await integration_client.get("/api/usage")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "usage_counts" in data
+ assert "token_usage_counts" in data
+
+ async def test_config_returns_endpoints(self, integration_client):
+ resp = await integration_client.get("/api/config")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "endpoints" in data
+
+ async def test_hostname(self, integration_client):
+ resp = await integration_client.get("/api/hostname")
+ assert resp.status_code == 200
+ assert "hostname" in resp.json()
+
+ async def test_health(self, integration_client):
+ resp = await integration_client.get("/health")
+ assert resp.status_code in (200, 503)
+ data = resp.json()
+ assert data["status"] in ("ok", "error")
+ assert "endpoints" in data
+
+ async def test_cache_stats(self, integration_client):
+ resp = await integration_client.get("/api/cache/stats")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "enabled" in data
+
+
+# ── /api/chat ─────────────────────────────────────────────────────────────────
+
+class TestApiChat:
+ async def test_non_streaming(self, integration_client, chat_model):
+ resp = await integration_client.post(
+ "/api/chat",
+ json={
+ "model": chat_model,
+ "stream": False,
+ "messages": [{"role": "user", "content": "Reply with exactly: OK"}],
+ "options": {"num_predict": 10},
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "message" in data
+ assert "content" in data["message"]
+
+ async def test_streaming_ndjson(self, integration_client, chat_model):
+ resp = await integration_client.post(
+ "/api/chat",
+ json={
+ "model": chat_model,
+ "stream": True,
+ "messages": [{"role": "user", "content": "Say hi"}],
+ "options": {"num_predict": 5},
+ },
+ )
+ assert resp.status_code == 200
+ lines = [l for l in resp.text.strip().split("\n") if l.strip()]
+ assert len(lines) >= 1
+ for line in lines:
+ obj = json.loads(line)
+ assert "model" in obj
+
+ async def test_non_streaming_has_token_counts(self, integration_client, chat_model):
+ resp = await integration_client.post(
+ "/api/chat",
+ json={
+ "model": chat_model,
+ "stream": False,
+ "messages": [{"role": "user", "content": "Count to 3"}],
+ "options": {"num_predict": 20},
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data.get("done") is True
+ # Token counts should be present in the final chunk
+ assert data.get("prompt_eval_count", 0) >= 0
+
+ async def test_system_message_honoured(self, integration_client, chat_model):
+ resp = await integration_client.post(
+ "/api/chat",
+ json={
+ "model": chat_model,
+ "stream": False,
+ "messages": [
+ {"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"},
+ {"role": "user", "content": "PING"},
+ ],
+ "options": {"num_predict": 10},
+ },
+ )
+ assert resp.status_code == 200
+ content = resp.json()["message"]["content"]
+ assert isinstance(content, str)
+ assert len(content) > 0
+
+
+# ── /api/generate ─────────────────────────────────────────────────────────────
+
+class TestApiGenerate:
+ async def test_non_streaming(self, integration_client, chat_model):
+ resp = await integration_client.post(
+ "/api/generate",
+ json={
+ "model": chat_model,
+ "prompt": "Complete: The sky is",
+ "stream": False,
+ "options": {"num_predict": 5},
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "response" in data
+
+ async def test_streaming(self, integration_client, chat_model):
+ resp = await integration_client.post(
+ "/api/generate",
+ json={
+ "model": chat_model,
+ "prompt": "One plus one equals",
+ "stream": True,
+ "options": {"num_predict": 5},
+ },
+ )
+ assert resp.status_code == 200
+ lines = [l for l in resp.text.strip().split("\n") if l.strip()]
+ assert len(lines) >= 1
+
+
+# ── /api/embed ────────────────────────────────────────────────────────────────
+
+class TestApiEmbed:
+ async def test_embed_single_string(self, integration_client, embed_model):
+ resp = await integration_client.post(
+ "/api/embed",
+ json={"model": embed_model, "input": "The quick brown fox"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "embeddings" in data
+ assert isinstance(data["embeddings"], list)
+ assert len(data["embeddings"]) == 1
+ assert len(data["embeddings"][0]) > 0
+
+ async def test_embed_multiple_inputs(self, integration_client, embed_model):
+ resp = await integration_client.post(
+ "/api/embed",
+ json={"model": embed_model, "input": ["sentence one", "sentence two"]},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "embeddings" in data
+ assert len(data["embeddings"]) == 2
+
+
+# ── /v1/chat/completions ──────────────────────────────────────────────────────
+
+class TestOpenAIChatCompletions:
+ async def test_non_streaming(self, integration_client, chat_model):
+ model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
+ resp = await integration_client.post(
+ "/v1/chat/completions",
+ json={
+ "model": model,
+ "messages": [{"role": "user", "content": "Reply OK"}],
+ "max_tokens": 10,
+ "stream": False,
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "choices" in data
+ assert len(data["choices"]) > 0
+ assert "message" in data["choices"][0]
+
+ async def test_streaming_sse(self, integration_client, chat_model):
+ model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
+ resp = await integration_client.post(
+ "/v1/chat/completions",
+ json={
+ "model": model,
+ "messages": [{"role": "user", "content": "Hi"}],
+ "max_tokens": 5,
+ "stream": True,
+ },
+ )
+ assert resp.status_code == 200
+ # Response should be SSE format
+ assert "data:" in resp.text or "[DONE]" in resp.text
+
+ async def test_non_streaming_has_usage(self, integration_client, chat_model):
+ model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
+ resp = await integration_client.post(
+ "/v1/chat/completions",
+ json={
+ "model": model,
+ "messages": [{"role": "user", "content": "Say yes"}],
+ "max_tokens": 5,
+ "stream": False,
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ if "usage" in data and data["usage"]:
+ assert data["usage"].get("prompt_tokens", 0) >= 0
+
+
+# ── /v1/embeddings ────────────────────────────────────────────────────────────
+
+class TestOpenAIEmbeddings:
+ async def test_single_input(self, integration_client, embed_model):
+ model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model
+ resp = await integration_client.post(
+ "/v1/embeddings",
+ json={"model": model, "input": "Test sentence"},
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "data" in data
+ assert len(data["data"]) > 0
+ embedding = data["data"][0].get("embedding")
+ assert isinstance(embedding, list)
+ assert len(embedding) > 0
+
+
+# ── Token counts (database-backed) ───────────────────────────────────────────
+
+class TestTokenCounts:
+ async def test_token_counts_endpoint(self, integration_client):
+ resp = await integration_client.get("/api/token_counts")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "total_tokens" in data
+ assert "breakdown" in data
+
+
+# ── ps_details (extended ps) ─────────────────────────────────────────────────
+
+class TestPsDetails:
+ async def test_ps_details_returns_models(self, integration_client):
+ resp = await integration_client.get("/api/ps_details")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "models" in data
+ assert isinstance(data["models"], list)
diff --git a/test/test_api_validation.py b/test/test_api_validation.py
new file mode 100644
index 0000000..5d2b52d
--- /dev/null
+++ b/test/test_api_validation.py
@@ -0,0 +1,230 @@
+"""
+HTTP-level validation and auth middleware tests.
+
+These tests use an in-process httpx client and never reach a real backend:
+all requests are rejected at the validation or auth layer before any
+endpoint-selection or upstream HTTP calls occur.
+"""
+import pytest
+
+
+class TestChatValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post(
+ "/api/chat",
+ json={"messages": [{"role": "user", "content": "hello"}]},
+ )
+ assert resp.status_code == 400
+ assert "model" in resp.json()["detail"].lower()
+
+ async def test_missing_messages_returns_400(self, client):
+ resp = await client.post("/api/chat", json={"model": "llama3.2"})
+ assert resp.status_code == 400
+
+ async def test_invalid_json_returns_400(self, client):
+ resp = await client.post(
+ "/api/chat",
+ content=b"not-json",
+ headers={"Content-Type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+ async def test_messages_not_list_returns_400(self, client):
+ resp = await client.post(
+ "/api/chat",
+ json={"model": "m", "messages": "not-a-list"},
+ )
+ assert resp.status_code == 400
+
+ async def test_options_not_dict_returns_400(self, client):
+ resp = await client.post(
+ "/api/chat",
+ json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"},
+ )
+ assert resp.status_code == 400
+
+
+class TestGenerateValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post("/api/generate", json={"prompt": "hello"})
+ assert resp.status_code == 400
+ assert "model" in resp.json()["detail"].lower()
+
+ async def test_missing_prompt_returns_400(self, client):
+ resp = await client.post("/api/generate", json={"model": "m"})
+ assert resp.status_code == 400
+ assert "prompt" in resp.json()["detail"].lower()
+
+ async def test_invalid_json_returns_400(self, client):
+ resp = await client.post(
+ "/api/generate",
+ content=b"{bad-json",
+ headers={"Content-Type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+
+class TestEmbedValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post("/api/embed", json={"input": "hello"})
+ assert resp.status_code == 400
+
+ async def test_missing_input_returns_400(self, client):
+ resp = await client.post("/api/embed", json={"model": "nomic-embed-text"})
+ assert resp.status_code == 400
+
+
+class TestEmbeddingsValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post("/api/embeddings", json={"prompt": "hello"})
+ assert resp.status_code == 400
+
+ async def test_missing_prompt_returns_400(self, client):
+ resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"})
+ assert resp.status_code == 400
+
+
+class TestOpenAIChatValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={"messages": [{"role": "user", "content": "hello"}]},
+ )
+ assert resp.status_code == 400
+
+ async def test_missing_messages_returns_400(self, client):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={"model": "gpt-4o"},
+ )
+ assert resp.status_code == 400
+
+ async def test_invalid_json_returns_400(self, client):
+ resp = await client.post(
+ "/v1/chat/completions",
+ content=b"}{",
+ headers={"Content-Type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+ async def test_svg_image_rejected(self, client):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "vision-model",
+ "messages": [{
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "describe"},
+ {"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}},
+ ],
+ }],
+ },
+ )
+ assert resp.status_code == 400
+ assert "svg" in resp.json()["detail"].lower()
+
+
+class TestOpenAICompletionsValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post("/v1/completions", json={"prompt": "hello"})
+ assert resp.status_code == 400
+
+ async def test_missing_prompt_returns_400(self, client):
+ resp = await client.post("/v1/completions", json={"model": "m"})
+ assert resp.status_code == 400
+
+
+class TestRerankValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post(
+ "/v1/rerank",
+ json={"query": "search query", "documents": ["doc1"]},
+ )
+ assert resp.status_code == 400
+
+ async def test_missing_query_returns_400(self, client):
+ resp = await client.post(
+ "/v1/rerank",
+ json={"model": "reranker", "documents": ["doc1"]},
+ )
+ assert resp.status_code == 400
+
+ async def test_empty_documents_returns_400(self, client):
+ resp = await client.post(
+ "/v1/rerank",
+ json={"model": "reranker", "query": "search", "documents": []},
+ )
+ assert resp.status_code == 400
+
+
+class TestShowValidation:
+ async def test_missing_model_returns_400(self, client):
+ resp = await client.post("/api/show", json={})
+ assert resp.status_code == 400
+
+
+class TestCopyValidation:
+ async def test_missing_source_returns_400(self, client):
+ resp = await client.post("/api/copy", json={"destination": "dst"})
+ assert resp.status_code == 400
+
+ async def test_missing_destination_returns_400(self, client):
+ resp = await client.post("/api/copy", json={"source": "src"})
+ assert resp.status_code == 400
+
+
+class TestDeleteValidation:
+ async def test_missing_model_returns_400(self, client):
+ import json as _json
+ resp = await client.request(
+ "DELETE",
+ "/api/delete",
+ content=_json.dumps({}).encode(),
+ headers={"Content-Type": "application/json"},
+ )
+ assert resp.status_code == 400
+
+
+class TestAuthMiddleware:
+ async def test_no_key_returns_401(self, client_auth):
+ resp = await client_auth.post(
+ "/api/chat",
+ json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.status_code == 401
+ assert "Missing" in resp.json()["detail"]
+
+ async def test_invalid_key_returns_403(self, client_auth):
+ resp = await client_auth.post(
+ "/api/chat",
+ headers={"Authorization": "Bearer wrong-key"},
+ json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
+ )
+ assert resp.status_code == 403
+ assert "Invalid" in resp.json()["detail"]
+
+ async def test_valid_key_passes_middleware(self, client_auth):
+ # /api/usage reads in-memory counters only — no backend call needed
+ resp = await client_auth.get(
+ "/api/usage",
+ headers={"Authorization": "Bearer test-secret-key"},
+ )
+ assert resp.status_code == 200
+
+ async def test_key_via_query_param(self, client_auth):
+ resp = await client_auth.get("/api/usage?api_key=test-secret-key")
+ assert resp.status_code == 200
+
+ async def test_options_bypasses_auth(self, client_auth):
+ resp = await client_auth.options("/api/chat")
+ assert resp.status_code not in (401, 403)
+
+ async def test_root_path_bypasses_auth(self, client_auth):
+ resp = await client_auth.get("/")
+ assert resp.status_code not in (401, 403)
+
+ async def test_favicon_bypasses_auth(self, client_auth):
+ resp = await client_auth.get("/favicon.ico")
+ # Should not be blocked by auth (may 404 in test but not 401/403)
+ assert resp.status_code not in (401, 403)
diff --git a/test/test_cache.py b/test/test_cache.py
new file mode 100644
index 0000000..f2ce1a9
--- /dev/null
+++ b/test/test_cache.py
@@ -0,0 +1,333 @@
+"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed)."""
+import tempfile
+from pathlib import Path
+from types import SimpleNamespace
+
+import orjson
+import pytest
+
+import cache as cache_mod
+from cache import (
+ LLMCache,
+ _bm25_weighted_text,
+ get_llm_cache,
+ init_llm_cache,
+ openai_nonstream_to_sse,
+)
+
+_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db")
+
+
+def _exact_cfg(backend: str = "memory") -> SimpleNamespace:
+ """Config for exact-match mode — similarity=1.0 avoids embedding deps."""
+ return SimpleNamespace(
+ cache_enabled=True,
+ cache_backend=backend,
+ cache_similarity=1.0,
+ cache_history_weight=0.3,
+ cache_ttl=300,
+ cache_db_path=_CACHE_DB_PATH,
+ cache_redis_url="redis://localhost:6379",
+ )
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# Pure helpers
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestBM25WeightedText:
+ def test_empty_history(self):
+ assert _bm25_weighted_text([]) == ""
+
+ def test_history_without_content(self):
+ assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == ""
+
+ def test_repeats_high_idf_terms(self):
+ history = [
+ {"role": "user", "content": "Tell me about quantum entanglement"},
+ {"role": "assistant", "content": "Quantum entanglement is a phenomenon"},
+ {"role": "user", "content": "How does entanglement work?"},
+ ]
+ out = _bm25_weighted_text(history)
+ # Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped
+ assert "entanglement" in out
+ assert "is" not in out.split()
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# openai_nonstream_to_sse
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestOpenAINonstreamToSSE:
+ def test_valid_chat_completion(self):
+ chat = {
+ "id": "x1",
+ "created": 123,
+ "model": "gpt-4o",
+ "choices": [{"message": {"role": "assistant", "content": "hello"}}],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
+ }
+ out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o")
+ text = out.decode()
+ assert text.startswith("data: ")
+ assert text.endswith("data: [DONE]\n\n")
+ # First chunk contains the original content
+ first = text.split("\n\n")[0][len("data: "):]
+ parsed = orjson.loads(first)
+ assert parsed["choices"][0]["delta"]["content"] == "hello"
+ assert parsed["usage"]["total_tokens"] == 3
+
+ def test_corrupt_bytes_return_done_only(self):
+ out = openai_nonstream_to_sse(b"not-json", "m")
+ assert out == b"data: [DONE]\n\n"
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# LLMCache internal helpers
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestLLMCacheParsing:
+ def test_namespace_is_stable_and_isolated(self):
+ c = LLMCache(_exact_cfg())
+ a = c._namespace("chat", "m1", "system A")
+ b = c._namespace("chat", "m1", "system A")
+ assert a == b
+ assert c._namespace("chat", "m1", "system B") != a
+ assert c._namespace("generate", "m1", "system A") != a
+ assert len(a) == 16
+
+ def test_parse_messages_flat_strings(self):
+ c = LLMCache(_exact_cfg())
+ sys, hist, last = c._parse_messages([
+ {"role": "system", "content": "be helpful"},
+ {"role": "user", "content": "hi"},
+ {"role": "assistant", "content": "hello"},
+ {"role": "user", "content": "what is 2+2?"},
+ ])
+ assert sys == "be helpful"
+ assert last == "what is 2+2?"
+ assert hist == [
+ {"role": "user", "content": "hi"},
+ {"role": "assistant", "content": "hello"},
+ ]
+
+ def test_parse_messages_multimodal_content(self):
+ c = LLMCache(_exact_cfg())
+ sys, _hist, last = c._parse_messages([
+ {"role": "system", "content": "sys"},
+ {"role": "user", "content": [
+ {"type": "text", "text": "describe"},
+ {"type": "image_url", "image_url": {"url": "data:..."}},
+ ]},
+ ])
+ assert sys == "sys"
+ assert last == "describe"
+
+ def test_parse_messages_no_user_message(self):
+ c = LLMCache(_exact_cfg())
+ sys, hist, last = c._parse_messages([
+ {"role": "system", "content": "sys only"},
+ ])
+ assert sys == "sys only"
+ assert last == ""
+ assert hist == []
+
+
+class TestPersonalTokenExtraction:
+ def test_email_extracted(self):
+ c = LLMCache(_exact_cfg())
+ toks = c._extract_personal_tokens("Reach me at alice@example.com please")
+ assert "alice@example.com" in toks
+
+ def test_numeric_id_after_keyword(self):
+ c = LLMCache(_exact_cfg())
+ toks = c._extract_personal_tokens("User id: 123456")
+ assert "123456" in toks
+
+ def test_identity_tag_names_extracted(self):
+ c = LLMCache(_exact_cfg())
+ toks = c._extract_personal_tokens(
+ "[Tags: identity] User's name is Andreas Schwibbe"
+ )
+ # Both name tokens should be extracted lowercased; stopwords dropped
+ assert "andreas" in toks
+ assert "schwibbe" in toks
+ assert "name" not in toks # in _IDENTITY_STOPWORDS
+ assert "user" not in toks
+
+ def test_empty_system_returns_empty_set(self):
+ c = LLMCache(_exact_cfg())
+ assert c._extract_personal_tokens("") == frozenset()
+
+
+class TestResponseIsPersonalized:
+ def _resp(self, content: str) -> bytes:
+ return orjson.dumps({"choices": [{"message": {"content": content}}]})
+
+ def test_email_in_response_is_personalized(self):
+ c = LLMCache(_exact_cfg())
+ assert c._response_is_personalized(self._resp("contact bob@x.com"), "")
+
+ def test_uuid_in_response_is_personalized(self):
+ c = LLMCache(_exact_cfg())
+ uuid = "550e8400-e29b-41d4-a716-446655440000"
+ assert c._response_is_personalized(self._resp(f"id={uuid}"), "")
+
+ def test_long_numeric_id_in_response_is_personalized(self):
+ c = LLMCache(_exact_cfg())
+ assert c._response_is_personalized(self._resp("account 12345678"), "")
+
+ def test_identity_token_from_system_echoed_in_response(self):
+ c = LLMCache(_exact_cfg())
+ system = "[Tags: identity] Andreas works here"
+ assert c._response_is_personalized(
+ self._resp("Yes, Andreas is logged in"), system
+ )
+
+ def test_generic_response_not_personalized(self):
+ c = LLMCache(_exact_cfg())
+ assert not c._response_is_personalized(
+ self._resp("The capital of France is Paris."), "be helpful"
+ )
+
+ def test_ollama_message_format_parsed(self):
+ c = LLMCache(_exact_cfg())
+ body = orjson.dumps({"message": {"content": "alice@example.com"}})
+ assert c._response_is_personalized(body, "")
+
+ def test_unparseable_body_with_bytes_is_conservative(self):
+ c = LLMCache(_exact_cfg())
+ # Can't parse → returns True (err on the side of privacy)
+ assert c._response_is_personalized(b"binary-junk", "")
+
+ def test_empty_response_not_personalized(self):
+ c = LLMCache(_exact_cfg())
+ assert not c._response_is_personalized(b"", "anything")
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# End-to-end exact-match cache with the memory backend
+# ──────────────────────────────────────────────────────────────────────────────
+
+@pytest.fixture
+async def memcache():
+ """LLMCache wired up with the in-memory backend (no external deps)."""
+ c = LLMCache(_exact_cfg("memory"))
+ await c.init()
+ return c
+
+
+class TestExactMatchCache:
+ async def test_miss_then_set_then_hit(self, memcache):
+ msgs = [
+ {"role": "system", "content": "be helpful"},
+ {"role": "user", "content": "what is 2+2?"},
+ ]
+ resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]})
+
+ assert await memcache.get_chat("chat", "m1", msgs) is None
+ await memcache.set_chat("chat", "m1", msgs, resp)
+ hit = await memcache.get_chat("chat", "m1", msgs)
+ assert hit == resp
+
+ async def test_namespace_isolation_by_system(self, memcache):
+ resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
+ msgs_a = [
+ {"role": "system", "content": "system A"},
+ {"role": "user", "content": "same question"},
+ ]
+ msgs_b = [
+ {"role": "system", "content": "system B"},
+ {"role": "user", "content": "same question"},
+ ]
+ await memcache.set_chat("chat", "m", msgs_a, resp)
+ # Same question + different system prompt = different namespace = miss
+ assert await memcache.get_chat("chat", "m", msgs_b) is None
+
+ async def test_namespace_isolation_by_route(self, memcache):
+ resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
+ msgs = [{"role": "user", "content": "ping"}]
+ await memcache.set_chat("chat", "m", msgs, resp)
+ assert await memcache.get_chat("openai_chat", "m", msgs) is None
+
+ async def test_no_user_message_is_noop(self, memcache):
+ msgs = [{"role": "system", "content": "sys only"}]
+ resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]})
+ # Both get and set should silently no-op
+ assert await memcache.get_chat("chat", "m", msgs) is None
+ await memcache.set_chat("chat", "m", msgs, resp)
+ assert await memcache.get_chat("chat", "m", msgs) is None
+
+ async def test_personalized_response_generic_system_not_stored(self, memcache):
+ msgs = [
+ {"role": "system", "content": "be helpful"}, # generic
+ {"role": "user", "content": "give me an email"},
+ ]
+ # Response contains an email → would leak across users sharing the
+ # generic namespace → must NOT be stored at all
+ resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]})
+ await memcache.set_chat("chat", "m", msgs, resp)
+ assert await memcache.get_chat("chat", "m", msgs) is None
+
+ async def test_personalized_response_user_specific_system_stored(self, memcache):
+ msgs = [
+ {"role": "system", "content": "User id: 998877 prefers concise answers"},
+ {"role": "user", "content": "what is my id?"},
+ ]
+ resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]})
+ await memcache.set_chat("chat", "m", msgs, resp)
+ # User-specific namespace → exact-match within this user is OK
+ assert await memcache.get_chat("chat", "m", msgs) == resp
+
+ async def test_generate_convenience_wrappers(self, memcache):
+ resp = orjson.dumps({"response": "blue"})
+ await memcache.set_generate("m", "what color is the sky?", "", resp)
+ assert await memcache.get_generate("m", "what color is the sky?") == resp
+
+
+class TestStatsAndClear:
+ async def test_stats_tracks_hits_and_misses(self, memcache):
+ msgs = [{"role": "user", "content": "hello"}]
+ await memcache.get_chat("chat", "m", msgs) # miss
+ resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]})
+ await memcache.set_chat("chat", "m", msgs, resp)
+ await memcache.get_chat("chat", "m", msgs) # hit
+ s = memcache.stats()
+ assert s["hits"] == 1
+ assert s["misses"] == 1
+ assert s["hit_rate"] == 0.5
+ assert s["semantic"] is False
+ assert s["backend"] == "memory"
+
+ async def test_clear_resets_counters_and_storage(self, memcache):
+ msgs = [{"role": "user", "content": "hi"}]
+ resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
+ await memcache.set_chat("chat", "m", msgs, resp)
+ await memcache.get_chat("chat", "m", msgs)
+ await memcache.clear()
+ s = memcache.stats()
+ assert s["hits"] == 0
+ assert s["misses"] == 0
+ assert await memcache.get_chat("chat", "m", msgs) is None
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# Module-level helpers
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestInitLLMCache:
+ async def test_disabled_returns_none(self):
+ cfg = _exact_cfg()
+ cfg.cache_enabled = False
+ result = await init_llm_cache(cfg)
+ assert result is None
+
+ async def test_enabled_returns_initialized_cache(self):
+ cfg = _exact_cfg()
+ try:
+ result = await init_llm_cache(cfg)
+ assert result is not None
+ assert get_llm_cache() is result
+ finally:
+ # Reset singleton between tests
+ cache_mod._cache = None
diff --git a/test/test_choose_endpoint.py b/test/test_choose_endpoint.py
new file mode 100644
index 0000000..ece609a
--- /dev/null
+++ b/test/test_choose_endpoint.py
@@ -0,0 +1,399 @@
+"""Tests for choose_endpoint routing logic with mocked fetch calls."""
+import time
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+import router
+
+EP1 = "http://ep1:11434"
+EP2 = "http://ep2:11434"
+EP3 = "http://ep3:11434"
+LLAMA_EP = "http://llama:8080/v1"
+
+
+def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
+ cfg = MagicMock()
+ cfg.endpoints = endpoints
+ cfg.llama_server_endpoints = llama_eps or []
+ cfg.api_keys = {}
+ cfg.max_concurrent_connections = max_conn
+ cfg.endpoint_config = endpoint_config or {}
+ cfg.priority_routing = priority_routing
+ cfg.router_api_key = None
+ return cfg
+
+
+@pytest.fixture(autouse=True)
+def reset_usage():
+ """Clear usage_counts and error caches between tests to prevent bleed."""
+ router.usage_counts.clear()
+ router._loaded_error_cache.clear()
+ yield
+ router.usage_counts.clear()
+ router._loaded_error_cache.clear()
+
+
+class TestChooseEndpointBasic:
+ async def test_selects_single_candidate(self):
+ cfg = _make_cfg([EP1])
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
+ ):
+ ep, tracking = await router.choose_endpoint("llama3.2:latest")
+ assert ep == EP1
+ assert tracking == "llama3.2:latest"
+
+ async def test_raises_when_no_endpoint_has_model(self):
+ cfg = _make_cfg([EP1, EP2])
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(return_value=set())),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
+ ):
+ with pytest.raises(RuntimeError, match="advertise the model"):
+ await router.choose_endpoint("unknown-model:latest")
+
+ async def test_prefers_loaded_endpoint(self):
+ cfg = _make_cfg([EP1, EP2])
+ async def available(ep, *_):
+ return {"llama3.2:latest"}
+
+ async def loaded(ep):
+ return {"llama3.2:latest"} if ep == EP2 else set()
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", side_effect=loaded),
+ ):
+ ep, _ = await router.choose_endpoint("llama3.2:latest")
+ assert ep == EP2
+
+ async def test_falls_back_to_free_slot(self):
+ cfg = _make_cfg([EP1, EP2])
+ async def available(ep, *_):
+ return {"llama3.2:latest"}
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
+ ):
+ ep, _ = await router.choose_endpoint("llama3.2:latest")
+ assert ep in (EP1, EP2)
+
+ async def test_saturated_picks_least_busy(self):
+ cfg = _make_cfg([EP1, EP2])
+ cfg.max_concurrent_connections = 1
+
+ async def available(ep, *_):
+ return {"llama3.2:latest"}
+
+ # Saturate EP1 with 2 active connections, EP2 with 1
+ router.usage_counts[EP1]["llama3.2:latest"] = 2
+ router.usage_counts[EP2]["llama3.2:latest"] = 1
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
+ ):
+ ep, _ = await router.choose_endpoint("llama3.2:latest")
+ # Least-busy is EP2
+ assert ep == EP2
+
+ async def test_excludes_endpoint_with_recent_loaded_error(self):
+ # Regression: issue #83 — when /api/ps fails for EP1 but EP1
+ # still advertises the model via /api/tags, routing must not
+ # fall back to EP1 just because it has a free slot.
+ cfg = _make_cfg([EP1, EP2])
+
+ async def available(ep, *_):
+ return {"llama3.2:latest"}
+
+ # EP1's /api/ps probe failed recently; EP2 is fine but the model
+ # is not loaded there. Without the health filter, EP1 would be
+ # picked by the free-slot fallback (step 4 in choose_endpoint).
+ router._loaded_error_cache[EP1] = time.time()
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
+ ):
+ ep, _ = await router.choose_endpoint("llama3.2:latest")
+ assert ep == EP2
+
+ async def test_stale_loaded_error_does_not_exclude(self):
+ # Errors older than the 300s window must not keep an endpoint
+ # excluded forever.
+ cfg = _make_cfg([EP1])
+ router._loaded_error_cache[EP1] = time.time() - 301
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"m:latest"})),
+ ):
+ ep, _ = await router.choose_endpoint("m:latest")
+ assert ep == EP1
+
+ async def test_all_unhealthy_still_routes(self):
+ # If every candidate has a fresh loaded-error we still try one
+ # (it may have recovered between the cache write and now) rather
+ # than refusing to route.
+ cfg = _make_cfg([EP1])
+ router._loaded_error_cache[EP1] = time.time()
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
+ ):
+ ep, _ = await router.choose_endpoint("m:latest")
+ assert ep == EP1
+
+ async def test_reserve_increments_usage(self):
+ cfg = _make_cfg([EP1])
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})),
+ ):
+ ep, tracking = await router.choose_endpoint("model:latest", reserve=True)
+ assert router.usage_counts[ep][tracking] == 1
+
+
+class TestChooseEndpointModelNaming:
+ async def test_strips_latest_for_openai_endpoints(self):
+ cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP])
+ cfg.endpoints = []
+
+ async def available(ep, *_):
+ # llama-server advertises without :latest
+ return {"gpt-4o"}
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})),
+ ):
+ ep, _ = await router.choose_endpoint("gpt-4o:latest")
+ assert ep == LLAMA_EP
+
+ async def test_adds_latest_for_ollama_when_bare_name(self):
+ cfg = _make_cfg([EP1])
+
+ async def available(ep, *_):
+ return {"llama3.2:latest"}
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
+ ):
+ ep, _ = await router.choose_endpoint("llama3.2")
+ assert ep == EP1
+
+
+class TestChooseEndpointLoadBalancing:
+ async def test_random_selection_among_idle(self):
+ cfg = _make_cfg([EP1, EP2, EP3])
+ selected = set()
+
+ async def available(ep, *_):
+ return {"model:latest"}
+
+ async def loaded(ep):
+ return {"model:latest"}
+
+ for _ in range(20):
+ router.usage_counts.clear()
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", side_effect=loaded),
+ ):
+ ep, _ = await router.choose_endpoint("model:latest", reserve=False)
+ selected.add(ep)
+
+ # With 20 draws from 3 idle endpoints, all three should appear
+ assert len(selected) > 1
+
+ async def test_sort_by_load_ascending(self):
+ cfg = _make_cfg([EP1, EP2])
+ router.usage_counts[EP1]["model:latest"] = 1
+ router.usage_counts[EP2]["model:latest"] = 0
+
+ async def available(ep, *_):
+ return {"model:latest"}
+
+ async def loaded(ep):
+ return {"model:latest"}
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", side_effect=available),
+ patch.object(router.fetch, "loaded_models", side_effect=loaded),
+ ):
+ ep, _ = await router.choose_endpoint("model:latest", reserve=False)
+ # EP2 has fewer active connections → should be selected
+ assert ep == EP2
+
+
+# ---------------------------------------------------------------------------
+# get_max_connections unit tests
+# ---------------------------------------------------------------------------
+
+class TestGetMaxConnections:
+ def test_returns_global_default_when_no_override(self):
+ cfg = _make_cfg([EP1, EP2], max_conn=3)
+ with patch.object(router, "config", cfg):
+ assert router.get_max_connections(EP1) == 3
+ assert router.get_max_connections(EP2) == 3
+
+ def test_returns_per_endpoint_override(self):
+ cfg = _make_cfg(
+ [EP1, EP2],
+ max_conn=2,
+ endpoint_config={EP1: {"max_concurrent_connections": 5}},
+ )
+ with patch.object(router, "config", cfg):
+ assert router.get_max_connections(EP1) == 5
+ assert router.get_max_connections(EP2) == 2 # falls back to global
+
+ def test_unrecognised_endpoint_falls_back_to_global(self):
+ cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}})
+ with patch.object(router, "config", cfg):
+ assert router.get_max_connections(EP3) == 4
+
+
+# ---------------------------------------------------------------------------
+# Priority / WRR routing tests
+# ---------------------------------------------------------------------------
+
+MODEL = "model:latest"
+
+
+def _all_loaded(ep):
+ """Side-effect: every endpoint advertises and has MODEL loaded."""
+ return {MODEL}
+
+
+class TestPriorityRouting:
+ """Tests for priority_routing=True (WRR + config-order tiebreaking)."""
+
+ async def test_idle_picks_first_in_config_order(self):
+ """When all endpoints are idle, priority picks the first listed endpoint."""
+ cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True)
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
+ patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
+ ):
+ ep, _ = await router.choose_endpoint(MODEL, reserve=False)
+ assert ep == EP1
+
+ async def test_lower_utilization_preferred_over_priority(self):
+ """An endpoint with lower ratio is preferred even if it has lower priority."""
+ cfg = _make_cfg([EP1, EP2], priority_routing=True)
+ # EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0
+ router.usage_counts[EP1][MODEL] = 1
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
+ patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
+ ):
+ ep, _ = await router.choose_endpoint(MODEL, reserve=False)
+ assert ep == EP2
+
+ async def test_wrr_distribution_matches_expected_sequence(self):
+ """
+ Full WRR sequence with heterogeneous capacities, mirroring the issue example:
+ EP1 max=2, EP2 max=2, EP3 max=1
+
+ Expected routing order for 5 sequential requests:
+ EP1, EP2, EP3, EP1, EP2
+ """
+ cfg = _make_cfg(
+ [EP1, EP2, EP3],
+ max_conn=2,
+ endpoint_config={EP3: {"max_concurrent_connections": 1}},
+ priority_routing=True,
+ )
+
+ expected = [EP1, EP2, EP3, EP1, EP2]
+ actual = []
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
+ patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
+ ):
+ for _ in expected:
+ ep, _ = await router.choose_endpoint(MODEL, reserve=True)
+ actual.append(ep)
+
+ assert actual == expected
+
+ async def test_saturated_picks_lowest_ratio_then_priority(self):
+ """When all endpoints are saturated, pick lowest utilization ratio; break ties by priority."""
+ cfg = _make_cfg(
+ [EP1, EP2, EP3],
+ max_conn=1,
+ endpoint_config={EP3: {"max_concurrent_connections": 2}},
+ priority_routing=True,
+ )
+ # EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins
+ router.usage_counts[EP1][MODEL] = 1
+ router.usage_counts[EP2][MODEL] = 1
+ router.usage_counts[EP3][MODEL] = 1
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
+ patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
+ ):
+ ep, _ = await router.choose_endpoint(MODEL, reserve=False)
+ assert ep == EP3
+
+ async def test_saturated_ties_broken_by_priority(self):
+ """When all are saturated with equal ratio, config order wins."""
+ cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True)
+ router.usage_counts[EP1][MODEL] = 1
+ router.usage_counts[EP2][MODEL] = 1
+ router.usage_counts[EP3][MODEL] = 1
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
+ patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
+ ):
+ ep, _ = await router.choose_endpoint(MODEL, reserve=False)
+ assert ep == EP1
+
+
+class TestPriorityRoutingDisabled:
+ """Verify that priority_routing=False keeps the original random behaviour."""
+
+ async def test_idle_endpoints_are_randomised(self):
+ """Without priority routing, all-idle selection must eventually pick each endpoint."""
+ cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False)
+ selected = set()
+
+ with (
+ patch.object(router, "config", cfg),
+ patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
+ patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
+ ):
+ for _ in range(30):
+ router.usage_counts.clear()
+ ep, _ = await router.choose_endpoint(MODEL, reserve=False)
+ selected.add(ep)
+
+ # With 30 draws from 3 equally-idle endpoints, all three must appear
+ assert selected == {EP1, EP2, EP3}
diff --git a/test/test_db.py b/test/test_db.py
new file mode 100644
index 0000000..833b375
--- /dev/null
+++ b/test/test_db.py
@@ -0,0 +1,197 @@
+"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
+from datetime import datetime, timezone
+
+import pytest
+
+from db import TokenDatabase
+
+
+@pytest.fixture
+async def db(tmp_path):
+ inst = TokenDatabase(str(tmp_path / "tokens.db"))
+ await inst.init_db()
+ yield inst
+ await inst.close()
+
+
+class TestInit:
+ async def test_init_creates_tables(self, db):
+ # Re-init must be idempotent
+ await db.init_db()
+ # Insert + read confirms tables exist
+ await db.update_token_counts("http://ep", "m", 1, 2)
+ rows = [r async for r in db.load_token_counts()]
+ assert len(rows) == 1
+
+ async def test_creates_parent_directory(self, tmp_path):
+ nested = tmp_path / "nested" / "subdir" / "x.db"
+ inst = TokenDatabase(str(nested))
+ await inst.init_db()
+ try:
+ assert nested.parent.exists()
+ finally:
+ await inst.close()
+
+
+class TestUpdateTokenCounts:
+ async def test_insert_then_update_aggregates(self, db):
+ await db.update_token_counts("http://ep", "m1", 10, 20)
+ await db.update_token_counts("http://ep", "m1", 5, 7)
+ rows = [r async for r in db.load_token_counts()]
+ assert len(rows) == 1
+ r = rows[0]
+ assert r["endpoint"] == "http://ep"
+ assert r["model"] == "m1"
+ assert r["input_tokens"] == 15
+ assert r["output_tokens"] == 27
+ assert r["total_tokens"] == 42
+
+ async def test_independent_endpoint_model_pairs(self, db):
+ await db.update_token_counts("http://ep1", "m1", 1, 1)
+ await db.update_token_counts("http://ep1", "m2", 2, 2)
+ await db.update_token_counts("http://ep2", "m1", 3, 3)
+ rows = [r async for r in db.load_token_counts()]
+ assert len(rows) == 3
+ totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
+ assert totals == {
+ ("http://ep1", "m1"): 2,
+ ("http://ep1", "m2"): 4,
+ ("http://ep2", "m1"): 6,
+ }
+
+
+class TestBatchedCounts:
+ async def test_update_batched_counts(self, db):
+ counts = {
+ "http://a": {"m": (4, 6)},
+ "http://b": {"m": (1, 1), "n": (10, 0)},
+ }
+ await db.update_batched_counts(counts)
+ rows = [r async for r in db.load_token_counts()]
+ totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
+ assert totals == {
+ ("http://a", "m"): 10,
+ ("http://b", "m"): 2,
+ ("http://b", "n"): 10,
+ }
+
+ async def test_empty_batch_is_noop(self, db):
+ await db.update_batched_counts({})
+ rows = [r async for r in db.load_token_counts()]
+ assert rows == []
+
+
+class TestTimeSeries:
+ async def test_add_time_series_entry(self, db):
+ # The aggregate FK requires the (endpoint,model) row to exist first
+ await db.update_token_counts("http://ep", "m", 0, 0)
+ await db.add_time_series_entry("http://ep", "m", 3, 4)
+ await db.add_time_series_entry("http://ep", "m", 1, 1)
+ rows = [r async for r in db.get_latest_time_series(limit=10)]
+ assert len(rows) == 2
+ # Newest-first ordering; both timestamps are within the same minute,
+ # so just check totals are present and well-formed
+ for r in rows:
+ assert r["endpoint"] == "http://ep"
+ assert r["model"] == "m"
+ assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
+
+ async def test_add_batched_time_series(self, db):
+ await db.update_token_counts("http://ep", "m", 0, 0)
+ now = int(datetime.now(tz=timezone.utc).timestamp())
+ entries = [
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 1,
+ "output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 4,
+ "output_tokens": 5, "total_tokens": 9, "timestamp": now},
+ ]
+ await db.add_batched_time_series(entries)
+ rows = [r async for r in db.get_latest_time_series(limit=10)]
+ assert len(rows) == 2
+ assert rows[0]["timestamp"] >= rows[1]["timestamp"]
+
+ async def test_get_time_series_for_model_filters(self, db):
+ await db.update_token_counts("http://ep", "m1", 0, 0)
+ await db.update_token_counts("http://ep", "m2", 0, 0)
+ now = int(datetime.now(tz=timezone.utc).timestamp())
+ await db.add_batched_time_series([
+ {"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
+ "output_tokens": 1, "total_tokens": 2, "timestamp": now},
+ {"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
+ "output_tokens": 9, "total_tokens": 18, "timestamp": now},
+ ])
+ rows = [r async for r in db.get_time_series_for_model("m1")]
+ assert len(rows) == 1
+ assert rows[0]["total_tokens"] == 2
+
+ async def test_endpoint_distribution_for_model(self, db):
+ await db.update_token_counts("http://a", "m", 0, 0)
+ await db.update_token_counts("http://b", "m", 0, 0)
+ now = int(datetime.now(tz=timezone.utc).timestamp())
+ await db.add_batched_time_series([
+ {"endpoint": "http://a", "model": "m", "input_tokens": 1,
+ "output_tokens": 1, "total_tokens": 2, "timestamp": now},
+ {"endpoint": "http://a", "model": "m", "input_tokens": 1,
+ "output_tokens": 1, "total_tokens": 2, "timestamp": now},
+ {"endpoint": "http://b", "model": "m", "input_tokens": 5,
+ "output_tokens": 5, "total_tokens": 10, "timestamp": now},
+ ])
+ dist = await db.get_endpoint_distribution_for_model("m")
+ assert dist == {"http://a": 4, "http://b": 10}
+
+
+class TestGetTokenCountsForModel:
+ async def test_aggregates_across_endpoints(self, db):
+ await db.update_token_counts("http://a", "m", 1, 2)
+ await db.update_token_counts("http://b", "m", 3, 4)
+ result = await db.get_token_counts_for_model("m")
+ assert result is not None
+ assert result["endpoint"] == "aggregated"
+ assert result["model"] == "m"
+ assert result["input_tokens"] == 4
+ assert result["output_tokens"] == 6
+ assert result["total_tokens"] == 10
+
+ async def test_unknown_model_returns_zero_aggregate(self, db):
+ # SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
+ result = await db.get_token_counts_for_model("nope")
+ assert result is not None
+ assert result["input_tokens"] in (0, None)
+
+
+class TestAggregateTimeSeriesOlderThan:
+ async def test_aggregates_old_entries_by_day(self, db):
+ await db.update_token_counts("http://ep", "m", 0, 0)
+ now = int(datetime.now(tz=timezone.utc).timestamp())
+ old = now - (40 * 86400) # 40 days ago
+ await db.add_batched_time_series([
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 1,
+ "output_tokens": 1, "total_tokens": 2, "timestamp": old},
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 3,
+ "output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 99,
+ "output_tokens": 99, "total_tokens": 198, "timestamp": now},
+ ])
+ n = await db.aggregate_time_series_older_than(30, trim_old=False)
+ assert n == 1 # one (endpoint, model, day) group rolled up
+
+ async def test_invalid_days_falls_back_to_30(self, db):
+ # Just ensure it doesn't blow up with a bogus value
+ n = await db.aggregate_time_series_older_than(0)
+ assert n == 0
+
+ async def test_trim_old_removes_aggregated_rows(self, db):
+ await db.update_token_counts("http://ep", "m", 0, 0)
+ now = int(datetime.now(tz=timezone.utc).timestamp())
+ old = now - (40 * 86400)
+ await db.add_batched_time_series([
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 1,
+ "output_tokens": 1, "total_tokens": 2, "timestamp": old},
+ {"endpoint": "http://ep", "model": "m", "input_tokens": 99,
+ "output_tokens": 99, "total_tokens": 198, "timestamp": now},
+ ])
+ await db.aggregate_time_series_older_than(30, trim_old=True)
+ remaining = [r async for r in db.get_latest_time_series(limit=10)]
+ # Only the recent (within-cutoff) row should remain
+ assert len(remaining) == 1
+ assert remaining[0]["total_tokens"] == 198
diff --git a/test/test_fetch.py b/test/test_fetch.py
new file mode 100644
index 0000000..6f2ed50
--- /dev/null
+++ b/test/test_fetch.py
@@ -0,0 +1,210 @@
+"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking."""
+import time
+from unittest.mock import patch, MagicMock
+
+import pytest
+from aioresponses import aioresponses
+
+import router
+from conftest import TEST_OLLAMA, TEST_LLAMA
+
+MOCK_OLLAMA_EP = "http://mock-ollama:11434"
+MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
+
+
+def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
+ cfg = MagicMock()
+ cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
+ cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
+ cfg.api_keys = api_keys or {}
+ cfg.max_concurrent_connections = 2
+ cfg.router_api_key = None
+ return cfg
+
+
+@pytest.fixture(autouse=True)
+def clear_caches(aio_session):
+ """aio_session fixture already clears caches and sets up app_state."""
+ yield
+
+
+class TestFetchAvailableModels:
+ async def test_ollama_tags(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/tags",
+ payload={"models": [
+ {"name": "llama3.2:latest"},
+ {"name": "qwen2.5:7b"},
+ ]},
+ )
+ models = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ assert models == {"llama3.2:latest", "qwen2.5:7b"}
+
+ async def test_openai_compatible_models_endpoint(self):
+ cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_LLAMA_EP}/models",
+ payload={"data": [{"id": "unsloth/model:Q8_0"}]},
+ )
+ models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok")
+ assert "unsloth/model:Q8_0" in models
+
+ async def test_caches_successful_result(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/tags",
+ payload={"models": [{"name": "llama3.2:latest"}]},
+ )
+ first = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ second = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ # second call must be served from cache without a second HTTP request
+ assert first == second == {"llama3.2:latest"}
+
+ async def test_returns_empty_on_http_500(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
+ models = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ assert models == set()
+
+ async def test_returns_empty_on_connection_error(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ import aiohttp
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/tags",
+ exception=aiohttp.ClientConnectorError(
+ connection_key=MagicMock(host="mock-ollama", port=11434),
+ os_error=OSError(111, "refused"),
+ ),
+ )
+ models = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ assert models == set()
+
+ async def test_stale_cache_returned_while_refresh_runs(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/tags",
+ payload={"models": [{"name": "llama3.2:latest"}]},
+ )
+ await router.fetch.available_models(MOCK_OLLAMA_EP)
+
+ # Manually age cache into stale-but-valid window (300-600s)
+ async with router._models_cache_lock:
+ models, _ = router._models_cache[MOCK_OLLAMA_EP]
+ router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
+
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/tags",
+ payload={"models": [{"name": "llama3.2:latest"}]},
+ )
+ # Should return stale data immediately
+ stale = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ assert "llama3.2:latest" in stale
+
+ async def test_error_cache_short_circuits(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ # Seed error cache with a very recent error
+ async with router._available_error_cache_lock:
+ router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
+
+ with patch.object(router, "config", cfg), aioresponses():
+ # No HTTP mock registered — if a call happens it will raise
+ models = await router.fetch.available_models(MOCK_OLLAMA_EP)
+ assert models == set()
+
+
+class TestFetchLoadedModels:
+ async def test_ollama_ps(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/ps",
+ payload={"models": [{"name": "llama3.2:latest"}]},
+ )
+ models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
+ assert models == {"llama3.2:latest"}
+
+ async def test_llama_server_filters_loaded(self):
+ cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_LLAMA_EP}/models",
+ payload={"data": [
+ {"id": "model-a", "status": {"value": "loaded"}},
+ {"id": "model-b", "status": {"value": "unloaded"}},
+ ]},
+ )
+ models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
+ assert models == {"model-a"}
+
+ async def test_llama_server_no_status_field_always_loaded(self):
+ cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_LLAMA_EP}/models",
+ payload={"data": [{"id": "always-on-model"}]},
+ )
+ models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
+ assert "always-on-model" in models
+
+ async def test_returns_empty_on_error(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
+ models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
+ assert models == set()
+
+ async def test_ext_openai_always_empty(self):
+ ext_ep = "https://api.openai.com/v1"
+ cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[])
+ with patch.object(router, "config", cfg):
+ models = await router.fetch.loaded_models(ext_ep)
+ assert models == set()
+
+ async def test_caches_result(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/ps",
+ payload={"models": [{"name": "qwen:7b"}]},
+ )
+ first = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
+ second = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
+ assert first == second
+
+ async def test_records_error_in_loaded_error_cache_on_failure(self):
+ # Regression: issue #83 — /api/ps failures must be recorded so
+ # `choose_endpoint` can exclude unhealthy backends from routing.
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
+ await router.fetch.loaded_models(MOCK_OLLAMA_EP)
+ assert MOCK_OLLAMA_EP in router._loaded_error_cache
+
+ async def test_records_error_for_llama_server_on_failure(self):
+ cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP])
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
+ await router.fetch.loaded_models(MOCK_LLAMA_EP)
+ assert MOCK_LLAMA_EP in router._loaded_error_cache
+
+ async def test_clears_error_cache_on_subsequent_success(self):
+ cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
+ # Pre-seed an old error so loaded_models() falls through to the
+ # network probe instead of short-circuiting on the error cache.
+ async with router._loaded_error_cache_lock:
+ router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301
+ with patch.object(router, "config", cfg), aioresponses() as m:
+ m.get(
+ f"{MOCK_OLLAMA_EP}/api/ps",
+ payload={"models": [{"name": "qwen:7b"}]},
+ )
+ await router.fetch.loaded_models(MOCK_OLLAMA_EP)
+ assert MOCK_OLLAMA_EP not in router._loaded_error_cache
diff --git a/test/test_openai_proxies.py b/test/test_openai_proxies.py
new file mode 100644
index 0000000..8a56c91
--- /dev/null
+++ b/test/test_openai_proxies.py
@@ -0,0 +1,181 @@
+"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
+
+These tests verify that when the LLM cache reports a hit, the route returns
+the cached payload *without* selecting an endpoint or contacting any backend.
+"""
+from unittest.mock import AsyncMock, patch
+
+import orjson
+import pytest
+from fastapi import HTTPException
+
+import router
+
+
+_BYPASS = HTTPException(status_code=599, detail="bypassed")
+
+
+class _FakeCache:
+ """Minimal stand-in for cache.LLMCache.get_chat."""
+ def __init__(self, response_bytes: bytes | None):
+ self._resp = response_bytes
+ self.calls: list[tuple] = []
+
+ async def get_chat(self, route, model, messages):
+ self.calls.append((route, model, messages))
+ return self._resp
+
+
+@pytest.fixture
+def cache_hit_payload():
+ return orjson.dumps({
+ "id": "cmpl-xyz",
+ "created": 1,
+ "model": "test-model",
+ "choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
+ "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
+ })
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# /v1/chat/completions
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestOpenAIChatCompletionsCacheHit:
+ async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
+ fake = _FakeCache(cache_hit_payload)
+ # Patch the route's references to both helpers — they're imported by name
+ # into router's namespace at module load time.
+ with (
+ patch.object(router, "get_llm_cache", return_value=fake),
+ patch.object(router, "choose_endpoint",
+ AsyncMock(side_effect=AssertionError("backend must not be reached"))),
+ ):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "ping"}],
+ "stream": False,
+ "nomyo": {"cache": True},
+ },
+ )
+ assert resp.status_code == 200
+ # Body is streamed; collect it
+ body = resp.content
+ parsed = orjson.loads(body)
+ assert parsed["choices"][0]["message"]["content"] == "from-cache"
+ assert fake.calls and fake.calls[0][0] == "openai_chat"
+
+ async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
+ fake = _FakeCache(cache_hit_payload)
+ with (
+ patch.object(router, "get_llm_cache", return_value=fake),
+ patch.object(router, "choose_endpoint",
+ AsyncMock(side_effect=AssertionError("backend must not be reached"))),
+ ):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "test-model",
+ "messages": [{"role": "user", "content": "ping"}],
+ "stream": True,
+ "nomyo": {"cache": True},
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.headers["content-type"].startswith("text/event-stream")
+ text = resp.content.decode()
+ # First SSE frame contains the cached content as a delta
+ first_frame = text.split("\n\n")[0]
+ assert first_frame.startswith("data: ")
+ chunk = orjson.loads(first_frame[len("data: "):])
+ assert chunk["choices"][0]["delta"]["content"] == "from-cache"
+ # Stream is terminated with [DONE]
+ assert "data: [DONE]" in text
+
+ async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
+ """When nomyo.cache=False, get_chat is never called even if a cache exists."""
+ fake = _FakeCache(b"") # has a response, but should never be consulted
+ with (
+ patch.object(router, "get_llm_cache", return_value=fake),
+ patch.object(router, "choose_endpoint",
+ AsyncMock(side_effect=_BYPASS)),
+ ):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "m",
+ "messages": [{"role": "user", "content": "hi"}],
+ "nomyo": {"cache": False},
+ },
+ )
+ # Got past the cache short-circuit → endpoint selection invoked
+ assert resp.status_code == 599
+ assert fake.calls == []
+
+ async def test_no_cache_configured_bypasses_cache_check(self, client):
+ """get_llm_cache() returning None should not break the route."""
+ with (
+ patch.object(router, "get_llm_cache", return_value=None),
+ patch.object(router, "choose_endpoint",
+ AsyncMock(side_effect=_BYPASS)),
+ ):
+ resp = await client.post(
+ "/v1/chat/completions",
+ json={
+ "model": "m",
+ "messages": [{"role": "user", "content": "hi"}],
+ "nomyo": {"cache": True},
+ },
+ )
+ assert resp.status_code == 599
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# /v1/completions
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestOpenAICompletionsCacheHit:
+ async def test_nonstream_cache_hit(self, client, cache_hit_payload):
+ fake = _FakeCache(cache_hit_payload)
+ with (
+ patch.object(router, "get_llm_cache", return_value=fake),
+ patch.object(router, "choose_endpoint",
+ AsyncMock(side_effect=AssertionError("backend must not be reached"))),
+ ):
+ resp = await client.post(
+ "/v1/completions",
+ json={
+ "model": "test-model",
+ "prompt": "Tell me a joke",
+ "stream": False,
+ "nomyo": {"cache": True},
+ },
+ )
+ assert resp.status_code == 200
+ # Prompt-style cache lookup is namespaced under "openai_completions"
+ assert fake.calls[0][0] == "openai_completions"
+ # Cache lookup receives the prompt as a single user message
+ cached_msgs = fake.calls[0][2]
+ assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
+
+ async def test_stream_cache_hit(self, client, cache_hit_payload):
+ fake = _FakeCache(cache_hit_payload)
+ with (
+ patch.object(router, "get_llm_cache", return_value=fake),
+ patch.object(router, "choose_endpoint",
+ AsyncMock(side_effect=AssertionError("backend must not be reached"))),
+ ):
+ resp = await client.post(
+ "/v1/completions",
+ json={
+ "model": "test-model",
+ "prompt": "What is 2+2?",
+ "stream": True,
+ "nomyo": {"cache": True},
+ },
+ )
+ assert resp.status_code == 200
+ assert resp.headers["content-type"].startswith("text/event-stream")
+ assert "data: [DONE]" in resp.content.decode()
diff --git a/test/test_unit_context.py b/test/test_unit_context.py
new file mode 100644
index 0000000..de2b98a
--- /dev/null
+++ b/test/test_unit_context.py
@@ -0,0 +1,116 @@
+"""Unit tests for context-window trimming logic."""
+import pytest
+import router
+
+
+def _msgs(roles_contents):
+ return [{"role": r, "content": c} for r, c in roles_contents]
+
+
+class TestCountMessageTokens:
+ def test_returns_int(self):
+ msgs = _msgs([("user", "hello")])
+ assert isinstance(router._count_message_tokens(msgs), int)
+
+ def test_empty_list(self):
+ assert router._count_message_tokens([]) >= 0
+
+ def test_longer_content_more_tokens(self):
+ short = _msgs([("user", "hi")])
+ long_ = _msgs([("user", "a " * 500)])
+ assert router._count_message_tokens(long_) > router._count_message_tokens(short)
+
+ def test_list_content(self):
+ msgs = [{"role": "user", "content": [
+ {"type": "text", "text": "what do you see?"},
+ ]}]
+ tokens = router._count_message_tokens(msgs)
+ assert tokens > 0
+
+ def test_multiple_messages(self):
+ msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
+ assert router._count_message_tokens(msgs) > 10
+
+
+class TestTrimMessagesForContext:
+ def test_short_history_unchanged(self):
+ msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
+ result = router._trim_messages_for_context(msgs, n_ctx=4096)
+ assert result == msgs
+
+ def test_system_messages_always_kept(self):
+ msgs = (
+ _msgs([("system", "you are helpful")])
+ + _msgs([("user", f"msg {i}") for i in range(50)])
+ + _msgs([("user", "final question")])
+ )
+ result = router._trim_messages_for_context(msgs, n_ctx=512)
+ system_msgs = [m for m in result if m["role"] == "system"]
+ assert len(system_msgs) == 1
+ assert system_msgs[0]["content"] == "you are helpful"
+
+ def test_last_user_message_always_kept(self):
+ msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
+ result = router._trim_messages_for_context(msgs, n_ctx=256)
+ assert result[-1]["content"] == "very important last question"
+
+ def test_oldest_dropped_first(self):
+ msgs = _msgs([
+ ("user", "oldest msg"),
+ ("assistant", "oldest reply"),
+ ("user", "newer msg"),
+ ("assistant", "newer reply"),
+ ("user", "newest"),
+ ])
+ # Use very small target to force trimming
+ result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
+ contents = [m["content"] for m in result]
+ # "oldest msg" should be dropped before "newest"
+ if "oldest msg" in contents:
+ assert "newest" in contents
+ else:
+ assert "newest" in contents
+
+ def test_result_starts_with_user(self):
+ msgs = _msgs([
+ ("assistant", "leftover assistant"),
+ ("user", "question"),
+ ])
+ result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
+ if result:
+ assert result[0]["role"] == "user"
+
+ def test_target_tokens_overrides_safety_margin(self):
+ msgs = _msgs([("user", "a " * 200)])
+ result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
+ result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
+ # Both should return at least the last message
+ assert len(result_small) >= 1
+ assert len(result_large) >= 1
+
+
+class TestCalibratedTrimTarget:
+ def test_returns_positive_int(self):
+ msgs = [{"role": "user", "content": "hello " * 100}]
+ result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
+ assert isinstance(result, int)
+ assert result >= 1
+
+ def test_over_limit_reduces_target(self):
+ msgs = [{"role": "user", "content": "a " * 500}]
+ # actual_tokens > n_ctx means we need to shed more
+ target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
+ assert target < router._count_message_tokens(msgs)
+
+ def test_well_within_limit_returns_current(self):
+ msgs = [{"role": "user", "content": "hi"}]
+ # actual_tokens << n_ctx means nothing to shed
+ target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
+ # Should return cur_tiktoken since to_shed == 0
+ assert target == max(1, router._count_message_tokens(msgs))
+
+ def test_minimum_is_one(self):
+ # Even if we need to shed everything, result is at least 1
+ msgs = [{"role": "user", "content": "hello"}]
+ target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
+ assert target >= 1
diff --git a/test/test_unit_helpers.py b/test/test_unit_helpers.py
new file mode 100644
index 0000000..d38eb37
--- /dev/null
+++ b/test/test_unit_helpers.py
@@ -0,0 +1,279 @@
+"""Unit tests for pure helper functions in router.py (no network, no app)."""
+import time
+import asyncio
+from unittest.mock import MagicMock, patch
+
+import aiohttp
+import pytest
+
+import router
+
+
+class TestMaskSecrets:
+ def test_masks_openai_key(self):
+ text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ"
+ result = router._mask_secrets(text)
+ assert "sk-***redacted***" in result
+ assert "sk-abcd1234" not in result
+
+ def test_masks_api_key_assignment(self):
+ result = router._mask_secrets("api_key: supersecretvalue123")
+ assert "supersecretvalue123" not in result
+ assert "***redacted***" in result
+
+ def test_masks_api_key_with_colon(self):
+ result = router._mask_secrets("api-key: mykey")
+ assert "mykey" not in result
+
+ def test_empty_string_returns_empty(self):
+ assert router._mask_secrets("") == ""
+
+ def test_none_returns_none(self):
+ assert router._mask_secrets(None) is None
+
+ def test_no_secrets_unchanged(self):
+ text = "this is a normal log line"
+ assert router._mask_secrets(text) == text
+
+
+class TestIsFresh:
+ def test_fresh_within_ttl(self):
+ cached_at = time.time() - 10
+ assert router._is_fresh(cached_at, 300) is True
+
+ def test_expired_beyond_ttl(self):
+ cached_at = time.time() - 400
+ assert router._is_fresh(cached_at, 300) is False
+
+ def test_exactly_at_boundary(self):
+ cached_at = time.time() - 300
+ # May be True or False depending on timing, just verify it runs
+ result = router._is_fresh(cached_at, 300)
+ assert isinstance(result, bool)
+
+ def test_just_cached(self):
+ assert router._is_fresh(time.time(), 1) is True
+
+
+class TestNormalizeLlamaModelName:
+ def test_strips_hf_prefix(self):
+ assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF"
+
+ def test_strips_quant_suffix(self):
+ assert router._normalize_llama_model_name("model:Q8_0") == "model"
+
+ def test_strips_both(self):
+ result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16")
+ assert result == "gpt-oss-20b-GGUF"
+
+ def test_no_prefix_no_suffix(self):
+ assert router._normalize_llama_model_name("plain-model") == "plain-model"
+
+ def test_multiple_slashes(self):
+ result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M")
+ assert result == "model-name"
+
+
+class TestExtractLlamaQuant:
+ def test_extracts_quant(self):
+ assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0"
+
+ def test_no_quant_returns_empty(self):
+ assert router._extract_llama_quant("plain-model") == ""
+
+ def test_f16(self):
+ assert router._extract_llama_quant("model:F16") == "F16"
+
+ def test_q4_k_m(self):
+ assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M"
+
+
+class TestIsUnixSocketEndpoint:
+ def test_sock_endpoint_detected(self):
+ assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True
+
+ def test_regular_http_not_sock(self):
+ assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False
+
+ def test_ollama_not_sock(self):
+ assert router._is_unix_socket_endpoint("http://localhost:11434") is False
+
+ def test_dot_sock_in_host_detected(self):
+ assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True
+
+
+class TestGetSocketPath:
+ def test_returns_run_user_path(self):
+ import os
+ path = router._get_socket_path("http://192.168.0.52.sock/v1")
+ uid = os.getuid()
+ assert path == f"/run/user/{uid}/192.168.0.52.sock"
+
+
+class TestIsBase64:
+ def test_valid_base64(self):
+ import base64
+ data = base64.b64encode(b"hello world").decode()
+ assert router.is_base64(data) is True
+
+ def test_invalid_base64(self):
+ assert router.is_base64("not-base64!@#$") is False
+
+ def test_empty_string(self):
+ # Empty string is valid base64 (decodes to empty bytes)
+ assert router.is_base64("") is True
+
+ def test_non_string(self):
+ # Non-strings fall through without returning True (returns None)
+ assert not router.is_base64(12345)
+
+
+class TestIsLlamaModelLoaded:
+ def test_status_dict_loaded(self):
+ assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True
+
+ def test_status_dict_unloaded(self):
+ assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False
+
+ def test_status_string_loaded(self):
+ assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True
+
+ def test_status_string_unloaded(self):
+ assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False
+
+ def test_no_status_field_always_loaded(self):
+ # No status field → always available (single-model server)
+ assert router._is_llama_model_loaded({"id": "m"}) is True
+
+ def test_status_none_always_loaded(self):
+ assert router._is_llama_model_loaded({"id": "m", "status": None}) is True
+
+
+class TestEp2Base:
+ def test_adds_v1_to_ollama(self):
+ assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1"
+
+ def test_keeps_v1_if_present(self):
+ assert router.ep2base("http://host/v1") == "http://host/v1"
+
+ def test_llama_server_endpoint_unchanged(self):
+ ep = "http://192.168.0.50:8889/v1"
+ assert router.ep2base(ep) == ep
+
+
+class TestDedupeOnKeys:
+ def test_removes_duplicate_by_single_key(self):
+ items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}]
+ result = router.dedupe_on_keys(items, ["name"])
+ assert len(result) == 2
+ assert result[0]["name"] == "a"
+ assert result[1]["name"] == "b"
+
+ def test_removes_duplicate_by_two_keys(self):
+ items = [
+ {"digest": "abc", "name": "m1"},
+ {"digest": "abc", "name": "m1"},
+ {"digest": "def", "name": "m2"},
+ ]
+ result = router.dedupe_on_keys(items, ["digest", "name"])
+ assert len(result) == 2
+
+ def test_empty_list(self):
+ assert router.dedupe_on_keys([], ["name"]) == []
+
+ def test_no_duplicates(self):
+ items = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
+ assert len(router.dedupe_on_keys(items, ["name"])) == 3
+
+
+class TestFormatConnectionIssue:
+ def test_connector_error_message(self):
+ err = aiohttp.ClientConnectorError(
+ connection_key=MagicMock(host="localhost", port=11434),
+ os_error=OSError(111, "Connection refused"),
+ )
+ msg = router._format_connection_issue("http://localhost:11434", err)
+ assert "localhost" in msg
+ assert "Connection refused" in msg or "111" in msg
+
+ def test_timeout_error_message(self):
+ msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError())
+ assert "Timed out" in msg
+ assert "host:1234" in msg
+
+ def test_generic_error(self):
+ msg = router._format_connection_issue("http://host:1234", ValueError("boom"))
+ assert "host:1234" in msg
+ assert "boom" in msg
+
+
+class TestIsExtOpenaiEndpoint:
+ def test_openai_com_is_ext(self):
+ cfg = MagicMock()
+ cfg.endpoints = []
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True
+
+ def test_ollama_default_port_not_ext(self):
+ cfg = MagicMock()
+ cfg.endpoints = ["http://host:11434"]
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.is_ext_openai_endpoint("http://host:11434") is False
+
+ def test_llama_server_not_ext(self):
+ cfg = MagicMock()
+ cfg.endpoints = []
+ cfg.llama_server_endpoints = ["http://host:8080/v1"]
+ with patch.object(router, "config", cfg):
+ assert router.is_ext_openai_endpoint("http://host:8080/v1") is False
+
+ def test_no_v1_not_ext(self):
+ cfg = MagicMock()
+ cfg.endpoints = ["http://host:11434"]
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.is_ext_openai_endpoint("http://host:11434") is False
+
+
+class TestIsOpenaiCompatible:
+ def test_v1_endpoint_compatible(self):
+ cfg = MagicMock()
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.is_openai_compatible("http://host/v1") is True
+
+ def test_ollama_not_compatible(self):
+ cfg = MagicMock()
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.is_openai_compatible("http://localhost:11434") is False
+
+ def test_llama_server_in_list_compatible(self):
+ cfg = MagicMock()
+ cfg.llama_server_endpoints = ["http://host:8080"]
+ with patch.object(router, "config", cfg):
+ assert router.is_openai_compatible("http://host:8080") is True
+
+
+class TestGetTrackingModel:
+ def test_ollama_adds_latest(self):
+ cfg = MagicMock()
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest"
+
+ def test_ollama_keeps_existing_tag(self):
+ cfg = MagicMock()
+ cfg.llama_server_endpoints = []
+ with patch.object(router, "config", cfg):
+ assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b"
+
+ def test_llama_server_normalizes(self):
+ ep = "http://host:8080/v1"
+ cfg = MagicMock()
+ cfg.llama_server_endpoints = [ep]
+ with patch.object(router, "config", cfg):
+ result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
+ assert result == "model"
diff --git a/test/test_unit_rechunk.py b/test/test_unit_rechunk.py
new file mode 100644
index 0000000..e0d01c9
--- /dev/null
+++ b/test/test_unit_rechunk.py
@@ -0,0 +1,173 @@
+"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion."""
+import time
+from types import SimpleNamespace
+
+import ollama
+
+import router
+
+
+def _ns(**kw):
+ return SimpleNamespace(**kw)
+
+
+def _stream_chunk(content="hi", role="assistant", finish_reason=None,
+ usage=None, model="m"):
+ """Build a SimpleNamespace mimicking a streaming OpenAI chunk."""
+ delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
+ tool_calls=None)
+ choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None)
+ return _ns(model=model, choices=[choice], usage=usage)
+
+
+def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop",
+ usage=None, model="m", tool_calls=None):
+ """Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion."""
+ message = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
+ tool_calls=tool_calls)
+ choice = _ns(message=message, finish_reason=finish_reason, logprobs=None)
+ return _ns(model=model, choices=[choice], usage=usage)
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# openai_chat_completion2ollama
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestChatCompletionToOllama:
+ def test_streaming_content_chunk(self):
+ chunk = _stream_chunk(content="hello", finish_reason=None, usage=None)
+ out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
+ assert isinstance(out, ollama.ChatResponse)
+ assert out.message.role == "assistant"
+ assert out.message.content == "hello"
+ assert out.done is False # usage is None → not done yet
+ assert out.model == "m"
+
+ def test_streaming_empty_content_defaults(self):
+ # Some chunks have content=None — should coerce to empty string
+ chunk = _stream_chunk(content=None, role=None)
+ out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
+ assert out.message.role == "assistant" # role defaulted
+ assert out.message.content == ""
+
+ def test_final_usage_only_chunk_marks_done(self):
+ usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15)
+ chunk = _ns(model="m", choices=[], usage=usage)
+ out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
+ assert out.done is True
+ assert out.done_reason == "stop"
+ assert out.prompt_eval_count == 10
+ assert out.eval_count == 5
+ assert out.message.content == ""
+
+ def test_nonstreaming_with_content(self):
+ usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5)
+ chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage)
+ out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
+ assert out.done is True
+ assert out.message.content == "response text"
+ assert out.prompt_eval_count == 2
+ assert out.eval_count == 3
+
+ def test_nonstreaming_tool_calls_converted(self):
+ """Tool calls with JSON string arguments are parsed into dicts."""
+ tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}'))
+ usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
+ chunk = _nonstream_chunk(
+ content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc]
+ )
+ out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
+ assert out.message.tool_calls is not None
+ assert len(out.message.tool_calls) == 1
+ first = out.message.tool_calls[0]
+ assert first.function.name == "get_weather"
+ assert first.function.arguments == {"city": "Paris"}
+
+ def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self):
+ tc = _ns(function=_ns(name="f", arguments="not-json"))
+ usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
+ chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc])
+ out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
+ assert out.message.tool_calls[0].function.arguments == {}
+
+ def test_streaming_tool_calls_in_delta_are_skipped(self):
+ """Streaming mode must not assemble tool calls (caller handles it)."""
+ chunk = _stream_chunk(content="x", finish_reason=None)
+ # Even if a chunk somehow carried tool_calls in the delta, streaming
+ # mode should ignore them.
+ out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
+ assert out.message.tool_calls is None
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# openai_completion2ollama
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestCompletionToOllama:
+ def test_streaming_text_chunk(self):
+ choice = _ns(text="word", finish_reason=None, reasoning=None)
+ chunk = _ns(model="m", choices=[choice], usage=None)
+ out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
+ assert isinstance(out, ollama.GenerateResponse)
+ assert out.response == "word"
+ assert out.done is False
+
+ def test_final_chunk_with_usage(self):
+ usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10)
+ choice = _ns(text="end", finish_reason="stop", reasoning=None)
+ chunk = _ns(model="m", choices=[choice], usage=usage)
+ out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
+ assert out.done is True
+ assert out.prompt_eval_count == 4
+ assert out.eval_count == 6
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# embeddings / embed
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestEmbeddingConversions:
+ def test_openai_embeddings2ollama(self):
+ chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])])
+ out = router.rechunk.openai_embeddings2ollama(chunk)
+ assert isinstance(out, ollama.EmbeddingsResponse)
+ assert list(out.embedding) == [0.1, 0.2, 0.3]
+
+ def test_openai_embed2ollama(self):
+ chunk = _ns(data=[_ns(embedding=[0.5, 0.6])])
+ out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model")
+ assert isinstance(out, ollama.EmbedResponse)
+ assert out.model == "my-embed-model"
+ assert list(out.embeddings[0]) == [0.5, 0.6]
+
+
+# ──────────────────────────────────────────────────────────────────────────────
+# extract_usage_from_llama_timings
+# ──────────────────────────────────────────────────────────────────────────────
+
+class TestExtractUsageFromLlamaTimings:
+ def test_none_when_no_timings_attr(self):
+ obj = _ns()
+ assert router.rechunk.extract_usage_from_llama_timings(obj) is None
+
+ def test_prompt_plus_cache_sums(self):
+ obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35})
+ prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
+ assert prompt == 237
+ assert completion == 35
+
+ def test_missing_keys_default_to_zero(self):
+ obj = _ns(timings={"predicted_n": 12})
+ prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
+ assert prompt == 0
+ assert completion == 12
+
+ def test_null_values_treated_as_zero(self):
+ obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None})
+ prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
+ assert prompt == 0
+ assert completion == 0
+
+ def test_non_dict_timings_returns_none(self):
+ obj = _ns(timings="not-a-dict")
+ assert router.rechunk.extract_usage_from_llama_timings(obj) is None
diff --git a/test/test_unit_transforms.py b/test/test_unit_transforms.py
new file mode 100644
index 0000000..51160a0
--- /dev/null
+++ b/test/test_unit_transforms.py
@@ -0,0 +1,200 @@
+"""Unit tests for message transformation functions."""
+from unittest.mock import MagicMock
+
+import pytest
+
+import router
+
+
+class TestStripAssistantPrefill:
+ def test_removes_trailing_assistant(self):
+ msgs = [
+ {"role": "user", "content": "hello"},
+ {"role": "assistant", "content": "prefill"},
+ ]
+ result = router._strip_assistant_prefill(msgs)
+ assert len(result) == 1
+ assert result[0]["role"] == "user"
+
+ def test_keeps_non_trailing_assistant(self):
+ msgs = [
+ {"role": "user", "content": "hello"},
+ {"role": "assistant", "content": "response"},
+ {"role": "user", "content": "follow-up"},
+ ]
+ result = router._strip_assistant_prefill(msgs)
+ assert len(result) == 3
+
+ def test_empty_list_unchanged(self):
+ assert router._strip_assistant_prefill([]) == []
+
+ def test_single_user_message_unchanged(self):
+ msgs = [{"role": "user", "content": "hi"}]
+ assert router._strip_assistant_prefill(msgs) == msgs
+
+
+class TestTransformToolCallsToOpenAI:
+ def test_adds_type_function(self):
+ msgs = [{"role": "assistant", "tool_calls": [
+ {"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}}
+ ]}]
+ result = router.transform_tool_calls_to_openai(msgs)
+ tc = result[0]["tool_calls"][0]
+ assert tc["type"] == "function"
+
+ def test_adds_id_when_missing(self):
+ msgs = [{"role": "assistant", "tool_calls": [
+ {"function": {"name": "fn", "arguments": {}}}
+ ]}]
+ result = router.transform_tool_calls_to_openai(msgs)
+ assert "id" in result[0]["tool_calls"][0]
+
+ def test_converts_dict_arguments_to_string(self):
+ msgs = [{"role": "assistant", "tool_calls": [
+ {"function": {"name": "fn", "arguments": {"key": "val"}}}
+ ]}]
+ result = router.transform_tool_calls_to_openai(msgs)
+ args = result[0]["tool_calls"][0]["function"]["arguments"]
+ assert isinstance(args, str)
+ import orjson
+ parsed = orjson.loads(args)
+ assert parsed == {"key": "val"}
+
+ def test_keeps_string_arguments_unchanged(self):
+ msgs = [{"role": "assistant", "tool_calls": [
+ {"function": {"name": "fn", "arguments": '{"key": "val"}'}}
+ ]}]
+ result = router.transform_tool_calls_to_openai(msgs)
+ args = result[0]["tool_calls"][0]["function"]["arguments"]
+ assert args == '{"key": "val"}'
+
+ def test_links_tool_call_id_to_tool_response(self):
+ msgs = [
+ {"role": "assistant", "tool_calls": [
+ {"function": {"name": "get_weather", "arguments": {}}}
+ ]},
+ {"role": "tool", "name": "get_weather", "content": "sunny"},
+ ]
+ result = router.transform_tool_calls_to_openai(msgs)
+ tc_id = result[0]["tool_calls"][0]["id"]
+ assert result[1].get("tool_call_id") == tc_id
+
+ def test_non_tool_messages_unchanged(self):
+ msgs = [{"role": "user", "content": "hello"}]
+ result = router.transform_tool_calls_to_openai(msgs)
+ assert result == msgs
+
+
+class TestStripImagesFromMessages:
+ def test_removes_image_url_parts(self):
+ msgs = [{"role": "user", "content": [
+ {"type": "text", "text": "what is this?"},
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
+ ]}]
+ result = router._strip_images_from_messages(msgs)
+ content = result[0]["content"]
+ assert content == "what is this?"
+
+ def test_keeps_text_only_messages(self):
+ msgs = [{"role": "user", "content": "plain text"}]
+ result = router._strip_images_from_messages(msgs)
+ assert result[0]["content"] == "plain text"
+
+ def test_multiple_text_parts_kept_as_list(self):
+ msgs = [{"role": "user", "content": [
+ {"type": "text", "text": "part one"},
+ {"type": "text", "text": "part two"},
+ {"type": "image_url", "image_url": {"url": "data:..."}},
+ ]}]
+ result = router._strip_images_from_messages(msgs)
+ content = result[0]["content"]
+ assert isinstance(content, list)
+ assert len(content) == 2
+
+ def test_all_images_removed_empty_list(self):
+ msgs = [{"role": "user", "content": [
+ {"type": "image_url", "image_url": {"url": "data:..."}},
+ ]}]
+ result = router._strip_images_from_messages(msgs)
+ # Image-only content becomes empty list
+ content = result[0]["content"]
+ assert content == []
+
+
+class TestAccumulateOpenAITcDelta:
+ def _make_chunk(self, index, name=None, args_fragment="", tc_id=None):
+ delta = MagicMock()
+ tc = MagicMock()
+ tc.index = index
+ tc.id = tc_id
+ tc.function = MagicMock()
+ tc.function.name = name
+ tc.function.arguments = args_fragment
+ delta.tool_calls = [tc]
+ chunk = MagicMock()
+ chunk.choices = [MagicMock(delta=delta)]
+ return chunk
+
+ def test_first_delta_creates_entry(self):
+ acc = {}
+ chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"')
+ router._accumulate_openai_tc_delta(chunk, acc)
+ assert 0 in acc
+ assert acc[0]["name"] == "my_fn"
+ assert acc[0]["arguments"] == '{"k"'
+
+ def test_subsequent_deltas_concatenate_args(self):
+ acc = {}
+ router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc)
+ router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc)
+ assert acc[0]["arguments"] == '{"k": "v"}'
+
+ def test_multiple_tool_calls_tracked_separately(self):
+ acc = {}
+ c1 = self._make_chunk(0, name="fn1", args_fragment="{}")
+ c2 = self._make_chunk(1, name="fn2", args_fragment="{}")
+ chunk = MagicMock()
+ tc1 = MagicMock()
+ tc1.index = 0
+ tc1.id = "id1"
+ tc1.function = MagicMock(name="fn1", arguments="{}")
+ tc2 = MagicMock()
+ tc2.index = 1
+ tc2.id = "id2"
+ tc2.function = MagicMock(name="fn2", arguments="{}")
+ chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))]
+ router._accumulate_openai_tc_delta(chunk, acc)
+ assert 0 in acc and 1 in acc
+
+ def test_no_choices_is_noop(self):
+ acc = {}
+ chunk = MagicMock(choices=[])
+ router._accumulate_openai_tc_delta(chunk, acc)
+ assert acc == {}
+
+
+class TestBuildOllamaToolCalls:
+ def test_builds_from_accumulator(self):
+ acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}}
+ result = router._build_ollama_tool_calls(acc)
+ assert result is not None
+ assert len(result) == 1
+ assert result[0].function.name == "get_weather"
+ assert result[0].function.arguments == {"city": "Berlin"}
+
+ def test_invalid_json_args_becomes_empty_dict(self):
+ acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}}
+ result = router._build_ollama_tool_calls(acc)
+ assert result[0].function.arguments == {}
+
+ def test_empty_accumulator_returns_none(self):
+ assert router._build_ollama_tool_calls({}) is None
+
+ def test_preserves_order_by_index(self):
+ acc = {
+ 1: {"id": "c2", "name": "fn2", "arguments": "{}"},
+ 0: {"id": "c1", "name": "fn1", "arguments": "{}"},
+ }
+ result = router._build_ollama_tool_calls(acc)
+ assert result[0].function.name == "fn1"
+ assert result[1].function.name == "fn2"