Compare commits

..

No commits in common. "main" and "workflow-tuning" have entirely different histories.

28 changed files with 191 additions and 3682 deletions

View file

@ -86,7 +86,9 @@ jobs:
provenance: false
build-args: |
SEMANTIC_CACHE=true
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}-${{ github.run_id }}
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }}
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache-semantic-${{ matrix.arch }},mode=max
merge:
runs-on: docker-amd64
@ -142,6 +144,6 @@ jobs:
run: |
docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64-${{ github.run_id }} \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64-${{ github.run_id }}
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64 \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64

View file

@ -77,7 +77,7 @@ jobs:
platforms: ${{ matrix.platform }}
push: true
provenance: false
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}-${{ github.run_id }}
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}
merge:
runs-on: docker-amd64
@ -133,6 +133,6 @@ jobs:
run: |
docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64-${{ github.run_id }} \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64-${{ github.run_id }}
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64 \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64

View file

@ -1,39 +0,0 @@
name: PR Tests
on: [pull_request]
jobs:
test:
runs-on: docker-arm64
container:
image: python:3.12-slim
env:
CMAKE_BUILD_PARALLEL_LEVEL: "4"
steps:
- name: Install system deps
run: |
apt-get update
apt-get install -y --no-install-recommends \
git ca-certificates \
build-essential pkg-config
rm -rf /var/lib/apt/lists/*
- name: Checkout
run: |
git config --global --add safe.directory "$PWD"
git clone --depth=1 \
"https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" .
git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr"
git checkout pr
- name: Fetch action source
run: |
git clone --depth=1 --branch master \
"https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \
./.run-tests
- uses: ./.run-tests/run-tests
with:
setup: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r test/requirements_test.txt
command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml
artifacts-path: |
report.xml
coverage.xml

5
.gitignore vendored
View file

@ -66,4 +66,7 @@ config.yaml
# SQLite
*.db*
*settings.json
*settings.json
# Test suite (local only, not committed yet)
test/

View file

@ -1,24 +0,0 @@
{
"version": 1,
"decisions": [],
"suppression_rules": [
{
"by": "rule",
"value": "py.auth.token_override_without_validation",
"state": "suppressed",
"note": "false_positive: token validation handled upstream by middleware"
},
{
"by": "rule",
"value": "state-resource-leak",
"state": "suppressed",
"note": "false_positive: resource lifecycle managed externally"
},
{
"by": "rule",
"value": "py.crypto.sha1",
"state": "suppressed",
"note": "accepted_risk: used for non-security checksum only"
}
]
}

View file

@ -26,26 +26,6 @@ max_concurrent_connections: 2
# When false (default), equally-idle endpoints are chosen at random.
# priority_routing: true
# Conversation affinity (optional, default: false).
# Pins a conversation to the endpoint that served its first turn so the
# llama.cpp / Ollama prompt cache (KV cache) stays warm — first turn pays
# the cold prefill, every follow-up turn reuses the same prefix.
#
# Fingerprint = sha1(model + leading system messages + first user turn).
# Same chat → same fingerprint on every follow-up turn → same pin, TTL
# refreshed on each reuse. Soft preference: if the pinned endpoint no
# longer has the model loaded or has no free slot, the standard algorithm
# takes over (no failure, just a cache miss).
#
# Heads-up: most chat UIs (Open WebUI, LibreChat, …) fire side requests for
# title / tag / follow-up generation. Those have their own first turn and
# therefore their own pin, so a single visible "chat" may show several dots
# in the dashboard's Affinity column. That is correct — each pin matches a
# real warm KV prefix on the backend. See doc/configuration.md for details.
conversation_affinity: true
conversation_affinity_ttl: 300 # seconds of inactivity before a pin expires;
# bumped on every reuse. Matches Ollama's default keep_alive.
# Optional router-level API key that gates router/API/web UI access (leave empty to disable)
nomyo-router-api-key: ""

View file

@ -206,8 +206,6 @@ The `/health` endpoint provides comprehensive health status:
}
```
For Ollama endpoints the probe is a parallel check of `/api/version` (liveness) and `/api/ps` (the route used by `choose_endpoint` when selecting a backend for a request). Reporting `ok` only when both succeed prevents the router from advertising an endpoint as healthy while completion calls dead-end on `/api/ps`. The same dual probe backs `/api/config`, which the dashboard uses to render endpoint health.
## Database Schema
The router uses SQLite for persistent storage:

View file

@ -166,91 +166,6 @@ With this config the primary handles up to 4 concurrent requests before the seco
---
### `conversation_affinity`
**Type**: `bool` (optional)
**Default**: `false`
**Companion setting**: [`conversation_affinity_ttl`](#conversation_affinity_ttl)
**Description**: When enabled, the router prefers to send follow-up requests of the same conversation back to the endpoint that already served the first turn. This keeps the backend's prompt cache (the llama.cpp / Ollama **KV cache**) warm: the first user turn pays the cold prefill cost, every later turn reuses the same prefix and only generates new tokens. It is a **soft preference** — when the previously-chosen endpoint is no longer eligible (model unloaded, no free slot), the router falls back to the standard selection algorithm (`priority_routing` or random).
#### How a conversation is identified
The router does **not** track session IDs or auth tokens. It computes a stable fingerprint per request from:
```
SHA1( model
+ every leading message with role="system"
+ the first message with role="user" )
```
Anything after the first user turn is ignored — those later messages extend the same KV prefix, so they don't change the cache identity.
**What this means in practice**
| You send… | Fingerprint behaves like… |
|---|---|
| Turn 2 of the same chat (history grows but first system+user are unchanged) | **Same** as turn 1 → pin is reused and TTL refreshed |
| Turn 1 of a fresh chat | **New** fingerprint → new pin |
| Same first user prompt but a different model | **New** fingerprint (model is part of the hash) |
| Same chat but the client mutates the system prompt between turns (e.g. injects a fresh timestamp) | **New** fingerprint — the affinity will not stick |
#### TTL and refresh
Every time `choose_endpoint` returns a pinned endpoint, the entry's expiry is bumped to `now + conversation_affinity_ttl`. An idle conversation drops out of the map once that window elapses without traffic. Default 300 s matches Ollama's default `keep_alive` — once the backend has unloaded the model, the KV cache is gone too, so a stale pin would be pointless anyway.
#### Why the dashboard may show more than one dot per visible conversation
The fingerprint is computed per **HTTP request**, not per chat-window. Most chat UIs (Open WebUI in particular) fire several **auxiliary** requests alongside the real conversation:
- *Title generation* — synthetic system prompt + the user message as content
- *Follow-up question suggestion* — synthetic system prompt + the conversation as content
- *Tag generation*, *memory extraction*, *retrieval query rewriting*, etc.
Each of those has its own `(system + first user turn)` and therefore its own fingerprint and its own pin in [the affinity dot matrix](monitoring.md#affinity-stats-conversation-affinity). They all *correctly* refer to a real warm KV-cache prefix on the backend, so the routing they drive is right — they just don't visually map 1:1 to a user-perceived "conversation."
#### Example
```yaml
endpoints:
- http://gpu-primary:11434
- http://gpu-secondary:11434
conversation_affinity: true
conversation_affinity_ttl: 300
```
With this configuration, a chat that starts on `gpu-primary` will keep returning to `gpu-primary` for follow-up turns as long as the model is still loaded there and a slot is free, even if `gpu-secondary` happens to be more idle at that moment. Cold-prefill cost is paid once instead of once per turn.
#### When to enable
- ✅ Interactive chat workloads with long histories — the prefill savings on every follow-up turn are substantial.
- ✅ Multi-endpoint deployments where models are loaded on more than one node.
- ❌ Pure one-shot / single-turn workloads (no KV-cache to keep warm).
- ❌ When you specifically want strict load-balancing parity — affinity intentionally biases against perfect balance.
---
### `conversation_affinity_ttl`
**Type**: `int` (seconds, optional)
**Default**: `300`
**Description**: How long a conversation stays pinned to its endpoint after the last request that touched it. Refreshed on every reuse — so an actively-used conversation keeps its pin indefinitely; an abandoned one expires after `conversation_affinity_ttl` seconds of silence.
**Recommendation**: leave this aligned with the backend's `keep_alive` window. If the model is unloaded by the backend, the KV cache is gone and there is no benefit to keeping the pin.
**Example**:
```yaml
conversation_affinity: true
conversation_affinity_ttl: 600 # half an hour of inactivity before un-pinning
```
---
### `router_api_key`
**Type**: `str` (optional)

View file

@ -29,10 +29,6 @@ Response:
- `200`: All endpoints healthy
- `503`: One or more endpoints unhealthy
**Probe scope per endpoint**:
- **Ollama endpoints** are probed at both `/api/version` (liveness) and `/api/ps` (model-introspection used by the router). If either fails the endpoint is reported as `error`; the response still includes `version` when the daemon is reachable so operators can tell a partial failure from a full outage. The `detail` field names the failing probe, e.g. `"/api/ps: 502 …"`.
- **OpenAI-compatible / llama-server endpoints** are probed at `/models`.
### Current Usage
```bash
@ -137,8 +133,6 @@ Response:
}
```
Uses the same dual-probe logic as `/health` (Ollama: `/api/version` + `/api/ps`; OpenAI-compatible: `/models`). An endpoint will report `error` whenever either probe fails. The dashboard renders the `detail` field as a tooltip on the status cell.
### Cache Statistics
```bash
@ -172,39 +166,6 @@ curl -X POST http://localhost:12434/api/cache/invalidate
Clears all cached entries and resets hit/miss counters.
### Affinity Stats (Conversation Affinity)
```bash
curl http://localhost:12434/api/affinity_stats
```
Response when [`conversation_affinity`](configuration.md#conversation_affinity) is enabled:
```json
{
"enabled": true,
"ttl": 300,
"entries": [
{ "endpoint": "http://gpu-primary:11434", "model": "llama3.2:latest", "remaining": 287.4 },
{ "endpoint": "http://gpu-primary:11434", "model": "llama3.2:latest", "remaining": 113.0 },
{ "endpoint": "http://gpu-secondary:11434", "model": "qwen2.5-coder:7b", "remaining": 44.8 }
]
}
```
Response when the feature is disabled:
```json
{ "enabled": false, "ttl": 300, "entries": [] }
```
- One element per **live pinned conversation** (no fingerprints or content — just the endpoint/model the pin points to and how many seconds it has left before expiry).
- Aggregation by `(endpoint, model)` is left to the consumer: the dashboard does this client-side.
- The endpoint is gated by the same `nomyo-router-api-key` middleware as the rest of `/api/*`.
The dashboard's **Running Models (PS) → Affinity** column is rendered from this data. The column auto-hides when `enabled: false`. Each row shows one dot per live pin against that `(endpoint, model)` pair; dot opacity = `remaining / ttl` (floor 0.15), so freshly-routed pins are solid and pins close to expiry fade out. A `+N` overflow badge appears once a single (endpoint, model) holds more than 12 active pins; an em-dash (`—`) marks an `(endpoint, model)` with no live pins.
> Multiple dots for what looks like "one chat window" is normal — most chat UIs (Open WebUI, LibreChat, …) fire auxiliary requests (title generation, follow-up suggestions, tag extraction) that have their own first-turn fingerprint and therefore their own pin. See [Conversation Affinity → Why the dashboard may show more than one dot per visible conversation](configuration.md#conversation_affinity) for the details.
### Real-time Usage Stream
```bash

View file

@ -6,7 +6,7 @@ anyio==4.13.0
async-timeout==5.0.1
attrs==26.1.0
certifi==2026.4.22
click==8.4.0
click==8.3.3
distro==1.9.0
exceptiongroup==1.3.1
fastapi==0.136.1
@ -15,11 +15,11 @@ frozenlist==1.8.0
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
idna==3.15
jiter==0.15.0
idna==3.14
jiter==0.14.0
multidict==6.7.1
ollama==0.6.2
openai==2.37.0
openai==1.109.1
orjson>=3.11.5
numpy>=1.26
pillow==12.2.0
@ -32,11 +32,11 @@ PyYAML==6.0.3
sniffio==1.3.1
starlette==0.52.1
truststore==0.10.4
tiktoken==0.13.0
tiktoken==0.12.0
tqdm==4.67.3
typing-inspection==0.4.2
typing_extensions==4.15.0
uvicorn==0.47.0
uvicorn==0.46.0
uvloop
yarl==1.23.0
aiosqlite

500
router.py
View file

@ -2,11 +2,11 @@
title: NOMYO Router - an (O)llama and OpenAI API v1 Proxy with Endpoint:Model aware routing
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
version: 0.9
version: 0.7
license: AGPL
"""
# -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx, hashlib
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx
try:
import truststore; truststore.inject_into_ssl()
except ImportError:
@ -223,15 +223,6 @@ class Config(BaseSettings):
# When True, config order = priority; routes by utilization ratio + config index (WRR)
priority_routing: bool = Field(default=False)
# Conversation affinity: route the same conversation back to the endpoint that
# previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm.
# Soft preference — falls back to the standard algorithm when the affine endpoint
# is saturated or no longer has the model loaded.
conversation_affinity: bool = Field(default=False)
# TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min):
# if the backend has already evicted the model, the KV cache is cold anyway.
conversation_affinity_ttl: int = Field(default=300)
api_keys: Dict[str, str] = Field(default_factory=dict)
# Optional router-level API key used to gate access to this service and dashboard
router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY")
@ -256,8 +247,9 @@ class Config(BaseSettings):
cache_history_weight: float = Field(default=0.3)
class Config:
# YAML loading is handled manually via Config.from_yaml(); env vars use this prefix.
# Load from `config.yaml` first, then from env variables
env_prefix = "NOMYO_ROUTER_"
yaml_file = Path("config.yaml") # relative to cwd
@classmethod
def _expand_env_refs(cls, obj):
@ -444,47 +436,6 @@ token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(
usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic).
# Keeps the same conversation pinned to the endpoint that already has its
# KV-cache prefix warm. Model is stored so the dashboard can aggregate live
# entries per (endpoint, model) without recomputing fingerprints.
# Never held together with usage_lock.
_affinity_map: Dict[str, tuple[str, str, float]] = {}
_affinity_lock = asyncio.Lock()
_AFFINITY_MAX_ENTRIES = 10000
def _conversation_fingerprint(model: str, messages: Optional[list],
prompt: Optional[str]) -> Optional[str]:
"""
Stable hash over (model, first system + first user turn). That prefix
determines whether the backend's prompt cache is reusable; later turns
don't influence the routing decision because they extend the same prefix.
Returns None when there is no usable prefix.
"""
parts: list[str] = [model or "_"]
if messages:
for m in messages:
role = m.get("role") if isinstance(m, dict) else None
if role not in ("system", "user"):
continue
content = m.get("content")
if isinstance(content, list): # OpenAI multimodal parts
content = "".join(
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
if not isinstance(content, str):
continue
parts.append(f"{role}:{content}")
if role == "user":
break
elif prompt:
parts.append(f"user:{prompt}")
else:
return None
return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()
# Database instance
db: "TokenDatabase" = None
@ -1000,7 +951,7 @@ class fetch:
async with client.get(f"{endpoint}/models") as resp:
await _ensure_success(resp)
data = await resp.json()
# Filter for loaded models only
items = data.get("data", [])
models = {
@ -1012,19 +963,11 @@ class fetch:
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
# Probe succeeded — clear any stale error so the endpoint
# becomes routable again.
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/models", e)
print(f"[fetch.loaded_models] {message}")
# Record the failure so `choose_endpoint` can avoid routing
# to an unhealthy backend and repeated probes short-circuit.
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
else:
# Original Ollama /api/ps logic
@ -1039,15 +982,11 @@ class fetch:
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/api/ps", e)
print(f"[fetch.loaded_models] {message}")
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
async def _refresh_loaded_models(endpoint: str) -> None:
@ -1430,30 +1369,30 @@ def resize_image_if_needed(image_data):
pass
# Decode the base64 image data
image_bytes = base64.b64decode(image_data)
with Image.open(io.BytesIO(image_bytes)) as image:
if image.mode not in ("RGB", "L"):
image = image.convert("RGB")
image = Image.open(io.BytesIO(image_bytes))
if image.mode not in ("RGB", "L"):
image = image.convert("RGB")
# Get current size
width, height = image.size
# Get current size
width, height = image.size
# Calculate the new dimensions while maintaining aspect ratio
if width > 512 or height > 512:
aspect_ratio = width / height
if aspect_ratio > 1: # Width is larger
new_width = 512
new_height = int(512 / aspect_ratio)
else: # Height is larger
new_height = 512
new_width = int(512 * aspect_ratio)
# Calculate the new dimensions while maintaining aspect ratio
if width > 512 or height > 512:
aspect_ratio = width / height
if aspect_ratio > 1: # Width is larger
new_width = 512
new_height = int(512 / aspect_ratio)
else: # Height is larger
new_height = 512
new_width = int(512 * aspect_ratio)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Encode the resized image back to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return resized_image_data
# Encode the resized image back to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return resized_image_data
except Exception as e:
print(f"Error processing image: {e}")
@ -1799,8 +1738,7 @@ def get_max_connections(ep: str) -> int:
"max_concurrent_connections", config.max_concurrent_connections
)
async def choose_endpoint(model: str, reserve: bool = True,
affinity_key: Optional[str] = None) -> tuple[str, str]:
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
@ -1810,14 +1748,10 @@ async def choose_endpoint(model: str, reserve: bool = True,
1 Query every endpoint for its advertised models (`/api/tags`).
2 Build a list of endpoints that contain the requested model.
2.5 If conversation affinity is enabled and the caller passes
``affinity_key``, prefer the endpoint that previously served the
same conversation but only when it still has the model loaded
and a free slot. Otherwise fall through to the standard logic.
3 For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
4 If none are both loaded and free, fall back to any endpoint
from the filtered list that simply has a free slot and randomly
from the filtered list that simply has a free slot and randomly
select one.
5 If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
@ -1865,41 +1799,6 @@ async def choose_endpoint(model: str, reserve: bool = True,
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# 3⃣.5 Exclude endpoints whose loaded-model probe has been failing
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
# would appear with an empty loaded set but pass through to the
# free-slot fallback (step 4) — sending completion calls to an
# unhealthy backend. See issue #83.
async with _loaded_error_cache_lock:
unhealthy = {
ep for ep, ts in _loaded_error_cache.items()
if _is_fresh(ts, 300)
}
if unhealthy:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in unhealthy
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# If *every* candidate is unhealthy we still fall through with the
# original list — refusing to route is worse than retrying a
# possibly-recovered backend.
# Look up a possible affinity hint *before* taking usage_lock. The two
# locks are never held together to avoid lock-ordering issues.
affine_ep: Optional[str] = None
if config.conversation_affinity and affinity_key:
async with _affinity_lock:
entry = _affinity_map.get(affinity_key)
if entry is not None:
ep, _stored_model, expires_at = entry
if expires_at < time.monotonic():
_affinity_map.pop(affinity_key, None)
else:
affine_ep = ep
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
@ -1915,75 +1814,59 @@ async def choose_endpoint(model: str, reserve: bool = True,
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
selected: Optional[str] = None
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < get_max_connections(ep)
]
# 2⃣.5 Conversation affinity preference — only honour the hint when
# the affine endpoint still advertises the model loaded *and* has a
# free slot. Otherwise fall back to the standard algorithm.
if affine_ep:
ep_loaded = {
ep: set(models)
for ep, models in zip(candidate_endpoints, loaded_sets)
}
if (affine_ep in candidate_endpoints
and model in ep_loaded.get(affine_ep, set())
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
selected = affine_ep
if selected is None:
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < get_max_connections(ep)
if loaded_and_free:
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
else:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
selected = random.choice(loaded_and_free)
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < get_max_connections(ep)
]
if loaded_and_free:
if endpoints_with_free_slot:
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
selected = random.choice(loaded_and_free)
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < get_max_connections(ep)
]
if endpoints_with_free_slot:
if config.priority_routing:
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = endpoints_with_free_slot[0]
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
selected = min(candidate_endpoints, key=tracking_usage)
selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model)
snapshot = None
@ -1992,15 +1875,6 @@ async def choose_endpoint(model: str, reserve: bool = True,
snapshot = _capture_snapshot()
if snapshot is not None:
await _distribute_snapshot(snapshot)
# Record / refresh affinity *after* releasing usage_lock.
if reserve and config.conversation_affinity and affinity_key:
expires_at = time.monotonic() + config.conversation_affinity_ttl
async with _affinity_lock:
_affinity_map[affinity_key] = (selected, model, expires_at)
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
now = time.monotonic()
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
_affinity_map.pop(k, None)
return selected, tracking_model
# -------------------------------------------------------------
@ -2051,8 +1925,7 @@ async def proxy(request: Request):
yield _cached
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@ -2222,8 +2095,7 @@ async def chat_proxy(request: Request):
opt = True
else:
opt = False
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
@ -3138,43 +3010,6 @@ async def ps_details_proxy(request: Request):
return JSONResponse(content={"models": models}, status_code=200)
# -------------------------------------------------------------
# 18b. Conversation-affinity stats feeds the PS-table dot matrix
# -------------------------------------------------------------
@app.get("/api/affinity_stats")
async def affinity_stats(request: Request):
"""
Aggregate live conversation-affinity pins, one entry per pinned conversation.
Each entry exposes only the endpoint, model, and remaining TTL in seconds
no fingerprints or content. When conversation_affinity is disabled the
`entries` list is always empty.
"""
if not config.conversation_affinity:
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
now = time.monotonic()
entries: list[dict] = []
llama_eps = set(config.llama_server_endpoints)
async with _affinity_lock:
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
remaining = expires_at - now
if remaining <= 0:
_affinity_map.pop(fp, None)
continue
# Mirror the normalisation used by /api/ps_details so the dashboard
# can join affinity entries to PS rows by (endpoint, model).
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
entries.append({
"endpoint": ep,
"model": display_model,
"remaining": round(remaining, 2),
})
return {
"enabled": True,
"ttl": config.conversation_affinity_ttl,
"entries": entries,
}
# -------------------------------------------------------------
# 19. Proxy usage route for monitoring
# -------------------------------------------------------------
@ -3188,103 +3023,44 @@ async def usage_proxy(request: Request):
"token_usage_counts": token_usage_counts}
# -------------------------------------------------------------
# 20. Endpoint health probes (shared by /api/config and /health)
# -------------------------------------------------------------
async def _raw_probe(
ep: str,
route: str,
api_key: Optional[str] = None,
timeout: Optional[float] = None,
) -> tuple[bool, object]:
"""Direct HTTP probe that distinguishes success from failure
(unlike `fetch.endpoint_details`, which returns [] on either).
Returns `(ok, payload_or_error_message)`.
"""
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
req_kwargs = {}
if timeout is not None:
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
try:
client: aiohttp.ClientSession = get_session(ep)
async with client.get(url, headers=headers, **req_kwargs) as resp:
await _ensure_success(resp)
data = await resp.json()
return True, data
except Exception as exc:
return False, _format_connection_issue(url, exc)
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
"""Probe an endpoint and return `{status, version?, detail?}`.
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
that a daemon which is reachable but has a broken model-introspection
path (issue #83) is reported as `error` rather than `ok`.
OpenAI-compatible endpoints use a single `/models` probe.
"""
if is_openai_compatible(ep):
ok, payload = await _raw_probe(
ep, "/models", config.api_keys.get(ep), timeout=timeout,
)
if ok:
return {"status": "ok", "version": "latest"}
return {"status": "error", "detail": str(payload)}
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
_raw_probe(ep, "/api/version", timeout=timeout),
_raw_probe(ep, "/api/ps", timeout=timeout),
)
version_value = (
version_payload.get("version")
if version_ok and isinstance(version_payload, dict)
else None
)
if version_ok and ps_ok:
return {"status": "ok", "version": version_value}
if not version_ok and not ps_ok:
return {"status": "error", "detail": str(version_payload)}
# Partial failure — daemon reachable but one probe failed. Report
# as "error" so callers can surface the issue; include `version` so
# the operator knows the daemon itself is alive.
if not ps_ok:
return {
"status": "error",
"version": version_value,
"detail": f"/api/ps: {ps_payload}",
}
return {
"status": "error",
"detail": f"/api/version: {version_payload}",
}
# -------------------------------------------------------------
# 20b. Proxy config route for monitoring and frontend usage
# 20. Proxy config route for monitoring and frontent usage
# -------------------------------------------------------------
@app.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
Ollama endpoints and llama_server_endpoints. The frontend uses this
to display which endpoints are being proxied and their health.
Status is "error" when either liveness (/api/version) or routing
health (/api/ps) fails see issue #83.
Ollama endpoints and llama_server_endpoints. The frontend uses this to display
which endpoints are being proxied.
"""
async def check(url: str) -> dict:
return {"url": url, **(await _endpoint_health(url, timeout=5))}
async def check_endpoint(url: str):
client: aiohttp.ClientSession = get_session(url)
headers = None
if "/v1" in url:
headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
target_url = f"{url}/models"
else:
target_url = f"{url}/api/version"
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
try:
async with client.get(target_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
await _ensure_success(resp)
data = await resp.json()
if "/v1" in url:
return {"url": url, "status": "ok", "version": "latest"}
else:
return {"url": url, "status": "ok", "version": data.get("version")}
except Exception as e:
detail = _format_connection_issue(target_url, e)
return {"url": url, "status": "error", "detail": detail}
# Check Ollama endpoints
ollama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
# Check llama-server endpoints
llama_results = []
if config.llama_server_endpoints:
llama_results = await asyncio.gather(
*[check(ep) for ep in config.llama_server_endpoints]
)
llama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.llama_server_endpoints])
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
@ -3452,8 +3228,7 @@ async def openai_chat_completions_proxy(request: Request):
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
endpoint, tracking_model = await choose_endpoint(model)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list:
@ -3763,8 +3538,7 @@ async def openai_completions_proxy(request: Request):
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
endpoint, tracking_model = await choose_endpoint(model)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
@ -4096,30 +3870,44 @@ async def health_proxy(request: Request):
"""
Healthcheck endpoint for monitoring the proxy.
* Queries each configured endpoint for both liveness and routing health:
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
OpenAI-compatible endpoints at `/models`.
* Queries each configured endpoint for its `/api/version` response.
* Returns a JSON object containing:
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
- `status`: "ok" if every endpoint replied, otherwise "error".
- `endpoints`: a mapping of endpoint URL `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
# Run all health checks in parallel.
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
# (vLLM, llama-server, external) expose /models, which serves both
# purposes. Probing /api/version alone would miss the case where the
# Ollama process is up but /api/ps is failing — see issue #83.
# Ollama endpoints expose /api/version; OpenAI-compatible endpoints (vLLM,
# llama-server, external) expose /models. Using /api/version against an
# OpenAI-compatible endpoint yields a 404 and noisy log output.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints += llama_eps_extra
probe_results = await asyncio.gather(
*(_endpoint_health(ep) for ep in all_endpoints),
)
tasks = []
for ep in all_endpoints:
if is_openai_compatible(ep):
tasks.append(fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True))
else:
tasks.append(fetch.endpoint_details(ep, "/api/version", "version", skip_error_cache=True))
health_summary = dict(zip(all_endpoints, probe_results))
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
results = await asyncio.gather(*tasks, return_exceptions=True)
health_summary = {}
overall_ok = True
for ep, result in zip(all_endpoints, results):
if isinstance(result, Exception):
# Endpoint did not respond / returned an error
health_summary[ep] = {"status": "error", "detail": str(result)}
overall_ok = False
else:
# Successful response report the reported version (Ollama) or
# indicate the endpoint is reachable (OpenAI-compatible).
if is_openai_compatible(ep):
health_summary[ep] = {"status": "ok"}
else:
health_summary[ep] = {"status": "ok", "version": result}
response_payload = {
"status": "ok" if overall_ok else "error",
@ -4240,16 +4028,6 @@ async def startup_event() -> None:
@app.on_event("shutdown")
async def shutdown_event() -> None:
await close_all_sse_queues()
# Stop background tasks first so they stop touching the DB before we close it.
for t in (token_worker_task, flush_task):
if t is not None:
t.cancel()
try:
await t
except (asyncio.CancelledError, Exception):
pass
await flush_remaining_buffers()
await app_state["session"].close()
@ -4269,11 +4047,7 @@ async def shutdown_event() -> None:
except Exception as e:
print(f"[shutdown] Error closing httpx client {ep}: {e}")
# Close the aiosqlite connection last — its worker thread is non-daemon
# and would otherwise keep the interpreter alive after lifespan completes.
if db is not None:
try:
await db.close()
print("[shutdown] Closed token DB connection.")
except Exception as e:
print(f"[shutdown] Error closing DB: {e}")
if token_worker_task is not None:
token_worker_task.cancel()
if flush_task is not None:
flush_task.cancel()

View file

@ -121,45 +121,6 @@
.ps-subrow + .ps-subrow {
margin-top: 2px;
}
#ps-table .affinity-col,
#ps-table .affinity-cell {
display: none;
}
#ps-table.affinity-on .affinity-col,
#ps-table.affinity-on .affinity-cell {
display: table-cell;
width: 90px;
text-align: center;
padding-left: 6px;
padding-right: 6px;
}
#ps-table.affinity-on .affinity-dots {
max-width: 78px;
}
.affinity-dots {
display: inline-flex;
flex-wrap: wrap;
gap: 3px;
align-items: center;
line-height: 1;
}
.affinity-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: #2e7d32;
display: inline-block;
transition: opacity 1s linear;
}
.affinity-overflow {
font-size: 10px;
color: #555;
margin-left: 2px;
}
.affinity-empty {
color: #bbb;
font-size: 11px;
}
#ps-table {
width: max-content;
min-width: 100%;
@ -170,13 +131,13 @@
max-width: 300px;
white-space: nowrap;
}
/* Optimize narrow columns (Params / Quant / Ctx) */
/* Optimize narrow columns */
#ps-table th:nth-child(3),
#ps-table td:nth-child(3),
#ps-table th:nth-child(4),
#ps-table td:nth-child(4),
#ps-table th:nth-child(5),
#ps-table td:nth-child(5),
#ps-table th:nth-child(6),
#ps-table td:nth-child(6) {
#ps-table td:nth-child(5) {
width: 80px;
text-align: center;
}
@ -192,10 +153,6 @@
color: #8b0000;
font-weight: bold;
}
.status-error[title] {
cursor: help;
text-decoration: underline dotted;
}
.copy-link,
.delete-link,
.show-link,
@ -438,7 +395,6 @@
<tr>
<th class="model-col">Model</th>
<th>Endpoint</th>
<th class="affinity-col" title="Live conversation-affinity pins (KV-cache warm). One dot per pinned conversation; opacity fades toward TTL expiry.">Affinity</th>
<th>Params</th>
<th>Quant</th>
<th>Ctx</th>
@ -450,7 +406,7 @@
</thead>
<tbody id="ps-body">
<tr>
<td colspan="10" class="loading">Loading…</td>
<td colspan="6" class="loading">Loading…</td>
</tr>
</tbody>
</table>
@ -740,16 +696,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return await resp.json();
}
function escapeHtml(value) {
if (value === null || value === undefined) return "";
return String(value)
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#39;");
}
function toggleDarkMode() {
document.documentElement.classList.toggle("dark-mode");
}
@ -766,24 +712,40 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
// Build HTML for both endpoints and llama_server_endpoints
let html = "";
const renderRow = (e) => {
const statusClass =
e.status === "ok" ? "status-ok" : "status-error";
const version = e.version || "N/A";
const titleAttr = e.detail
? ` title="${escapeHtml(e.detail)}"`
: "";
return `
// Add Ollama endpoints
html += data.endpoints
.map((e) => {
const statusClass =
e.status === "ok"
? "status-ok"
: "status-error";
const version = e.version || "N/A";
return `
<tr>
<td class="endpoint">${escapeHtml(e.url)}</td>
<td class="status ${statusClass}"${titleAttr}>${escapeHtml(e.status)}</td>
<td class="version">${escapeHtml(version)}</td>
<td class="endpoint">${e.url}</td>
<td class="status ${statusClass}">${e.status}</td>
<td class="version">${version}</td>
</tr>`;
};
html += data.endpoints.map(renderRow).join("");
})
.join("");
// Add llama-server endpoints
if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) {
html += data.llama_server_endpoints.map(renderRow).join("");
html += data.llama_server_endpoints
.map((e) => {
const statusClass =
e.status === "ok"
? "status-ok"
: "status-error";
const version = e.version || "N/A";
return `
<tr>
<td class="endpoint">${e.url}</td>
<td class="status ${statusClass}">${e.status}</td>
<td class="version">${version}</td>
</tr>`;
})
.join("");
}
body.innerHTML = html;
@ -970,14 +932,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return items.map((item) => `<div class="ps-subrow">${item || ""}</div>`).join("");
};
const escapeAttr = (s) => String(s).replace(/&/g, "&amp;").replace(/"/g, "&quot;").replace(/</g, "&lt;").replace(/>/g, "&gt;");
const renderAffinitySlots = (endpoints, modelName) => {
if (!endpoints.length) return "";
return endpoints
.map((ep) => `<div class="ps-subrow"><span class="affinity-dots" data-endpoint="${escapeAttr(ep)}" data-model="${escapeAttr(modelName)}"></span></div>`)
.join("");
};
body.innerHTML = Array.from(grouped.entries())
.map(([modelName, modelInstances]) => {
const existingRow = psRows.get(modelName);
@ -1001,7 +955,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return `<tr data-model="${modelName}" data-endpoints="${endpointsData}">
<td class="model"><span style="color:${getColor(modelName)}">${modelName}</span> <a href="#" class="stats-link" data-model="${modelName}">stats</a></td>
<td>${renderInstanceList(endpoints)}</td>
<td class="affinity-cell">${renderAffinitySlots(endpoints, modelName)}</td>
<td>${params}</td>
<td>${quant}</td>
<td>${ctx}</td>
@ -1019,83 +972,11 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
const model = row.dataset.model;
if (model) psRows.set(model, row);
});
renderAffinityDots();
} catch (e) {
console.error(e);
}
}
/* ---------- Conversation-affinity dots ---------- */
const AFFINITY_MAX_DOTS = 12;
let affinityIndex = new Map(); // `${endpoint}|${model}` -> array of {expiresAt}
let affinityTtl = 300;
let affinityEnabled = false;
async function loadAffinity() {
try {
const data = await fetchJSON("/api/affinity_stats");
affinityEnabled = !!data.enabled;
affinityTtl = Number(data.ttl) || 300;
const now = Date.now() / 1000;
const idx = new Map();
for (const e of data.entries || []) {
const key = `${e.endpoint}|${e.model}`;
if (!idx.has(key)) idx.set(key, []);
idx.get(key).push({ expiresAt: now + Number(e.remaining) });
}
affinityIndex = idx;
applyAffinityColumnVisibility();
renderAffinityDots();
} catch (err) {
// Endpoint may 404 on older deployments — silently degrade.
affinityEnabled = false;
affinityIndex = new Map();
applyAffinityColumnVisibility();
renderAffinityDots();
}
}
function applyAffinityColumnVisibility() {
const table = document.getElementById("ps-table");
if (!table) return;
table.classList.toggle("affinity-on", affinityEnabled);
}
function renderAffinityDots() {
const spans = document.querySelectorAll(".affinity-dots");
if (!spans.length) return;
const now = Date.now() / 1000;
spans.forEach((span) => {
const ep = span.dataset.endpoint;
const mdl = span.dataset.model;
const key = `${ep}|${mdl}`;
const pins = (affinityIndex.get(key) || []).filter((p) => p.expiresAt > now);
if (pins.length !== (affinityIndex.get(key) || []).length) {
if (pins.length) affinityIndex.set(key, pins);
else affinityIndex.delete(key);
}
if (!pins.length) {
span.innerHTML = affinityEnabled
? `<span class="affinity-empty"></span>`
: "";
return;
}
// Sort freshest first so visible dots are the most "recent".
pins.sort((a, b) => b.expiresAt - a.expiresAt);
const visible = pins.slice(0, AFFINITY_MAX_DOTS);
const overflow = pins.length - visible.length;
const dotsHtml = visible
.map((p) => {
const remaining = Math.max(0, p.expiresAt - now);
const opacity = Math.max(0.15, Math.min(1, remaining / affinityTtl));
const secs = Math.round(remaining);
return `<span class="affinity-dot" style="opacity:${opacity.toFixed(2)}" title="pin expires in ${secs}s"></span>`;
})
.join("");
span.innerHTML = dotsHtml + (overflow > 0 ? `<span class="affinity-overflow">+${overflow}</span>` : "");
});
}
/* ---------- Usage Chart (stackedpercentage) ---------- */
function getColor(seed) {
const h = Math.abs(hashString(seed) % 360);
@ -1292,13 +1173,10 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
loadEndpoints();
loadTags();
loadPS();
loadAffinity();
loadUsage();
initHeaderChart();
setInterval(tickTpsChart, 1000);
setInterval(loadPS, 60_000);
setInterval(loadAffinity, 15_000);
setInterval(renderAffinityDots, 2_000);
setInterval(loadEndpoints, 300_000);
/* show logic */

View file

@ -1,13 +0,0 @@
endpoints:
- http://192.168.0.51:12434
llama_server_endpoints:
- http://192.168.0.51:12434/v1
max_concurrent_connections: 2
api_keys:
"http://192.168.0.51:12434": "ollama"
"http://192.168.0.51:12434/v1": "llama"
cache_enabled: false

View file

@ -1,233 +0,0 @@
"""
Test configuration for nomyo-router.
Run from project root:
pytest test/ -v
pytest test/ -m "not integration" # skip real-server tests
pytest test/ -m integration -v # only real-server tests
Environment variables:
NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434)
NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1)
NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset)
NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset)
"""
import asyncio
import os
import ssl
import sys
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import aiohttp
import httpx
import pytest
_TEST_DIR = Path(__file__).parent
# Must be set before importing router so module-level Config.from_yaml + Config field
# defaults pick these up. db_path is intentionally absent from config_test.yaml so the
# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows).
os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml"))
os.environ.setdefault(
"NOMYO_ROUTER_DB_PATH",
str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"),
)
sys.path.insert(0, str(_TEST_DIR.parent))
import router # noqa: E402
TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434")
TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1")
def pytest_configure(config):
config.addinivalue_line(
"markers",
"integration: tests that require a real backend at 192.168.0.50:12434",
)
# ── Config mocks ─────────────────────────────────────────────────────────────
@pytest.fixture
def mock_config():
"""Minimal config pointing at TEST_OLLAMA / TEST_LLAMA."""
cfg = MagicMock()
cfg.endpoints = [TEST_OLLAMA]
cfg.llama_server_endpoints = [TEST_LLAMA]
cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
cfg.max_concurrent_connections = 2
cfg.router_api_key = None
cfg.cache_enabled = False
return cfg
@pytest.fixture
def mock_config_no_llama():
"""Config with Ollama only, no llama-server."""
cfg = MagicMock()
cfg.endpoints = [TEST_OLLAMA]
cfg.llama_server_endpoints = []
cfg.api_keys = {TEST_OLLAMA: "ollama"}
cfg.max_concurrent_connections = 2
cfg.router_api_key = None
cfg.cache_enabled = False
return cfg
@pytest.fixture
def mock_config_with_key():
"""Config with router_api_key set (enables auth middleware)."""
cfg = MagicMock()
cfg.endpoints = [TEST_OLLAMA]
cfg.llama_server_endpoints = []
cfg.api_keys = {}
cfg.max_concurrent_connections = 2
cfg.router_api_key = "test-secret-key"
cfg.cache_enabled = False
return cfg
# ── aiohttp session (used by fetch tests + choose_endpoint tests) ─────────────
@pytest.fixture
async def aio_session():
"""Real aiohttp session stored in app_state; intercepted by aioresponses."""
ssl_ctx = ssl.create_default_context()
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
session = aiohttp.ClientSession(connector=conn)
router.app_state["session"] = session
# Clear caches to prevent test bleed
router._models_cache.clear()
router._loaded_models_cache.clear()
router._available_error_cache.clear()
router._loaded_error_cache.clear()
router._inflight_available_models.clear()
router._inflight_loaded_models.clear()
router._bg_refresh_available.clear()
router._bg_refresh_loaded.clear()
yield session
await session.close()
router.app_state["session"] = None
# ── Validation-only HTTP client (no real backend needed) ──────────────────────
@pytest.fixture
async def client(mock_config, tmp_path):
"""httpx client for validation/auth tests — no real backend calls made."""
from db import TokenDatabase
ssl_ctx = ssl.create_default_context()
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
session = aiohttp.ClientSession(connector=conn)
db_inst = TokenDatabase(str(tmp_path / "test.db"))
await db_inst.init_db()
old_session = router.app_state.get("session")
old_db = router.db
router.app_state["session"] = session
router.db = db_inst
with patch.object(router, "config", mock_config):
transport = httpx.ASGITransport(app=router.app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test", timeout=10.0
) as c:
yield c
await session.close()
router.app_state["session"] = old_session
router.db = old_db
@pytest.fixture
async def client_auth(mock_config_with_key, tmp_path):
"""httpx client with router_api_key configured (for auth middleware tests)."""
from db import TokenDatabase
ssl_ctx = ssl.create_default_context()
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
session = aiohttp.ClientSession(connector=conn)
db_inst = TokenDatabase(str(tmp_path / "test_auth.db"))
await db_inst.init_db()
old_session = router.app_state.get("session")
old_db = router.db
router.app_state["session"] = session
router.db = db_inst
with patch.object(router, "config", mock_config_with_key):
transport = httpx.ASGITransport(app=router.app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test", timeout=10.0
) as c:
yield c
await session.close()
router.app_state["session"] = old_session
router.db = old_db
# ── Integration client (full startup with real backend) ──────────────────────
@pytest.fixture(scope="module")
async def integration_client():
"""Full app startup pointing at the real test server."""
await router.startup_event()
transport = httpx.ASGITransport(app=router.app)
async with httpx.AsyncClient(
transport=transport,
base_url="http://test",
timeout=httpx.Timeout(60.0),
) as c:
yield c
await router.shutdown_event()
# ── Model discovery fixtures ──────────────────────────────────────────────────
@pytest.fixture(scope="module")
async def chat_model(integration_client):
"""Return a chat/generation model name available on the test server."""
env_model = os.getenv("NOMYO_TEST_MODEL_CHAT")
if env_model:
return env_model
resp = await integration_client.get("/api/tags")
if resp.status_code != 200:
pytest.skip("Cannot reach test server")
models = resp.json().get("models", [])
# Prefer small models for faster tests
for m in models:
name = m.get("name", "")
if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]):
return name
if models:
return models[0]["name"]
pytest.skip("No chat models available on test server")
@pytest.fixture(scope="module")
async def embed_model(integration_client):
"""Return an embedding model name available on the test server."""
env_model = os.getenv("NOMYO_TEST_EMBED_MODEL")
if env_model:
return env_model
resp = await integration_client.get("/api/tags")
if resp.status_code != 200:
pytest.skip("Cannot reach test server")
models = resp.json().get("models", [])
for m in models:
name = m.get("name", "")
if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]):
return name
pytest.skip("No embedding model available on test server")

View file

@ -1,7 +0,0 @@
[pytest]
asyncio_mode = auto
markers =
integration: tests that require a real backend at 192.168.0.51:12434
testpaths = .
filterwarnings =
ignore::pytest.PytestUnhandledThreadExceptionWarning

View file

@ -1,4 +0,0 @@
pytest>=8.0
pytest-asyncio>=0.24
pytest-cov>=5.0
aioresponses>=0.7

View file

@ -1,60 +0,0 @@
# Testing nomyo-router
## Setup
Install test dependencies (from the project root):
```bash
pip install -r test/requirements_test.txt
```
## Running tests
All commands run from the `test/` directory:
```bash
cd test
```
**All non-integration tests** (no backend required):
```bash
pytest -m "not integration" -v
```
**Integration tests only** (requires backend at `192.168.0.51:12434`):
```bash
pytest -m integration -v
```
**Everything:**
```bash
pytest -v
```
## Test structure
| File | What it covers | Backend needed |
|---|---|---|
| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No |
| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No |
| `test_unit_context.py` | Context window trimming logic | No |
| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No |
| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No |
| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No |
| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** |
## Integration test backend
Integration tests start the router in-process via `startup_event()` and route traffic
through `httpx.ASGITransport` — no separately running router instance is needed.
They do require a reachable Ollama or llama-server backend. Override the defaults via
environment variables:
```bash
export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434
export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise
export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise
```
If the backend is unreachable, integration tests are automatically skipped.

View file

@ -1,304 +0,0 @@
"""
Integration tests against the real backend at 192.168.0.50:12434.
Run with:
pytest test/test_api_integration.py -v -m integration
All tests in this file are marked @pytest.mark.integration.
They require the test server to be reachable and to have at least one
chat model and one embedding model available.
Env vars to pin specific models:
NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b
NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest
"""
import json
import pytest
pytestmark = pytest.mark.integration
# ── Health / discovery routes ─────────────────────────────────────────────────
class TestDiscoveryRoutes:
async def test_version(self, integration_client):
resp = await integration_client.get("/api/version")
assert resp.status_code == 200
data = resp.json()
assert "version" in data
assert isinstance(data["version"], str)
async def test_tags_returns_models(self, integration_client):
resp = await integration_client.get("/api/tags")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert isinstance(data["models"], list)
assert len(data["models"]) > 0
async def test_ps_returns_list(self, integration_client):
resp = await integration_client.get("/api/ps")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert isinstance(data["models"], list)
async def test_v1_models_returns_data(self, integration_client):
resp = await integration_client.get("/v1/models")
assert resp.status_code == 200
data = resp.json()
assert "data" in data
assert isinstance(data["data"], list)
async def test_usage_returns_counts(self, integration_client):
resp = await integration_client.get("/api/usage")
assert resp.status_code == 200
data = resp.json()
assert "usage_counts" in data
assert "token_usage_counts" in data
async def test_config_returns_endpoints(self, integration_client):
resp = await integration_client.get("/api/config")
assert resp.status_code == 200
data = resp.json()
assert "endpoints" in data
async def test_hostname(self, integration_client):
resp = await integration_client.get("/api/hostname")
assert resp.status_code == 200
assert "hostname" in resp.json()
async def test_health(self, integration_client):
resp = await integration_client.get("/health")
assert resp.status_code in (200, 503)
data = resp.json()
assert data["status"] in ("ok", "error")
assert "endpoints" in data
async def test_cache_stats(self, integration_client):
resp = await integration_client.get("/api/cache/stats")
assert resp.status_code == 200
data = resp.json()
assert "enabled" in data
# ── /api/chat ─────────────────────────────────────────────────────────────────
class TestApiChat:
async def test_non_streaming(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": False,
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
"options": {"num_predict": 10},
},
)
assert resp.status_code == 200
data = resp.json()
assert "message" in data
assert "content" in data["message"]
async def test_streaming_ndjson(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": True,
"messages": [{"role": "user", "content": "Say hi"}],
"options": {"num_predict": 5},
},
)
assert resp.status_code == 200
lines = [l for l in resp.text.strip().split("\n") if l.strip()]
assert len(lines) >= 1
for line in lines:
obj = json.loads(line)
assert "model" in obj
async def test_non_streaming_has_token_counts(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": False,
"messages": [{"role": "user", "content": "Count to 3"}],
"options": {"num_predict": 20},
},
)
assert resp.status_code == 200
data = resp.json()
assert data.get("done") is True
# Token counts should be present in the final chunk
assert data.get("prompt_eval_count", 0) >= 0
async def test_system_message_honoured(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"},
{"role": "user", "content": "PING"},
],
"options": {"num_predict": 10},
},
)
assert resp.status_code == 200
content = resp.json()["message"]["content"]
assert isinstance(content, str)
assert len(content) > 0
# ── /api/generate ─────────────────────────────────────────────────────────────
class TestApiGenerate:
async def test_non_streaming(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/generate",
json={
"model": chat_model,
"prompt": "Complete: The sky is",
"stream": False,
"options": {"num_predict": 5},
},
)
assert resp.status_code == 200
data = resp.json()
assert "response" in data
async def test_streaming(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/generate",
json={
"model": chat_model,
"prompt": "One plus one equals",
"stream": True,
"options": {"num_predict": 5},
},
)
assert resp.status_code == 200
lines = [l for l in resp.text.strip().split("\n") if l.strip()]
assert len(lines) >= 1
# ── /api/embed ────────────────────────────────────────────────────────────────
class TestApiEmbed:
async def test_embed_single_string(self, integration_client, embed_model):
resp = await integration_client.post(
"/api/embed",
json={"model": embed_model, "input": "The quick brown fox"},
)
assert resp.status_code == 200
data = resp.json()
assert "embeddings" in data
assert isinstance(data["embeddings"], list)
assert len(data["embeddings"]) == 1
assert len(data["embeddings"][0]) > 0
async def test_embed_multiple_inputs(self, integration_client, embed_model):
resp = await integration_client.post(
"/api/embed",
json={"model": embed_model, "input": ["sentence one", "sentence two"]},
)
assert resp.status_code == 200
data = resp.json()
assert "embeddings" in data
assert len(data["embeddings"]) == 2
# ── /v1/chat/completions ──────────────────────────────────────────────────────
class TestOpenAIChatCompletions:
async def test_non_streaming(self, integration_client, chat_model):
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
resp = await integration_client.post(
"/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": "Reply OK"}],
"max_tokens": 10,
"stream": False,
},
)
assert resp.status_code == 200
data = resp.json()
assert "choices" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]
async def test_streaming_sse(self, integration_client, chat_model):
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
resp = await integration_client.post(
"/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
"stream": True,
},
)
assert resp.status_code == 200
# Response should be SSE format
assert "data:" in resp.text or "[DONE]" in resp.text
async def test_non_streaming_has_usage(self, integration_client, chat_model):
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
resp = await integration_client.post(
"/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": "Say yes"}],
"max_tokens": 5,
"stream": False,
},
)
assert resp.status_code == 200
data = resp.json()
if "usage" in data and data["usage"]:
assert data["usage"].get("prompt_tokens", 0) >= 0
# ── /v1/embeddings ────────────────────────────────────────────────────────────
class TestOpenAIEmbeddings:
async def test_single_input(self, integration_client, embed_model):
model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model
resp = await integration_client.post(
"/v1/embeddings",
json={"model": model, "input": "Test sentence"},
)
assert resp.status_code == 200
data = resp.json()
assert "data" in data
assert len(data["data"]) > 0
embedding = data["data"][0].get("embedding")
assert isinstance(embedding, list)
assert len(embedding) > 0
# ── Token counts (database-backed) ───────────────────────────────────────────
class TestTokenCounts:
async def test_token_counts_endpoint(self, integration_client):
resp = await integration_client.get("/api/token_counts")
assert resp.status_code == 200
data = resp.json()
assert "total_tokens" in data
assert "breakdown" in data
# ── ps_details (extended ps) ─────────────────────────────────────────────────
class TestPsDetails:
async def test_ps_details_returns_models(self, integration_client):
resp = await integration_client.get("/api/ps_details")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert isinstance(data["models"], list)

View file

@ -1,230 +0,0 @@
"""
HTTP-level validation and auth middleware tests.
These tests use an in-process httpx client and never reach a real backend:
all requests are rejected at the validation or auth layer before any
endpoint-selection or upstream HTTP calls occur.
"""
import pytest
class TestChatValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post(
"/api/chat",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status_code == 400
assert "model" in resp.json()["detail"].lower()
async def test_missing_messages_returns_400(self, client):
resp = await client.post("/api/chat", json={"model": "llama3.2"})
assert resp.status_code == 400
async def test_invalid_json_returns_400(self, client):
resp = await client.post(
"/api/chat",
content=b"not-json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
async def test_messages_not_list_returns_400(self, client):
resp = await client.post(
"/api/chat",
json={"model": "m", "messages": "not-a-list"},
)
assert resp.status_code == 400
async def test_options_not_dict_returns_400(self, client):
resp = await client.post(
"/api/chat",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"},
)
assert resp.status_code == 400
class TestGenerateValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/generate", json={"prompt": "hello"})
assert resp.status_code == 400
assert "model" in resp.json()["detail"].lower()
async def test_missing_prompt_returns_400(self, client):
resp = await client.post("/api/generate", json={"model": "m"})
assert resp.status_code == 400
assert "prompt" in resp.json()["detail"].lower()
async def test_invalid_json_returns_400(self, client):
resp = await client.post(
"/api/generate",
content=b"{bad-json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
class TestEmbedValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/embed", json={"input": "hello"})
assert resp.status_code == 400
async def test_missing_input_returns_400(self, client):
resp = await client.post("/api/embed", json={"model": "nomic-embed-text"})
assert resp.status_code == 400
class TestEmbeddingsValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/embeddings", json={"prompt": "hello"})
assert resp.status_code == 400
async def test_missing_prompt_returns_400(self, client):
resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"})
assert resp.status_code == 400
class TestOpenAIChatValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status_code == 400
async def test_missing_messages_returns_400(self, client):
resp = await client.post(
"/v1/chat/completions",
json={"model": "gpt-4o"},
)
assert resp.status_code == 400
async def test_invalid_json_returns_400(self, client):
resp = await client.post(
"/v1/chat/completions",
content=b"}{",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
async def test_svg_image_rejected(self, client):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "vision-model",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}},
],
}],
},
)
assert resp.status_code == 400
assert "svg" in resp.json()["detail"].lower()
class TestOpenAICompletionsValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/v1/completions", json={"prompt": "hello"})
assert resp.status_code == 400
async def test_missing_prompt_returns_400(self, client):
resp = await client.post("/v1/completions", json={"model": "m"})
assert resp.status_code == 400
class TestRerankValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post(
"/v1/rerank",
json={"query": "search query", "documents": ["doc1"]},
)
assert resp.status_code == 400
async def test_missing_query_returns_400(self, client):
resp = await client.post(
"/v1/rerank",
json={"model": "reranker", "documents": ["doc1"]},
)
assert resp.status_code == 400
async def test_empty_documents_returns_400(self, client):
resp = await client.post(
"/v1/rerank",
json={"model": "reranker", "query": "search", "documents": []},
)
assert resp.status_code == 400
class TestShowValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/show", json={})
assert resp.status_code == 400
class TestCopyValidation:
async def test_missing_source_returns_400(self, client):
resp = await client.post("/api/copy", json={"destination": "dst"})
assert resp.status_code == 400
async def test_missing_destination_returns_400(self, client):
resp = await client.post("/api/copy", json={"source": "src"})
assert resp.status_code == 400
class TestDeleteValidation:
async def test_missing_model_returns_400(self, client):
import json as _json
resp = await client.request(
"DELETE",
"/api/delete",
content=_json.dumps({}).encode(),
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
class TestAuthMiddleware:
async def test_no_key_returns_401(self, client_auth):
resp = await client_auth.post(
"/api/chat",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status_code == 401
assert "Missing" in resp.json()["detail"]
async def test_invalid_key_returns_403(self, client_auth):
resp = await client_auth.post(
"/api/chat",
headers={"Authorization": "Bearer wrong-key"},
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status_code == 403
assert "Invalid" in resp.json()["detail"]
async def test_valid_key_passes_middleware(self, client_auth):
# /api/usage reads in-memory counters only — no backend call needed
resp = await client_auth.get(
"/api/usage",
headers={"Authorization": "Bearer test-secret-key"},
)
assert resp.status_code == 200
async def test_key_via_query_param(self, client_auth):
resp = await client_auth.get("/api/usage?api_key=test-secret-key")
assert resp.status_code == 200
async def test_options_bypasses_auth(self, client_auth):
resp = await client_auth.options("/api/chat")
assert resp.status_code not in (401, 403)
async def test_root_path_bypasses_auth(self, client_auth):
resp = await client_auth.get("/")
assert resp.status_code not in (401, 403)
async def test_favicon_bypasses_auth(self, client_auth):
resp = await client_auth.get("/favicon.ico")
# Should not be blocked by auth (may 404 in test but not 401/403)
assert resp.status_code not in (401, 403)

View file

@ -1,333 +0,0 @@
"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed)."""
import tempfile
from pathlib import Path
from types import SimpleNamespace
import orjson
import pytest
import cache as cache_mod
from cache import (
LLMCache,
_bm25_weighted_text,
get_llm_cache,
init_llm_cache,
openai_nonstream_to_sse,
)
_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db")
def _exact_cfg(backend: str = "memory") -> SimpleNamespace:
"""Config for exact-match mode — similarity=1.0 avoids embedding deps."""
return SimpleNamespace(
cache_enabled=True,
cache_backend=backend,
cache_similarity=1.0,
cache_history_weight=0.3,
cache_ttl=300,
cache_db_path=_CACHE_DB_PATH,
cache_redis_url="redis://localhost:6379",
)
# ──────────────────────────────────────────────────────────────────────────────
# Pure helpers
# ──────────────────────────────────────────────────────────────────────────────
class TestBM25WeightedText:
def test_empty_history(self):
assert _bm25_weighted_text([]) == ""
def test_history_without_content(self):
assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == ""
def test_repeats_high_idf_terms(self):
history = [
{"role": "user", "content": "Tell me about quantum entanglement"},
{"role": "assistant", "content": "Quantum entanglement is a phenomenon"},
{"role": "user", "content": "How does entanglement work?"},
]
out = _bm25_weighted_text(history)
# Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped
assert "entanglement" in out
assert "is" not in out.split()
# ──────────────────────────────────────────────────────────────────────────────
# openai_nonstream_to_sse
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAINonstreamToSSE:
def test_valid_chat_completion(self):
chat = {
"id": "x1",
"created": 123,
"model": "gpt-4o",
"choices": [{"message": {"role": "assistant", "content": "hello"}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o")
text = out.decode()
assert text.startswith("data: ")
assert text.endswith("data: [DONE]\n\n")
# First chunk contains the original content
first = text.split("\n\n")[0][len("data: "):]
parsed = orjson.loads(first)
assert parsed["choices"][0]["delta"]["content"] == "hello"
assert parsed["usage"]["total_tokens"] == 3
def test_corrupt_bytes_return_done_only(self):
out = openai_nonstream_to_sse(b"not-json", "m")
assert out == b"data: [DONE]\n\n"
# ──────────────────────────────────────────────────────────────────────────────
# LLMCache internal helpers
# ──────────────────────────────────────────────────────────────────────────────
class TestLLMCacheParsing:
def test_namespace_is_stable_and_isolated(self):
c = LLMCache(_exact_cfg())
a = c._namespace("chat", "m1", "system A")
b = c._namespace("chat", "m1", "system A")
assert a == b
assert c._namespace("chat", "m1", "system B") != a
assert c._namespace("generate", "m1", "system A") != a
assert len(a) == 16
def test_parse_messages_flat_strings(self):
c = LLMCache(_exact_cfg())
sys, hist, last = c._parse_messages([
{"role": "system", "content": "be helpful"},
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
{"role": "user", "content": "what is 2+2?"},
])
assert sys == "be helpful"
assert last == "what is 2+2?"
assert hist == [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
def test_parse_messages_multimodal_content(self):
c = LLMCache(_exact_cfg())
sys, _hist, last = c._parse_messages([
{"role": "system", "content": "sys"},
{"role": "user", "content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]},
])
assert sys == "sys"
assert last == "describe"
def test_parse_messages_no_user_message(self):
c = LLMCache(_exact_cfg())
sys, hist, last = c._parse_messages([
{"role": "system", "content": "sys only"},
])
assert sys == "sys only"
assert last == ""
assert hist == []
class TestPersonalTokenExtraction:
def test_email_extracted(self):
c = LLMCache(_exact_cfg())
toks = c._extract_personal_tokens("Reach me at alice@example.com please")
assert "alice@example.com" in toks
def test_numeric_id_after_keyword(self):
c = LLMCache(_exact_cfg())
toks = c._extract_personal_tokens("User id: 123456")
assert "123456" in toks
def test_identity_tag_names_extracted(self):
c = LLMCache(_exact_cfg())
toks = c._extract_personal_tokens(
"[Tags: identity] User's name is Andreas Schwibbe"
)
# Both name tokens should be extracted lowercased; stopwords dropped
assert "andreas" in toks
assert "schwibbe" in toks
assert "name" not in toks # in _IDENTITY_STOPWORDS
assert "user" not in toks
def test_empty_system_returns_empty_set(self):
c = LLMCache(_exact_cfg())
assert c._extract_personal_tokens("") == frozenset()
class TestResponseIsPersonalized:
def _resp(self, content: str) -> bytes:
return orjson.dumps({"choices": [{"message": {"content": content}}]})
def test_email_in_response_is_personalized(self):
c = LLMCache(_exact_cfg())
assert c._response_is_personalized(self._resp("contact bob@x.com"), "")
def test_uuid_in_response_is_personalized(self):
c = LLMCache(_exact_cfg())
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert c._response_is_personalized(self._resp(f"id={uuid}"), "")
def test_long_numeric_id_in_response_is_personalized(self):
c = LLMCache(_exact_cfg())
assert c._response_is_personalized(self._resp("account 12345678"), "")
def test_identity_token_from_system_echoed_in_response(self):
c = LLMCache(_exact_cfg())
system = "[Tags: identity] Andreas works here"
assert c._response_is_personalized(
self._resp("Yes, Andreas is logged in"), system
)
def test_generic_response_not_personalized(self):
c = LLMCache(_exact_cfg())
assert not c._response_is_personalized(
self._resp("The capital of France is Paris."), "be helpful"
)
def test_ollama_message_format_parsed(self):
c = LLMCache(_exact_cfg())
body = orjson.dumps({"message": {"content": "alice@example.com"}})
assert c._response_is_personalized(body, "")
def test_unparseable_body_with_bytes_is_conservative(self):
c = LLMCache(_exact_cfg())
# Can't parse → returns True (err on the side of privacy)
assert c._response_is_personalized(b"binary-junk", "")
def test_empty_response_not_personalized(self):
c = LLMCache(_exact_cfg())
assert not c._response_is_personalized(b"", "anything")
# ──────────────────────────────────────────────────────────────────────────────
# End-to-end exact-match cache with the memory backend
# ──────────────────────────────────────────────────────────────────────────────
@pytest.fixture
async def memcache():
"""LLMCache wired up with the in-memory backend (no external deps)."""
c = LLMCache(_exact_cfg("memory"))
await c.init()
return c
class TestExactMatchCache:
async def test_miss_then_set_then_hit(self, memcache):
msgs = [
{"role": "system", "content": "be helpful"},
{"role": "user", "content": "what is 2+2?"},
]
resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]})
assert await memcache.get_chat("chat", "m1", msgs) is None
await memcache.set_chat("chat", "m1", msgs, resp)
hit = await memcache.get_chat("chat", "m1", msgs)
assert hit == resp
async def test_namespace_isolation_by_system(self, memcache):
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
msgs_a = [
{"role": "system", "content": "system A"},
{"role": "user", "content": "same question"},
]
msgs_b = [
{"role": "system", "content": "system B"},
{"role": "user", "content": "same question"},
]
await memcache.set_chat("chat", "m", msgs_a, resp)
# Same question + different system prompt = different namespace = miss
assert await memcache.get_chat("chat", "m", msgs_b) is None
async def test_namespace_isolation_by_route(self, memcache):
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
msgs = [{"role": "user", "content": "ping"}]
await memcache.set_chat("chat", "m", msgs, resp)
assert await memcache.get_chat("openai_chat", "m", msgs) is None
async def test_no_user_message_is_noop(self, memcache):
msgs = [{"role": "system", "content": "sys only"}]
resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]})
# Both get and set should silently no-op
assert await memcache.get_chat("chat", "m", msgs) is None
await memcache.set_chat("chat", "m", msgs, resp)
assert await memcache.get_chat("chat", "m", msgs) is None
async def test_personalized_response_generic_system_not_stored(self, memcache):
msgs = [
{"role": "system", "content": "be helpful"}, # generic
{"role": "user", "content": "give me an email"},
]
# Response contains an email → would leak across users sharing the
# generic namespace → must NOT be stored at all
resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
assert await memcache.get_chat("chat", "m", msgs) is None
async def test_personalized_response_user_specific_system_stored(self, memcache):
msgs = [
{"role": "system", "content": "User id: 998877 prefers concise answers"},
{"role": "user", "content": "what is my id?"},
]
resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
# User-specific namespace → exact-match within this user is OK
assert await memcache.get_chat("chat", "m", msgs) == resp
async def test_generate_convenience_wrappers(self, memcache):
resp = orjson.dumps({"response": "blue"})
await memcache.set_generate("m", "what color is the sky?", "", resp)
assert await memcache.get_generate("m", "what color is the sky?") == resp
class TestStatsAndClear:
async def test_stats_tracks_hits_and_misses(self, memcache):
msgs = [{"role": "user", "content": "hello"}]
await memcache.get_chat("chat", "m", msgs) # miss
resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
await memcache.get_chat("chat", "m", msgs) # hit
s = memcache.stats()
assert s["hits"] == 1
assert s["misses"] == 1
assert s["hit_rate"] == 0.5
assert s["semantic"] is False
assert s["backend"] == "memory"
async def test_clear_resets_counters_and_storage(self, memcache):
msgs = [{"role": "user", "content": "hi"}]
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
await memcache.get_chat("chat", "m", msgs)
await memcache.clear()
s = memcache.stats()
assert s["hits"] == 0
assert s["misses"] == 0
assert await memcache.get_chat("chat", "m", msgs) is None
# ──────────────────────────────────────────────────────────────────────────────
# Module-level helpers
# ──────────────────────────────────────────────────────────────────────────────
class TestInitLLMCache:
async def test_disabled_returns_none(self):
cfg = _exact_cfg()
cfg.cache_enabled = False
result = await init_llm_cache(cfg)
assert result is None
async def test_enabled_returns_initialized_cache(self):
cfg = _exact_cfg()
try:
result = await init_llm_cache(cfg)
assert result is not None
assert get_llm_cache() is result
finally:
# Reset singleton between tests
cache_mod._cache = None

View file

@ -1,399 +0,0 @@
"""Tests for choose_endpoint routing logic with mocked fetch calls."""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import router
EP1 = "http://ep1:11434"
EP2 = "http://ep2:11434"
EP3 = "http://ep3:11434"
LLAMA_EP = "http://llama:8080/v1"
def _make_cfg(endpoints, llama_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
cfg = MagicMock()
cfg.endpoints = endpoints
cfg.llama_server_endpoints = llama_eps or []
cfg.api_keys = {}
cfg.max_concurrent_connections = max_conn
cfg.endpoint_config = endpoint_config or {}
cfg.priority_routing = priority_routing
cfg.router_api_key = None
return cfg
@pytest.fixture(autouse=True)
def reset_usage():
"""Clear usage_counts and error caches between tests to prevent bleed."""
router.usage_counts.clear()
router._loaded_error_cache.clear()
yield
router.usage_counts.clear()
router._loaded_error_cache.clear()
class TestChooseEndpointBasic:
async def test_selects_single_candidate(self):
cfg = _make_cfg([EP1])
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
):
ep, tracking = await router.choose_endpoint("llama3.2:latest")
assert ep == EP1
assert tracking == "llama3.2:latest"
async def test_raises_when_no_endpoint_has_model(self):
cfg = _make_cfg([EP1, EP2])
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value=set())),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
with pytest.raises(RuntimeError, match="advertise the model"):
await router.choose_endpoint("unknown-model:latest")
async def test_prefers_loaded_endpoint(self):
cfg = _make_cfg([EP1, EP2])
async def available(ep, *_):
return {"llama3.2:latest"}
async def loaded(ep):
return {"llama3.2:latest"} if ep == EP2 else set()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
assert ep == EP2
async def test_falls_back_to_free_slot(self):
cfg = _make_cfg([EP1, EP2])
async def available(ep, *_):
return {"llama3.2:latest"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
assert ep in (EP1, EP2)
async def test_saturated_picks_least_busy(self):
cfg = _make_cfg([EP1, EP2])
cfg.max_concurrent_connections = 1
async def available(ep, *_):
return {"llama3.2:latest"}
# Saturate EP1 with 2 active connections, EP2 with 1
router.usage_counts[EP1]["llama3.2:latest"] = 2
router.usage_counts[EP2]["llama3.2:latest"] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
# Least-busy is EP2
assert ep == EP2
async def test_excludes_endpoint_with_recent_loaded_error(self):
# Regression: issue #83 — when /api/ps fails for EP1 but EP1
# still advertises the model via /api/tags, routing must not
# fall back to EP1 just because it has a free slot.
cfg = _make_cfg([EP1, EP2])
async def available(ep, *_):
return {"llama3.2:latest"}
# EP1's /api/ps probe failed recently; EP2 is fine but the model
# is not loaded there. Without the health filter, EP1 would be
# picked by the free-slot fallback (step 4 in choose_endpoint).
router._loaded_error_cache[EP1] = time.time()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
assert ep == EP2
async def test_stale_loaded_error_does_not_exclude(self):
# Errors older than the 300s window must not keep an endpoint
# excluded forever.
cfg = _make_cfg([EP1])
router._loaded_error_cache[EP1] = time.time() - 301
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"m:latest"})),
):
ep, _ = await router.choose_endpoint("m:latest")
assert ep == EP1
async def test_all_unhealthy_still_routes(self):
# If every candidate has a fresh loaded-error we still try one
# (it may have recovered between the cache write and now) rather
# than refusing to route.
cfg = _make_cfg([EP1])
router._loaded_error_cache[EP1] = time.time()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("m:latest")
assert ep == EP1
async def test_reserve_increments_usage(self):
cfg = _make_cfg([EP1])
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})),
):
ep, tracking = await router.choose_endpoint("model:latest", reserve=True)
assert router.usage_counts[ep][tracking] == 1
class TestChooseEndpointModelNaming:
async def test_strips_latest_for_openai_endpoints(self):
cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP])
cfg.endpoints = []
async def available(ep, *_):
# llama-server advertises without :latest
return {"gpt-4o"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})),
):
ep, _ = await router.choose_endpoint("gpt-4o:latest")
assert ep == LLAMA_EP
async def test_adds_latest_for_ollama_when_bare_name(self):
cfg = _make_cfg([EP1])
async def available(ep, *_):
return {"llama3.2:latest"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
):
ep, _ = await router.choose_endpoint("llama3.2")
assert ep == EP1
class TestChooseEndpointLoadBalancing:
async def test_random_selection_among_idle(self):
cfg = _make_cfg([EP1, EP2, EP3])
selected = set()
async def available(ep, *_):
return {"model:latest"}
async def loaded(ep):
return {"model:latest"}
for _ in range(20):
router.usage_counts.clear()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, _ = await router.choose_endpoint("model:latest", reserve=False)
selected.add(ep)
# With 20 draws from 3 idle endpoints, all three should appear
assert len(selected) > 1
async def test_sort_by_load_ascending(self):
cfg = _make_cfg([EP1, EP2])
router.usage_counts[EP1]["model:latest"] = 1
router.usage_counts[EP2]["model:latest"] = 0
async def available(ep, *_):
return {"model:latest"}
async def loaded(ep):
return {"model:latest"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, _ = await router.choose_endpoint("model:latest", reserve=False)
# EP2 has fewer active connections → should be selected
assert ep == EP2
# ---------------------------------------------------------------------------
# get_max_connections unit tests
# ---------------------------------------------------------------------------
class TestGetMaxConnections:
def test_returns_global_default_when_no_override(self):
cfg = _make_cfg([EP1, EP2], max_conn=3)
with patch.object(router, "config", cfg):
assert router.get_max_connections(EP1) == 3
assert router.get_max_connections(EP2) == 3
def test_returns_per_endpoint_override(self):
cfg = _make_cfg(
[EP1, EP2],
max_conn=2,
endpoint_config={EP1: {"max_concurrent_connections": 5}},
)
with patch.object(router, "config", cfg):
assert router.get_max_connections(EP1) == 5
assert router.get_max_connections(EP2) == 2 # falls back to global
def test_unrecognised_endpoint_falls_back_to_global(self):
cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}})
with patch.object(router, "config", cfg):
assert router.get_max_connections(EP3) == 4
# ---------------------------------------------------------------------------
# Priority / WRR routing tests
# ---------------------------------------------------------------------------
MODEL = "model:latest"
def _all_loaded(ep):
"""Side-effect: every endpoint advertises and has MODEL loaded."""
return {MODEL}
class TestPriorityRouting:
"""Tests for priority_routing=True (WRR + config-order tiebreaking)."""
async def test_idle_picks_first_in_config_order(self):
"""When all endpoints are idle, priority picks the first listed endpoint."""
cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True)
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP1
async def test_lower_utilization_preferred_over_priority(self):
"""An endpoint with lower ratio is preferred even if it has lower priority."""
cfg = _make_cfg([EP1, EP2], priority_routing=True)
# EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0
router.usage_counts[EP1][MODEL] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP2
async def test_wrr_distribution_matches_expected_sequence(self):
"""
Full WRR sequence with heterogeneous capacities, mirroring the issue example:
EP1 max=2, EP2 max=2, EP3 max=1
Expected routing order for 5 sequential requests:
EP1, EP2, EP3, EP1, EP2
"""
cfg = _make_cfg(
[EP1, EP2, EP3],
max_conn=2,
endpoint_config={EP3: {"max_concurrent_connections": 1}},
priority_routing=True,
)
expected = [EP1, EP2, EP3, EP1, EP2]
actual = []
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
for _ in expected:
ep, _ = await router.choose_endpoint(MODEL, reserve=True)
actual.append(ep)
assert actual == expected
async def test_saturated_picks_lowest_ratio_then_priority(self):
"""When all endpoints are saturated, pick lowest utilization ratio; break ties by priority."""
cfg = _make_cfg(
[EP1, EP2, EP3],
max_conn=1,
endpoint_config={EP3: {"max_concurrent_connections": 2}},
priority_routing=True,
)
# EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins
router.usage_counts[EP1][MODEL] = 1
router.usage_counts[EP2][MODEL] = 1
router.usage_counts[EP3][MODEL] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP3
async def test_saturated_ties_broken_by_priority(self):
"""When all are saturated with equal ratio, config order wins."""
cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True)
router.usage_counts[EP1][MODEL] = 1
router.usage_counts[EP2][MODEL] = 1
router.usage_counts[EP3][MODEL] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP1
class TestPriorityRoutingDisabled:
"""Verify that priority_routing=False keeps the original random behaviour."""
async def test_idle_endpoints_are_randomised(self):
"""Without priority routing, all-idle selection must eventually pick each endpoint."""
cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False)
selected = set()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
for _ in range(30):
router.usage_counts.clear()
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
selected.add(ep)
# With 30 draws from 3 equally-idle endpoints, all three must appear
assert selected == {EP1, EP2, EP3}

View file

@ -1,197 +0,0 @@
"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
from datetime import datetime, timezone
import pytest
from db import TokenDatabase
@pytest.fixture
async def db(tmp_path):
inst = TokenDatabase(str(tmp_path / "tokens.db"))
await inst.init_db()
yield inst
await inst.close()
class TestInit:
async def test_init_creates_tables(self, db):
# Re-init must be idempotent
await db.init_db()
# Insert + read confirms tables exist
await db.update_token_counts("http://ep", "m", 1, 2)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 1
async def test_creates_parent_directory(self, tmp_path):
nested = tmp_path / "nested" / "subdir" / "x.db"
inst = TokenDatabase(str(nested))
await inst.init_db()
try:
assert nested.parent.exists()
finally:
await inst.close()
class TestUpdateTokenCounts:
async def test_insert_then_update_aggregates(self, db):
await db.update_token_counts("http://ep", "m1", 10, 20)
await db.update_token_counts("http://ep", "m1", 5, 7)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 1
r = rows[0]
assert r["endpoint"] == "http://ep"
assert r["model"] == "m1"
assert r["input_tokens"] == 15
assert r["output_tokens"] == 27
assert r["total_tokens"] == 42
async def test_independent_endpoint_model_pairs(self, db):
await db.update_token_counts("http://ep1", "m1", 1, 1)
await db.update_token_counts("http://ep1", "m2", 2, 2)
await db.update_token_counts("http://ep2", "m1", 3, 3)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 3
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
assert totals == {
("http://ep1", "m1"): 2,
("http://ep1", "m2"): 4,
("http://ep2", "m1"): 6,
}
class TestBatchedCounts:
async def test_update_batched_counts(self, db):
counts = {
"http://a": {"m": (4, 6)},
"http://b": {"m": (1, 1), "n": (10, 0)},
}
await db.update_batched_counts(counts)
rows = [r async for r in db.load_token_counts()]
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
assert totals == {
("http://a", "m"): 10,
("http://b", "m"): 2,
("http://b", "n"): 10,
}
async def test_empty_batch_is_noop(self, db):
await db.update_batched_counts({})
rows = [r async for r in db.load_token_counts()]
assert rows == []
class TestTimeSeries:
async def test_add_time_series_entry(self, db):
# The aggregate FK requires the (endpoint,model) row to exist first
await db.update_token_counts("http://ep", "m", 0, 0)
await db.add_time_series_entry("http://ep", "m", 3, 4)
await db.add_time_series_entry("http://ep", "m", 1, 1)
rows = [r async for r in db.get_latest_time_series(limit=10)]
assert len(rows) == 2
# Newest-first ordering; both timestamps are within the same minute,
# so just check totals are present and well-formed
for r in rows:
assert r["endpoint"] == "http://ep"
assert r["model"] == "m"
assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
async def test_add_batched_time_series(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
entries = [
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
{"endpoint": "http://ep", "model": "m", "input_tokens": 4,
"output_tokens": 5, "total_tokens": 9, "timestamp": now},
]
await db.add_batched_time_series(entries)
rows = [r async for r in db.get_latest_time_series(limit=10)]
assert len(rows) == 2
assert rows[0]["timestamp"] >= rows[1]["timestamp"]
async def test_get_time_series_for_model_filters(self, db):
await db.update_token_counts("http://ep", "m1", 0, 0)
await db.update_token_counts("http://ep", "m2", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
"output_tokens": 9, "total_tokens": 18, "timestamp": now},
])
rows = [r async for r in db.get_time_series_for_model("m1")]
assert len(rows) == 1
assert rows[0]["total_tokens"] == 2
async def test_endpoint_distribution_for_model(self, db):
await db.update_token_counts("http://a", "m", 0, 0)
await db.update_token_counts("http://b", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
await db.add_batched_time_series([
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://b", "model": "m", "input_tokens": 5,
"output_tokens": 5, "total_tokens": 10, "timestamp": now},
])
dist = await db.get_endpoint_distribution_for_model("m")
assert dist == {"http://a": 4, "http://b": 10}
class TestGetTokenCountsForModel:
async def test_aggregates_across_endpoints(self, db):
await db.update_token_counts("http://a", "m", 1, 2)
await db.update_token_counts("http://b", "m", 3, 4)
result = await db.get_token_counts_for_model("m")
assert result is not None
assert result["endpoint"] == "aggregated"
assert result["model"] == "m"
assert result["input_tokens"] == 4
assert result["output_tokens"] == 6
assert result["total_tokens"] == 10
async def test_unknown_model_returns_zero_aggregate(self, db):
# SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
result = await db.get_token_counts_for_model("nope")
assert result is not None
assert result["input_tokens"] in (0, None)
class TestAggregateTimeSeriesOlderThan:
async def test_aggregates_old_entries_by_day(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
old = now - (40 * 86400) # 40 days ago
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
{"endpoint": "http://ep", "model": "m", "input_tokens": 3,
"output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
])
n = await db.aggregate_time_series_older_than(30, trim_old=False)
assert n == 1 # one (endpoint, model, day) group rolled up
async def test_invalid_days_falls_back_to_30(self, db):
# Just ensure it doesn't blow up with a bogus value
n = await db.aggregate_time_series_older_than(0)
assert n == 0
async def test_trim_old_removes_aggregated_rows(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
old = now - (40 * 86400)
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
])
await db.aggregate_time_series_older_than(30, trim_old=True)
remaining = [r async for r in db.get_latest_time_series(limit=10)]
# Only the recent (within-cutoff) row should remain
assert len(remaining) == 1
assert remaining[0]["total_tokens"] == 198

View file

@ -1,210 +0,0 @@
"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking."""
import time
from unittest.mock import patch, MagicMock
import pytest
from aioresponses import aioresponses
import router
from conftest import TEST_OLLAMA, TEST_LLAMA
MOCK_OLLAMA_EP = "http://mock-ollama:11434"
MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
cfg = MagicMock()
cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
cfg.api_keys = api_keys or {}
cfg.max_concurrent_connections = 2
cfg.router_api_key = None
return cfg
@pytest.fixture(autouse=True)
def clear_caches(aio_session):
"""aio_session fixture already clears caches and sets up app_state."""
yield
class TestFetchAvailableModels:
async def test_ollama_tags(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [
{"name": "llama3.2:latest"},
{"name": "qwen2.5:7b"},
]},
)
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == {"llama3.2:latest", "qwen2.5:7b"}
async def test_openai_compatible_models_endpoint(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_LLAMA_EP}/models",
payload={"data": [{"id": "unsloth/model:Q8_0"}]},
)
models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok")
assert "unsloth/model:Q8_0" in models
async def test_caches_successful_result(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]},
)
first = await router.fetch.available_models(MOCK_OLLAMA_EP)
second = await router.fetch.available_models(MOCK_OLLAMA_EP)
# second call must be served from cache without a second HTTP request
assert first == second == {"llama3.2:latest"}
async def test_returns_empty_on_http_500(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set()
async def test_returns_empty_on_connection_error(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
import aiohttp
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/tags",
exception=aiohttp.ClientConnectorError(
connection_key=MagicMock(host="mock-ollama", port=11434),
os_error=OSError(111, "refused"),
),
)
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set()
async def test_stale_cache_returned_while_refresh_runs(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]},
)
await router.fetch.available_models(MOCK_OLLAMA_EP)
# Manually age cache into stale-but-valid window (300-600s)
async with router._models_cache_lock:
models, _ = router._models_cache[MOCK_OLLAMA_EP]
router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]},
)
# Should return stale data immediately
stale = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert "llama3.2:latest" in stale
async def test_error_cache_short_circuits(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
# Seed error cache with a very recent error
async with router._available_error_cache_lock:
router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
with patch.object(router, "config", cfg), aioresponses():
# No HTTP mock registered — if a call happens it will raise
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set()
class TestFetchLoadedModels:
async def test_ollama_ps(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "llama3.2:latest"}]},
)
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert models == {"llama3.2:latest"}
async def test_llama_server_filters_loaded(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_LLAMA_EP}/models",
payload={"data": [
{"id": "model-a", "status": {"value": "loaded"}},
{"id": "model-b", "status": {"value": "unloaded"}},
]},
)
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert models == {"model-a"}
async def test_llama_server_no_status_field_always_loaded(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_LLAMA_EP}/models",
payload={"data": [{"id": "always-on-model"}]},
)
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert "always-on-model" in models
async def test_returns_empty_on_error(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert models == set()
async def test_ext_openai_always_empty(self):
ext_ep = "https://api.openai.com/v1"
cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[])
with patch.object(router, "config", cfg):
models = await router.fetch.loaded_models(ext_ep)
assert models == set()
async def test_caches_result(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "qwen:7b"}]},
)
first = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
second = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert first == second
async def test_records_error_in_loaded_error_cache_on_failure(self):
# Regression: issue #83 — /api/ps failures must be recorded so
# `choose_endpoint` can exclude unhealthy backends from routing.
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert MOCK_OLLAMA_EP in router._loaded_error_cache
async def test_records_error_for_llama_server_on_failure(self):
cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert MOCK_LLAMA_EP in router._loaded_error_cache
async def test_clears_error_cache_on_subsequent_success(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
# Pre-seed an old error so loaded_models() falls through to the
# network probe instead of short-circuiting on the error cache.
async with router._loaded_error_cache_lock:
router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301
with patch.object(router, "config", cfg), aioresponses() as m:
m.get(
f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "qwen:7b"}]},
)
await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert MOCK_OLLAMA_EP not in router._loaded_error_cache

View file

@ -1,181 +0,0 @@
"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
These tests verify that when the LLM cache reports a hit, the route returns
the cached payload *without* selecting an endpoint or contacting any backend.
"""
from unittest.mock import AsyncMock, patch
import orjson
import pytest
from fastapi import HTTPException
import router
_BYPASS = HTTPException(status_code=599, detail="bypassed")
class _FakeCache:
"""Minimal stand-in for cache.LLMCache.get_chat."""
def __init__(self, response_bytes: bytes | None):
self._resp = response_bytes
self.calls: list[tuple] = []
async def get_chat(self, route, model, messages):
self.calls.append((route, model, messages))
return self._resp
@pytest.fixture
def cache_hit_payload():
return orjson.dumps({
"id": "cmpl-xyz",
"created": 1,
"model": "test-model",
"choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
})
# ──────────────────────────────────────────────────────────────────────────────
# /v1/chat/completions
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAIChatCompletionsCacheHit:
async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
# Patch the route's references to both helpers — they're imported by name
# into router's namespace at module load time.
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "ping"}],
"stream": False,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
# Body is streamed; collect it
body = resp.content
parsed = orjson.loads(body)
assert parsed["choices"][0]["message"]["content"] == "from-cache"
assert fake.calls and fake.calls[0][0] == "openai_chat"
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "ping"}],
"stream": True,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
text = resp.content.decode()
# First SSE frame contains the cached content as a delta
first_frame = text.split("\n\n")[0]
assert first_frame.startswith("data: ")
chunk = orjson.loads(first_frame[len("data: "):])
assert chunk["choices"][0]["delta"]["content"] == "from-cache"
# Stream is terminated with [DONE]
assert "data: [DONE]" in text
async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
fake = _FakeCache(b"") # has a response, but should never be consulted
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "m",
"messages": [{"role": "user", "content": "hi"}],
"nomyo": {"cache": False},
},
)
# Got past the cache short-circuit → endpoint selection invoked
assert resp.status_code == 599
assert fake.calls == []
async def test_no_cache_configured_bypasses_cache_check(self, client):
"""get_llm_cache() returning None should not break the route."""
with (
patch.object(router, "get_llm_cache", return_value=None),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "m",
"messages": [{"role": "user", "content": "hi"}],
"nomyo": {"cache": True},
},
)
assert resp.status_code == 599
# ──────────────────────────────────────────────────────────────────────────────
# /v1/completions
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAICompletionsCacheHit:
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/completions",
json={
"model": "test-model",
"prompt": "Tell me a joke",
"stream": False,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
# Prompt-style cache lookup is namespaced under "openai_completions"
assert fake.calls[0][0] == "openai_completions"
# Cache lookup receives the prompt as a single user message
cached_msgs = fake.calls[0][2]
assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
async def test_stream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(router, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/completions",
json={
"model": "test-model",
"prompt": "What is 2+2?",
"stream": True,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
assert "data: [DONE]" in resp.content.decode()

View file

@ -1,116 +0,0 @@
"""Unit tests for context-window trimming logic."""
import pytest
import router
def _msgs(roles_contents):
return [{"role": r, "content": c} for r, c in roles_contents]
class TestCountMessageTokens:
def test_returns_int(self):
msgs = _msgs([("user", "hello")])
assert isinstance(router._count_message_tokens(msgs), int)
def test_empty_list(self):
assert router._count_message_tokens([]) >= 0
def test_longer_content_more_tokens(self):
short = _msgs([("user", "hi")])
long_ = _msgs([("user", "a " * 500)])
assert router._count_message_tokens(long_) > router._count_message_tokens(short)
def test_list_content(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "what do you see?"},
]}]
tokens = router._count_message_tokens(msgs)
assert tokens > 0
def test_multiple_messages(self):
msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
assert router._count_message_tokens(msgs) > 10
class TestTrimMessagesForContext:
def test_short_history_unchanged(self):
msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
result = router._trim_messages_for_context(msgs, n_ctx=4096)
assert result == msgs
def test_system_messages_always_kept(self):
msgs = (
_msgs([("system", "you are helpful")])
+ _msgs([("user", f"msg {i}") for i in range(50)])
+ _msgs([("user", "final question")])
)
result = router._trim_messages_for_context(msgs, n_ctx=512)
system_msgs = [m for m in result if m["role"] == "system"]
assert len(system_msgs) == 1
assert system_msgs[0]["content"] == "you are helpful"
def test_last_user_message_always_kept(self):
msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
result = router._trim_messages_for_context(msgs, n_ctx=256)
assert result[-1]["content"] == "very important last question"
def test_oldest_dropped_first(self):
msgs = _msgs([
("user", "oldest msg"),
("assistant", "oldest reply"),
("user", "newer msg"),
("assistant", "newer reply"),
("user", "newest"),
])
# Use very small target to force trimming
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
contents = [m["content"] for m in result]
# "oldest msg" should be dropped before "newest"
if "oldest msg" in contents:
assert "newest" in contents
else:
assert "newest" in contents
def test_result_starts_with_user(self):
msgs = _msgs([
("assistant", "leftover assistant"),
("user", "question"),
])
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
if result:
assert result[0]["role"] == "user"
def test_target_tokens_overrides_safety_margin(self):
msgs = _msgs([("user", "a " * 200)])
result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
# Both should return at least the last message
assert len(result_small) >= 1
assert len(result_large) >= 1
class TestCalibratedTrimTarget:
def test_returns_positive_int(self):
msgs = [{"role": "user", "content": "hello " * 100}]
result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
assert isinstance(result, int)
assert result >= 1
def test_over_limit_reduces_target(self):
msgs = [{"role": "user", "content": "a " * 500}]
# actual_tokens > n_ctx means we need to shed more
target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
assert target < router._count_message_tokens(msgs)
def test_well_within_limit_returns_current(self):
msgs = [{"role": "user", "content": "hi"}]
# actual_tokens << n_ctx means nothing to shed
target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
# Should return cur_tiktoken since to_shed == 0
assert target == max(1, router._count_message_tokens(msgs))
def test_minimum_is_one(self):
# Even if we need to shed everything, result is at least 1
msgs = [{"role": "user", "content": "hello"}]
target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
assert target >= 1

View file

@ -1,279 +0,0 @@
"""Unit tests for pure helper functions in router.py (no network, no app)."""
import time
import asyncio
from unittest.mock import MagicMock, patch
import aiohttp
import pytest
import router
class TestMaskSecrets:
def test_masks_openai_key(self):
text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ"
result = router._mask_secrets(text)
assert "sk-***redacted***" in result
assert "sk-abcd1234" not in result
def test_masks_api_key_assignment(self):
result = router._mask_secrets("api_key: supersecretvalue123")
assert "supersecretvalue123" not in result
assert "***redacted***" in result
def test_masks_api_key_with_colon(self):
result = router._mask_secrets("api-key: mykey")
assert "mykey" not in result
def test_empty_string_returns_empty(self):
assert router._mask_secrets("") == ""
def test_none_returns_none(self):
assert router._mask_secrets(None) is None
def test_no_secrets_unchanged(self):
text = "this is a normal log line"
assert router._mask_secrets(text) == text
class TestIsFresh:
def test_fresh_within_ttl(self):
cached_at = time.time() - 10
assert router._is_fresh(cached_at, 300) is True
def test_expired_beyond_ttl(self):
cached_at = time.time() - 400
assert router._is_fresh(cached_at, 300) is False
def test_exactly_at_boundary(self):
cached_at = time.time() - 300
# May be True or False depending on timing, just verify it runs
result = router._is_fresh(cached_at, 300)
assert isinstance(result, bool)
def test_just_cached(self):
assert router._is_fresh(time.time(), 1) is True
class TestNormalizeLlamaModelName:
def test_strips_hf_prefix(self):
assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF"
def test_strips_quant_suffix(self):
assert router._normalize_llama_model_name("model:Q8_0") == "model"
def test_strips_both(self):
result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16")
assert result == "gpt-oss-20b-GGUF"
def test_no_prefix_no_suffix(self):
assert router._normalize_llama_model_name("plain-model") == "plain-model"
def test_multiple_slashes(self):
result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M")
assert result == "model-name"
class TestExtractLlamaQuant:
def test_extracts_quant(self):
assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0"
def test_no_quant_returns_empty(self):
assert router._extract_llama_quant("plain-model") == ""
def test_f16(self):
assert router._extract_llama_quant("model:F16") == "F16"
def test_q4_k_m(self):
assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M"
class TestIsUnixSocketEndpoint:
def test_sock_endpoint_detected(self):
assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True
def test_regular_http_not_sock(self):
assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False
def test_ollama_not_sock(self):
assert router._is_unix_socket_endpoint("http://localhost:11434") is False
def test_dot_sock_in_host_detected(self):
assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True
class TestGetSocketPath:
def test_returns_run_user_path(self):
import os
path = router._get_socket_path("http://192.168.0.52.sock/v1")
uid = os.getuid()
assert path == f"/run/user/{uid}/192.168.0.52.sock"
class TestIsBase64:
def test_valid_base64(self):
import base64
data = base64.b64encode(b"hello world").decode()
assert router.is_base64(data) is True
def test_invalid_base64(self):
assert router.is_base64("not-base64!@#$") is False
def test_empty_string(self):
# Empty string is valid base64 (decodes to empty bytes)
assert router.is_base64("") is True
def test_non_string(self):
# Non-strings fall through without returning True (returns None)
assert not router.is_base64(12345)
class TestIsLlamaModelLoaded:
def test_status_dict_loaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True
def test_status_dict_unloaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False
def test_status_string_loaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True
def test_status_string_unloaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False
def test_no_status_field_always_loaded(self):
# No status field → always available (single-model server)
assert router._is_llama_model_loaded({"id": "m"}) is True
def test_status_none_always_loaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": None}) is True
class TestEp2Base:
def test_adds_v1_to_ollama(self):
assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1"
def test_keeps_v1_if_present(self):
assert router.ep2base("http://host/v1") == "http://host/v1"
def test_llama_server_endpoint_unchanged(self):
ep = "http://192.168.0.50:8889/v1"
assert router.ep2base(ep) == ep
class TestDedupeOnKeys:
def test_removes_duplicate_by_single_key(self):
items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}]
result = router.dedupe_on_keys(items, ["name"])
assert len(result) == 2
assert result[0]["name"] == "a"
assert result[1]["name"] == "b"
def test_removes_duplicate_by_two_keys(self):
items = [
{"digest": "abc", "name": "m1"},
{"digest": "abc", "name": "m1"},
{"digest": "def", "name": "m2"},
]
result = router.dedupe_on_keys(items, ["digest", "name"])
assert len(result) == 2
def test_empty_list(self):
assert router.dedupe_on_keys([], ["name"]) == []
def test_no_duplicates(self):
items = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
assert len(router.dedupe_on_keys(items, ["name"])) == 3
class TestFormatConnectionIssue:
def test_connector_error_message(self):
err = aiohttp.ClientConnectorError(
connection_key=MagicMock(host="localhost", port=11434),
os_error=OSError(111, "Connection refused"),
)
msg = router._format_connection_issue("http://localhost:11434", err)
assert "localhost" in msg
assert "Connection refused" in msg or "111" in msg
def test_timeout_error_message(self):
msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError())
assert "Timed out" in msg
assert "host:1234" in msg
def test_generic_error(self):
msg = router._format_connection_issue("http://host:1234", ValueError("boom"))
assert "host:1234" in msg
assert "boom" in msg
class TestIsExtOpenaiEndpoint:
def test_openai_com_is_ext(self):
cfg = MagicMock()
cfg.endpoints = []
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True
def test_ollama_default_port_not_ext(self):
cfg = MagicMock()
cfg.endpoints = ["http://host:11434"]
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("http://host:11434") is False
def test_llama_server_not_ext(self):
cfg = MagicMock()
cfg.endpoints = []
cfg.llama_server_endpoints = ["http://host:8080/v1"]
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("http://host:8080/v1") is False
def test_no_v1_not_ext(self):
cfg = MagicMock()
cfg.endpoints = ["http://host:11434"]
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("http://host:11434") is False
class TestIsOpenaiCompatible:
def test_v1_endpoint_compatible(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_openai_compatible("http://host/v1") is True
def test_ollama_not_compatible(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_openai_compatible("http://localhost:11434") is False
def test_llama_server_in_list_compatible(self):
cfg = MagicMock()
cfg.llama_server_endpoints = ["http://host:8080"]
with patch.object(router, "config", cfg):
assert router.is_openai_compatible("http://host:8080") is True
class TestGetTrackingModel:
def test_ollama_adds_latest(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest"
def test_ollama_keeps_existing_tag(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b"
def test_llama_server_normalizes(self):
ep = "http://host:8080/v1"
cfg = MagicMock()
cfg.llama_server_endpoints = [ep]
with patch.object(router, "config", cfg):
result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
assert result == "model"

View file

@ -1,173 +0,0 @@
"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion."""
import time
from types import SimpleNamespace
import ollama
import router
def _ns(**kw):
return SimpleNamespace(**kw)
def _stream_chunk(content="hi", role="assistant", finish_reason=None,
usage=None, model="m"):
"""Build a SimpleNamespace mimicking a streaming OpenAI chunk."""
delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
tool_calls=None)
choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None)
return _ns(model=model, choices=[choice], usage=usage)
def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop",
usage=None, model="m", tool_calls=None):
"""Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion."""
message = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
tool_calls=tool_calls)
choice = _ns(message=message, finish_reason=finish_reason, logprobs=None)
return _ns(model=model, choices=[choice], usage=usage)
# ──────────────────────────────────────────────────────────────────────────────
# openai_chat_completion2ollama
# ──────────────────────────────────────────────────────────────────────────────
class TestChatCompletionToOllama:
def test_streaming_content_chunk(self):
chunk = _stream_chunk(content="hello", finish_reason=None, usage=None)
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert isinstance(out, ollama.ChatResponse)
assert out.message.role == "assistant"
assert out.message.content == "hello"
assert out.done is False # usage is None → not done yet
assert out.model == "m"
def test_streaming_empty_content_defaults(self):
# Some chunks have content=None — should coerce to empty string
chunk = _stream_chunk(content=None, role=None)
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert out.message.role == "assistant" # role defaulted
assert out.message.content == ""
def test_final_usage_only_chunk_marks_done(self):
usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15)
chunk = _ns(model="m", choices=[], usage=usage)
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert out.done is True
assert out.done_reason == "stop"
assert out.prompt_eval_count == 10
assert out.eval_count == 5
assert out.message.content == ""
def test_nonstreaming_with_content(self):
usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5)
chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage)
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
assert out.done is True
assert out.message.content == "response text"
assert out.prompt_eval_count == 2
assert out.eval_count == 3
def test_nonstreaming_tool_calls_converted(self):
"""Tool calls with JSON string arguments are parsed into dicts."""
tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}'))
usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
chunk = _nonstream_chunk(
content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc]
)
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
assert out.message.tool_calls is not None
assert len(out.message.tool_calls) == 1
first = out.message.tool_calls[0]
assert first.function.name == "get_weather"
assert first.function.arguments == {"city": "Paris"}
def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self):
tc = _ns(function=_ns(name="f", arguments="not-json"))
usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc])
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
assert out.message.tool_calls[0].function.arguments == {}
def test_streaming_tool_calls_in_delta_are_skipped(self):
"""Streaming mode must not assemble tool calls (caller handles it)."""
chunk = _stream_chunk(content="x", finish_reason=None)
# Even if a chunk somehow carried tool_calls in the delta, streaming
# mode should ignore them.
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert out.message.tool_calls is None
# ──────────────────────────────────────────────────────────────────────────────
# openai_completion2ollama
# ──────────────────────────────────────────────────────────────────────────────
class TestCompletionToOllama:
def test_streaming_text_chunk(self):
choice = _ns(text="word", finish_reason=None, reasoning=None)
chunk = _ns(model="m", choices=[choice], usage=None)
out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
assert isinstance(out, ollama.GenerateResponse)
assert out.response == "word"
assert out.done is False
def test_final_chunk_with_usage(self):
usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10)
choice = _ns(text="end", finish_reason="stop", reasoning=None)
chunk = _ns(model="m", choices=[choice], usage=usage)
out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
assert out.done is True
assert out.prompt_eval_count == 4
assert out.eval_count == 6
# ──────────────────────────────────────────────────────────────────────────────
# embeddings / embed
# ──────────────────────────────────────────────────────────────────────────────
class TestEmbeddingConversions:
def test_openai_embeddings2ollama(self):
chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])])
out = router.rechunk.openai_embeddings2ollama(chunk)
assert isinstance(out, ollama.EmbeddingsResponse)
assert list(out.embedding) == [0.1, 0.2, 0.3]
def test_openai_embed2ollama(self):
chunk = _ns(data=[_ns(embedding=[0.5, 0.6])])
out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model")
assert isinstance(out, ollama.EmbedResponse)
assert out.model == "my-embed-model"
assert list(out.embeddings[0]) == [0.5, 0.6]
# ──────────────────────────────────────────────────────────────────────────────
# extract_usage_from_llama_timings
# ──────────────────────────────────────────────────────────────────────────────
class TestExtractUsageFromLlamaTimings:
def test_none_when_no_timings_attr(self):
obj = _ns()
assert router.rechunk.extract_usage_from_llama_timings(obj) is None
def test_prompt_plus_cache_sums(self):
obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35})
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
assert prompt == 237
assert completion == 35
def test_missing_keys_default_to_zero(self):
obj = _ns(timings={"predicted_n": 12})
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
assert prompt == 0
assert completion == 12
def test_null_values_treated_as_zero(self):
obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None})
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
assert prompt == 0
assert completion == 0
def test_non_dict_timings_returns_none(self):
obj = _ns(timings="not-a-dict")
assert router.rechunk.extract_usage_from_llama_timings(obj) is None

View file

@ -1,200 +0,0 @@
"""Unit tests for message transformation functions."""
from unittest.mock import MagicMock
import pytest
import router
class TestStripAssistantPrefill:
def test_removes_trailing_assistant(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "prefill"},
]
result = router._strip_assistant_prefill(msgs)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_keeps_non_trailing_assistant(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "response"},
{"role": "user", "content": "follow-up"},
]
result = router._strip_assistant_prefill(msgs)
assert len(result) == 3
def test_empty_list_unchanged(self):
assert router._strip_assistant_prefill([]) == []
def test_single_user_message_unchanged(self):
msgs = [{"role": "user", "content": "hi"}]
assert router._strip_assistant_prefill(msgs) == msgs
class TestTransformToolCallsToOpenAI:
def test_adds_type_function(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
tc = result[0]["tool_calls"][0]
assert tc["type"] == "function"
def test_adds_id_when_missing(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "fn", "arguments": {}}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
assert "id" in result[0]["tool_calls"][0]
def test_converts_dict_arguments_to_string(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "fn", "arguments": {"key": "val"}}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
args = result[0]["tool_calls"][0]["function"]["arguments"]
assert isinstance(args, str)
import orjson
parsed = orjson.loads(args)
assert parsed == {"key": "val"}
def test_keeps_string_arguments_unchanged(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "fn", "arguments": '{"key": "val"}'}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
args = result[0]["tool_calls"][0]["function"]["arguments"]
assert args == '{"key": "val"}'
def test_links_tool_call_id_to_tool_response(self):
msgs = [
{"role": "assistant", "tool_calls": [
{"function": {"name": "get_weather", "arguments": {}}}
]},
{"role": "tool", "name": "get_weather", "content": "sunny"},
]
result = router.transform_tool_calls_to_openai(msgs)
tc_id = result[0]["tool_calls"][0]["id"]
assert result[1].get("tool_call_id") == tc_id
def test_non_tool_messages_unchanged(self):
msgs = [{"role": "user", "content": "hello"}]
result = router.transform_tool_calls_to_openai(msgs)
assert result == msgs
class TestStripImagesFromMessages:
def test_removes_image_url_parts(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "what is this?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]}]
result = router._strip_images_from_messages(msgs)
content = result[0]["content"]
assert content == "what is this?"
def test_keeps_text_only_messages(self):
msgs = [{"role": "user", "content": "plain text"}]
result = router._strip_images_from_messages(msgs)
assert result[0]["content"] == "plain text"
def test_multiple_text_parts_kept_as_list(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "part one"},
{"type": "text", "text": "part two"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]}]
result = router._strip_images_from_messages(msgs)
content = result[0]["content"]
assert isinstance(content, list)
assert len(content) == 2
def test_all_images_removed_empty_list(self):
msgs = [{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "data:..."}},
]}]
result = router._strip_images_from_messages(msgs)
# Image-only content becomes empty list
content = result[0]["content"]
assert content == []
class TestAccumulateOpenAITcDelta:
def _make_chunk(self, index, name=None, args_fragment="", tc_id=None):
delta = MagicMock()
tc = MagicMock()
tc.index = index
tc.id = tc_id
tc.function = MagicMock()
tc.function.name = name
tc.function.arguments = args_fragment
delta.tool_calls = [tc]
chunk = MagicMock()
chunk.choices = [MagicMock(delta=delta)]
return chunk
def test_first_delta_creates_entry(self):
acc = {}
chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"')
router._accumulate_openai_tc_delta(chunk, acc)
assert 0 in acc
assert acc[0]["name"] == "my_fn"
assert acc[0]["arguments"] == '{"k"'
def test_subsequent_deltas_concatenate_args(self):
acc = {}
router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc)
router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc)
assert acc[0]["arguments"] == '{"k": "v"}'
def test_multiple_tool_calls_tracked_separately(self):
acc = {}
c1 = self._make_chunk(0, name="fn1", args_fragment="{}")
c2 = self._make_chunk(1, name="fn2", args_fragment="{}")
chunk = MagicMock()
tc1 = MagicMock()
tc1.index = 0
tc1.id = "id1"
tc1.function = MagicMock(name="fn1", arguments="{}")
tc2 = MagicMock()
tc2.index = 1
tc2.id = "id2"
tc2.function = MagicMock(name="fn2", arguments="{}")
chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))]
router._accumulate_openai_tc_delta(chunk, acc)
assert 0 in acc and 1 in acc
def test_no_choices_is_noop(self):
acc = {}
chunk = MagicMock(choices=[])
router._accumulate_openai_tc_delta(chunk, acc)
assert acc == {}
class TestBuildOllamaToolCalls:
def test_builds_from_accumulator(self):
acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}}
result = router._build_ollama_tool_calls(acc)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "get_weather"
assert result[0].function.arguments == {"city": "Berlin"}
def test_invalid_json_args_becomes_empty_dict(self):
acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}}
result = router._build_ollama_tool_calls(acc)
assert result[0].function.arguments == {}
def test_empty_accumulator_returns_none(self):
assert router._build_ollama_tool_calls({}) is None
def test_preserves_order_by_index(self):
acc = {
1: {"id": "c2", "name": "fn2", "arguments": "{}"},
0: {"id": "c1", "name": "fn1", "arguments": "{}"},
}
result = router._build_ollama_tool_calls(acc)
assert result[0].function.name == "fn1"
assert result[1].function.name == "fn2"