Compare commits

..

No commits in common. "main" and "v0.9.0" have entirely different histories.
main ... v0.9.0

62 changed files with 3978 additions and 111553 deletions

View file

@ -86,7 +86,7 @@ jobs:
provenance: false provenance: false
build-args: | build-args: |
SEMANTIC_CACHE=true SEMANTIC_CACHE=true
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}-${{ github.run_id }} tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-${{ matrix.arch }}
merge: merge:
runs-on: docker-amd64 runs-on: docker-amd64
@ -142,6 +142,6 @@ jobs:
run: | run: |
docker buildx imagetools create \ docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64-${{ github.run_id }} \ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-amd64 \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64-${{ github.run_id }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:semantic-platform-arm64

View file

@ -77,7 +77,7 @@ jobs:
platforms: ${{ matrix.platform }} platforms: ${{ matrix.platform }}
push: true push: true
provenance: false provenance: false
tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}-${{ github.run_id }} tags: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-${{ matrix.arch }}
merge: merge:
runs-on: docker-amd64 runs-on: docker-amd64
@ -133,6 +133,6 @@ jobs:
run: | run: |
docker buildx imagetools create \ docker buildx imagetools create \
$(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64-${{ github.run_id }} \ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-amd64 \
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64-${{ github.run_id }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:platform-arm64

View file

@ -11,7 +11,7 @@ jobs:
opencode: opencode:
if: | if: |
contains(github.event.comment.body, '/oc') || contains(github.event.comment.body, '/oc') ||
contains(github.event.review.body, '/oc') contains(github.event.comment.body, '/opencode')
runs-on: docker-amd64 runs-on: docker-amd64
container: container:
image: node:lts-bookworm image: node:lts-bookworm
@ -54,7 +54,9 @@ jobs:
uses: ./.opencode-action uses: ./.opencode-action
with: with:
nomyo_api_key: ${{ secrets.NOMYO_API_KEY }} nomyo_api_key: ${{ secrets.NOMYO_API_KEY }}
model: nomyo/unsloth/Qwen3.6-35B-A3B-MTP-GGUF:Q4_K_XL model: nomyo/unsloth/Qwen3.6-35B-A3B-GGUF:UD-Q4_K_M
forgejo_api_url: https://bitfreedom.net/code/ forgejo_api_url: https://bitfreedom.net/code/
forgejo_token: ${{ secrets.FORGEJO_TOKEN }} forgejo_token: ${{ secrets.FORGEJO_TOKEN }}
forgejo_push_token: ${{ secrets.FORGEJO_PUSH_TOKEN }} forgejo_push_token: ${{ secrets.FORGEJO_PUSH_TOKEN }}

View file

@ -1,39 +0,0 @@
name: PR Tests
on: [pull_request]
jobs:
test:
runs-on: docker-arm64
container:
image: python:3.12-slim
env:
CMAKE_BUILD_PARALLEL_LEVEL: "4"
steps:
- name: Install system deps
run: |
apt-get update
apt-get install -y --no-install-recommends \
git ca-certificates \
build-essential pkg-config
rm -rf /var/lib/apt/lists/*
- name: Checkout
run: |
git config --global --add safe.directory "$PWD"
git clone --depth=1 \
"https://oauth2:${{ github.token }}@bitfreedom.net/code/${{ github.repository }}.git" .
git fetch --depth=1 origin "+${{ github.event.pull_request.head.sha }}:pr"
git checkout pr
- name: Fetch action source
run: |
git clone --depth=1 --branch master \
"https://oauth2:${{ github.token }}@bitfreedom.net/code/nomyo-ai/actions.git" \
./.run-tests
- uses: ./.run-tests/run-tests
with:
setup: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r test/requirements_test.txt
command: pytest test/ -m "not integration" --cov=router --cov=cache --cov=db --cov=enhance --cov-fail-under=45 --cov-report=term-missing --cov-report=xml --junitxml=report.xml
artifacts-path: |
report.xml
coverage.xml

5
.gitignore vendored
View file

@ -66,4 +66,7 @@ config.yaml
# SQLite # SQLite
*.db* *.db*
*settings.json *settings.json
# Test suite (local only, not committed yet)
test/

View file

@ -132,41 +132,6 @@ This way the Ollama backend servers are utilized more efficient than by simply u
NOMYO Router also supports OpenAI API compatible v1 backend servers. NOMYO Router also supports OpenAI API compatible v1 backend servers.
## OpenAI Responses API
In addition to Chat Completions, NOMYO Router exposes the OpenAI **Responses API**:
```
POST /v1/responses # create a response (stream or non-stream)
GET /v1/responses/{id} # retrieve a stored response
DELETE /v1/responses/{id} # delete a stored response
POST /v1/responses/{id}/cancel # cancel a background response
```
It works transparently across **all** backends. When the routed model lives on a native
Responses backend (external OpenAI) the request is forwarded as-is; for Ollama and llama-server the
router translates Responses ⇄ Chat Completions in both directions (request, response, and streaming
typed SSE events), so clients get a consistent `/v1/responses` surface regardless of backend.
### Conversation state (`store` / `previous_response_id`)
The router **owns conversation state itself** (persisted in its SQLite DB) rather than delegating to
the upstream provider, so `store` and `previous_response_id` behave identically on every backend.
On a follow-up request the router rehydrates the prior turns from its DB and expands them into the
conversation; outbound native calls always send `store=false`. Trade-off: this forgoes OpenAI's
server-side reasoning-state reuse in exchange for uniform, backend-agnostic chaining.
### Background mode
`background:true` (which requires `store:true`) returns immediately with `{"status":"queued"}`; the
request runs server-side and the client polls `GET /v1/responses/{id}` until the status reaches a
terminal state (`completed` / `failed` / `cancelled`). `POST /v1/responses/{id}/cancel` aborts it.
Limitations: streaming reconnect-resume via `starting_after` is not yet implemented. In a
multi-worker/replica deployment polling works via the shared DB, but `cancel` only reaches the
running task in the worker that started it (other workers just mark the stored row cancelled). A
background task interrupted by a server restart is reconciled to `failed` on the next startup.
## Semantic LLM Cache ## Semantic LLM Cache
NOMYO Router includes an optional semantic cache that serves repeated or semantically similar LLM requests from cache — no endpoint round-trip, no token cost, response in <10 ms. NOMYO Router includes an optional semantic cache that serves repeated or semantically similar LLM requests from cache — no endpoint round-trip, no token cost, response in <10 ms.
@ -207,7 +172,7 @@ Each request is keyed on `model + system_prompt` (exact) combined with a weighte
### Cached routes ### Cached routes
`/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions` · `/v1/responses` `/api/chat` · `/api/generate` · `/v1/chat/completions` · `/v1/completions`
### Cache management ### Cache management

View file

View file

@ -1,280 +0,0 @@
"""Management / observability routes.
Read-only endpoints used by the dashboard and external monitoring:
* usage counters and token-counts breakdown,
* conversation-affinity introspection,
* endpoint health summary,
* LLM-response cache stats and invalidation,
* SSE live-stream of usage updates,
* hostname and ``/health`` probe.
"""
import asyncio
import socket
import time
from typing import Optional
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache
from config import get_config
from db import get_db
from state import (
usage_counts,
token_usage_counts,
_affinity_map,
_affinity_lock,
)
from sse import subscribe, unsubscribe
from backends.normalize import _normalize_llama_model_name, is_llama_server, llama_endpoints
from backends.probe import _endpoint_health
router = APIRouter()
@router.get("/api/token_counts")
async def token_counts_proxy():
breakdown = []
total = 0
async for entry in get_db().load_token_counts():
total += entry['total_tokens']
breakdown.append({
"endpoint": entry["endpoint"],
"model": entry["model"],
"input_tokens": entry["input_tokens"],
"output_tokens": entry["output_tokens"],
"total_tokens": entry["total_tokens"],
})
return {"total_tokens": total, "breakdown": breakdown}
@router.post("/api/aggregate_time_series_days")
async def aggregate_time_series_days_proxy(request: Request):
"""
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
"""
try:
body_bytes = await request.body()
if not body_bytes:
days = 30
trim_old = False
else:
payload = orjson.loads(body_bytes.decode("utf-8"))
days = int(payload.get("days", 30))
trim_old = bool(payload.get("trim_old", False))
except Exception:
days = 30
trim_old = False
aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old)
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
@router.post("/api/stats")
async def stats_proxy(request: Request, model: Optional[str] = None):
"""
Return token usage statistics for a specific model.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
db = get_db()
token_data = await db.get_token_counts_for_model(model)
if not token_data:
raise HTTPException(
status_code=404, detail="No token data found for this model"
)
time_series = [
entry async for entry in db.get_time_series_for_model(model)
]
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
return {
'model': model,
'input_tokens': token_data['input_tokens'],
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
'time_series': time_series,
'endpoint_distribution': endpoint_distribution,
}
@router.get("/api/affinity_stats")
async def affinity_stats(request: Request):
"""
Aggregate live conversation-affinity pins, one entry per pinned conversation.
Each entry exposes only the endpoint, model, and remaining TTL in seconds
no fingerprints or content. When conversation_affinity is disabled the
`entries` list is always empty.
"""
config = get_config()
if not config.conversation_affinity:
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
now = time.monotonic()
entries: list[dict] = []
async with _affinity_lock:
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
remaining = expires_at - now
if remaining <= 0:
_affinity_map.pop(fp, None)
continue
# Mirror the normalisation used by /api/ps_details so the dashboard
# can join affinity entries to PS rows by (endpoint, model).
display_model = _normalize_llama_model_name(mdl) if is_llama_server(ep) else mdl
entries.append({
"endpoint": ep,
"model": display_model,
"remaining": round(remaining, 2),
})
return {
"enabled": True,
"ttl": config.conversation_affinity_ttl,
"entries": entries,
}
@router.get("/api/usage")
async def usage_proxy(request: Request):
"""
Return a snapshot of the usage counter for each endpoint.
Useful for debugging / monitoring.
"""
return {"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts}
@router.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
Ollama endpoints and llama_server_endpoints. The frontend uses this
to display which endpoints are being proxied and their health.
Status is "error" when either liveness (/api/version) or routing
health (/api/ps) fails see issue #83.
"""
config = get_config()
async def check(url: str) -> dict:
return {"url": url, **(await _endpoint_health(url, timeout=5))}
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
llama_results = []
# llama-server and llama-swap render identically in the dashboard ("llama" rows),
# so health-check both and merge them into one list.
llama_eps = llama_endpoints(config)
if llama_eps:
llama_results = await asyncio.gather(
*[check(ep) for ep in llama_eps]
)
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
"require_router_api_key": bool(config.router_api_key),
}
@router.get("/api/cache/stats")
async def cache_stats():
"""Return hit/miss counters and configuration for the LLM response cache."""
c = get_llm_cache()
if c is None:
return {"enabled": False}
return {"enabled": True, **c.stats()}
@router.post("/api/cache/invalidate")
async def cache_invalidate():
"""Clear all entries from the LLM response cache and reset counters."""
c = get_llm_cache()
if c is None:
return {"enabled": False, "cleared": False}
await c.clear()
return {"enabled": True, "cleared": True}
@router.get("/health")
async def health_proxy(request: Request):
"""
Healthcheck endpoint for monitoring the proxy.
* Queries each configured endpoint for both liveness and routing health:
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
OpenAI-compatible endpoints at `/models`.
* Returns a JSON object containing:
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
- `endpoints`: a mapping of endpoint URL `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
config = get_config()
# Run all health checks in parallel.
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
# (vLLM, llama-server, external) expose /models, which serves both
# purposes. Probing /api/version alone would miss the case where the
# Ollama process is up but /api/ps is failing — see issue #83.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints]
all_endpoints += llama_eps_extra
probe_results = await asyncio.gather(
*(_endpoint_health(ep) for ep in all_endpoints),
)
health_summary = dict(zip(all_endpoints, probe_results))
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
response_payload = {
"status": "ok" if overall_ok else "error",
"endpoints": health_summary,
}
http_status = 200 if overall_ok else 503
return JSONResponse(content=response_payload, status_code=http_status)
@router.get("/api/hostname")
async def get_hostname():
"""Return the hostname of the machine running the router."""
return JSONResponse(content={"hostname": socket.gethostname()})
@router.get("/api/usage-stream")
async def usage_stream(request: Request):
"""
ServerSentEvents that emits a JSON payload every time the
global `usage_counts` dictionary changes.
"""
async def event_generator():
# The queue that receives *every* new snapshot
queue = await subscribe()
try:
while True:
# If the client disconnects, cancel the loop
if await request.is_disconnected():
break
data = await queue.get()
if data is None:
break
# Send the data as a single SSE message
yield f"data: {data}\n\n"
finally:
# Cleanup: unsubscribe from the broadcast channel
await unsubscribe(queue)
return StreamingResponse(event_generator(), media_type="text/event-stream")

File diff suppressed because it is too large Load diff

View file

@ -1,906 +0,0 @@
"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
The chat-completions and completions handlers carry the full reactive-trim
logic for ``exceed_context_size_error`` plus connection-failure rerouting
(``_mark_backend_unhealthy``). The streaming branches assemble cached
responses on the fly so caching works for both streaming and non-streaming
clients.
"""
import asyncio
import base64
import math
import aiohttp
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache, openai_nonstream_to_sse
from config import get_config
from context_window import (
_count_message_tokens,
_trim_messages_for_context,
_calibrated_trim_target,
_endpoint_nctx,
_CTX_TRIM_SMALL_LIMIT,
)
from fingerprint import _conversation_fingerprint
from security import _mask_secrets
from state import token_queue, app_state, default_headers
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
from backends.normalize import (
dedupe_on_keys,
ep2base,
is_ext_openai_endpoint,
is_openai_compatible,
is_llama_server,
llama_endpoints,
_normalize_llama_model_name,
)
from backends.probe import fetch
from backends.sessions import _make_openai_client, get_session
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
from requests.rechunk import rechunk
from routing import choose_endpoint, decrement_usage
router = APIRouter()
async def create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model):
"""Call ``chat.completions.create`` with the router's resilience retries.
Encapsulates the recovery ladder shared by the chat-completions handler and
the translated ``/v1/responses`` path:
* ``does not support tools`` retry without ``tools``
* llama-server context exhaustion sliding-window message trim, with a
second retry that also strips ``tools``/``tool_choice``
* backend connection failure mark (endpoint, model) unhealthy so the next
request reroutes, then re-raise
* ``image input is not supported`` strip images and retry
On unrecoverable failure the endpoint usage counter is decremented and the
exception is re-raised. Returns the established async generator / response.
"""
config = get_config()
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
_e_str = str(e)
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
if "does not support tools" in _e_str:
# Model doesn't support tools — retry without them
print(f"[ochat] retry: no tools", flush=True)
try:
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_ctx_err:
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
if not n_ctx_limit:
import re as _re
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
except Exception as _helper_exc:
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
await decrement_usage(endpoint, tracking_model)
raise
dropped = len(msgs_to_trim) - len(trimmed_messages)
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
print(f"[ctx-trim] retry-1 ok", flush=True)
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
# Still too large — tool definitions likely consuming too many tokens, strip them too
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
print(f"[ctx-trim] retry-2 ok", flush=True)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
# Upstream connection failed (e.g. llama-server in router mode
# whose delegated worker died). Mark (endpoint, model) so the
# next request reroutes; the client will retry this one.
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
# Model doesn't support images — strip and retry
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
return async_gen
@router.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
"""
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
doc = payload.get("input")
# Normalize multimodal input: extract only text parts for embedding models
if isinstance(doc, list):
normalized = []
for item in doc:
if isinstance(item, dict):
# Multimodal content part - extract text only, skip images
if item.get("type") == "text":
normalized.append(item.get("text", ""))
# Skip image_url and other non-text types
else:
normalized.append(item)
doc = normalized if len(normalized) != 1 else normalized[0]
elif isinstance(doc, dict) and doc.get("type") == "text":
doc = doc.get("text", "")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not doc:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
api_key = "ollama"
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
try:
async_gen = await oclient.embeddings.create(input=doc, model=model)
result = async_gen.model_dump()
for item in result.get("data", []):
emb = item.get("embedding")
if emb:
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
return JSONResponse(content=result)
finally:
await decrement_usage(endpoint, tracking_model)
@router.post("/v1/chat/completions")
async def openai_chat_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
response_format = payload.get("response_format")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"response_format": response_format,
"stream_options": stream_options or {"include_usage": True },
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"seed": seed,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"stop": stop,
"stream": stream,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Reject unsupported image formats (SVG) before doing any work
for _msg in messages:
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
if _item.get("type") == "image_url":
_url = (_item.get("image_url") or {}).get("url", "")
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
raise HTTPException(
status_code=400,
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
)
# Cache lookup — before endpoint selection
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
_cached = await _cache.get_chat("openai_chat", model, messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ochat_stream():
yield _sse
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ochat_json():
yield _cached
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list:
"""Fetch remote image URLs and convert them to base64 data URLs so
Ollama/llama-server can handle them without making outbound HTTP requests."""
resolved = []
for msg in msgs:
content = msg.get("content")
if not isinstance(content, list):
resolved.append(msg)
continue
new_content = []
for item in content:
if item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url", "")
if url and not url.startswith("data:"):
try:
http: aiohttp.ClientSession = app_state["session"]
async with http.get(url) as resp:
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
img_bytes = await resp.read()
b64 = base64.b64encode(img_bytes).decode("utf-8")
new_content.append({
"type": "image_url",
"image_url": {"url": f"data:{ctype};base64,{b64}"}
})
except Exception as _ie:
print(f"[image] Failed to fetch image URL: {_ie}")
new_content.append(item)
else:
new_content.append(item)
else:
new_content.append(item)
resolved.append({**msg, "content": new_content})
return resolved
# Make the API call in handler scope — try/except inside async generators is unreliable
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
send_params = params
if not is_ext_openai_endpoint(endpoint):
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
send_params = {**params, "messages": resolved_msgs}
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
_pre_est = _count_message_tokens(send_params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = send_params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
send_params = {**send_params, "messages": _pre_trimmed}
async_gen = await create_chat_with_retries(oclient, send_params, endpoint, model, tracking_model)
# 4. Async generator — only streams the already-established async_gen
async def stream_ochat_response():
try:
if stream == True:
content_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
delta = chunk.choices[0].delta
has_content = delta.content is not None
has_reasoning = (
getattr(delta, "reasoning_content", None) is not None
or getattr(delta, "reasoning", None) is not None
)
has_tool_calls = getattr(delta, "tool_calls", None) is not None
if has_content or has_reasoning or has_tool_calls:
yield f"data: {data}\n\n".encode("utf-8")
if has_content and delta.content:
content_parts.append(delta.content)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Detect context exhaustion mid-generation for small-ctx models.
# Guard: skip if max_tokens was set in the request — finish_reason=length
# could just mean the caller's token budget was exhausted, not the context window.
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
_inferred_nctx = (prompt_tok + comp_tok) or 0
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and content_parts:
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_chat", model, messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ochat_response(),
media_type="text/event-stream" if stream else "application/json",
)
@router.post("/v1/completions")
async def openai_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
optional_params = {
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options or {"include_usage": True },
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"suffix": suffix
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Cache lookup — completions prompt mapped to a single-turn messages list
_cache = get_llm_cache()
_compl_messages = [{"role": "user", "content": prompt}]
if _cache is not None and _cache_enabled:
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ocompl_stream():
yield _sse
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ocompl_json():
yield _cached
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
# Make the API call in handler scope (try/except inside async generators is unreliable)
try:
async_gen = await oclient.completions.create(**params)
except Exception as e:
if _is_backend_connection_error(e):
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, str(e))
await decrement_usage(endpoint, tracking_model)
raise
async def stream_ocompletions_response(model=model):
try:
if stream == True:
text_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
choice = chunk.choices[0]
has_text = getattr(choice, "text", None) is not None
has_reasoning = (
getattr(choice, "reasoning_content", None) is not None
or getattr(choice, "reasoning", None) is not None
)
if has_text or has_reasoning or choice.finish_reason is not None:
yield f"data: {data}\n\n".encode("utf-8")
if has_text and choice.text:
text_parts.append(choice.text)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and text_parts:
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
# Final DONE event
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ocompletions_response(),
media_type="text/event-stream" if stream else "application/json",
)
@router.get("/v1/models")
async def openai_models_proxy(request: Request):
"""
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
For Ollama endpoints: queries /api/tags (all models)
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
"""
config = get_config()
# 1. Query Ollama endpoints for all models via /api/tags
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
# 3. Query llama-server / llama-swap endpoints for advertised models via /v1/models
# Also query endpoints that may not be in config.endpoints
all_llama_endpoints = llama_endpoints(config)
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
for ep in all_llama_endpoints
]
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
models = {'data': []}
# Add Ollama models (if any)
if ollama_models:
for modellist in ollama_models:
for model in modellist:
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add external OpenAI models (if any)
if ext_openai_models:
for modellist in ext_openai_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add llama-server models (all available, not just loaded)
if llama_models:
for modellist in llama_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"data": dedupe_on_keys(models['data'], ['name'])},
status_code=200,
)
@router.post("/v1/rerank")
@router.post("/rerank")
async def rerank_proxy(request: Request):
"""
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
Compatible with the Jina/Cohere rerank API convention used by llama-server,
vLLM, and services such as Cohere and Jina AI.
Ollama does not natively support reranking; requests routed to a plain Ollama
endpoint will receive a 501 Not Implemented response.
Request body:
model (str, required) reranker model name
query (str, required) search query
documents (list[str], required) candidate documents to rank
top_n (int, optional) limit returned results (default: all)
return_documents (bool, optional) include document text in results
max_tokens_per_doc (int, optional) truncation limit per document
Response (Jina/Cohere-compatible):
{
"id": "...",
"model": "...",
"usage": {"prompt_tokens": N, "total_tokens": N},
"results": [{"index": 0, "relevance_score": 0.95}, ...]
}
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
query = payload.get("query")
documents = payload.get("documents")
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if not query:
raise HTTPException(status_code=400, detail="Missing required field 'query'")
if not isinstance(documents, list) or not documents:
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Determine which endpoint serves this model
try:
endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException(
status_code=501,
detail=(
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
"dedicated reranker model."
),
)
if ":latest" in model:
model = model.split(":latest")[0]
# Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
if optional_key in payload:
upstream_payload[optional_key] = payload[optional_key]
# Determine upstream URL:
# llama-server / llama-swap expose /v1/rerank (base already contains /v1)
# External OpenAI endpoints expose /rerank under their /v1 base
if is_llama_server(endpoint):
# llama-server / llama-swap: endpoint may or may not already contain /v1
if "/v1" in endpoint:
rerank_url = f"{endpoint}/rerank"
else:
rerank_url = f"{endpoint}/v1/rerank"
else:
# External OpenAI-compatible: ep2base gives us the /v1 base
rerank_url = f"{ep2base(endpoint)}/rerank"
api_key = config.api_keys.get(endpoint, "no-key")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
client: aiohttp.ClientSession = get_session(endpoint)
try:
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
response_bytes = await resp.read()
if resp.status >= 400:
raise HTTPException(
status_code=resp.status,
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
)
data = orjson.loads(response_bytes)
# Record token usage if the upstream returned a usage object
usage = data.get("usage") or {}
prompt_tok = usage.get("prompt_tokens") or 0
total_tok = usage.get("total_tokens") or 0
# For reranking there are no completion tokens; we record prompt tokens only
if prompt_tok or total_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
return JSONResponse(content=data)
finally:
await decrement_usage(endpoint, tracking_model)
async def _resolve_llama_swap_endpoint(model_id: str) -> str | None:
"""Pick the llama-swap endpoint that serves ``model_id``.
Prefers an endpoint that already has the worker running; falls back to any
that advertises the model. Returns None if none do.
"""
config = get_config()
swap_eps = config.llama_swap_endpoints
if not swap_eps:
return None
advertised = await asyncio.gather(
*[fetch.available_models(ep, config.api_keys.get(ep)) for ep in swap_eps]
)
candidates = [ep for ep, models in zip(swap_eps, advertised) if model_id in models]
if not candidates:
return None
if len(candidates) == 1:
return candidates[0]
loaded = await asyncio.gather(*[fetch.loaded_models(ep) for ep in candidates])
for ep, lm in zip(candidates, loaded):
if model_id in lm:
return ep
return candidates[0]
@router.api_route("/upstream/{model_id}/{path:path}", methods=["GET", "POST"])
async def llama_swap_upstream(model_id: str, path: str, request: Request):
"""Bypass llama-swap and reach a model's underlying llama-server worker directly
via llama-swap's ``/upstream/:model_id`` route.
Lets clients use llama-server features that llama-swap itself does not forward
(e.g. token-array prompts), while still letting the router pick the backend that
actually hosts the model. ``/upstream`` is a root route, so the ``/v1`` suffix is
stripped from the configured endpoint.
"""
config = get_config()
endpoint = await _resolve_llama_swap_endpoint(model_id)
if endpoint is None:
raise HTTPException(
status_code=404,
detail=f"No configured llama-swap endpoint serves model '{model_id}'.",
)
base_url = endpoint.rstrip("/").removesuffix("/v1")
url = f"{base_url}/upstream/{model_id}/{path}"
if request.url.query:
url = f"{url}?{request.url.query}"
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
content_type = request.headers.get("content-type")
if content_type:
headers["Content-Type"] = content_type
api_key = config.api_keys.get(endpoint)
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
body = await request.body()
client: aiohttp.ClientSession = get_session(endpoint)
try:
resp = await client.request(request.method, url, data=body or None, headers=headers)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Upstream request to {url} failed: {e}")
async def _iter():
try:
async for chunk in resp.content.iter_any():
yield chunk
finally:
resp.release()
return StreamingResponse(
_iter(),
status_code=resp.status,
media_type=resp.headers.get("Content-Type"),
)

View file

@ -1,398 +0,0 @@
"""OpenAI **Responses API** routes (``/v1/responses`` and its retrieve / delete /
cancel companions).
The router speaks Chat Completions to its backends, so this layer:
* **native** (external OpenAI): forwards via ``oclient.responses.create`` and
streams the SDK's typed events straight back, rewriting the response ``id`` to
a router-owned ``resp_`` id so chaining stays router-managed.
* **translated** (Ollama / llama-server): converts the request to chat, reuses
the resilient ``create_chat_with_retries`` ladder, and re-emits the result as
Responses typed SSE events (``requests/responses.py``).
State (``store`` / ``previous_response_id``) and background-task status live in the
router's SQLite DB (``db.py``); the router mints and owns every response id.
"""
import asyncio
import secrets
import time
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache
from config import get_config
from db import get_db
from fingerprint import _conversation_fingerprint
from state import token_queue, default_headers
from backends.normalize import is_ext_openai_endpoint
from backends.sessions import _make_openai_client
from routing import choose_endpoint, decrement_usage
from api.openai import create_chat_with_retries
from requests.responses import (
ChatToResponsesStream,
build_response_object,
chat_message_to_output_items,
messages_to_responses_input,
responses_input_to_messages,
responses_object_to_sse,
tools_responses_to_chat,
usage_chat_to_responses,
)
router = APIRouter()
# In-memory handles for background tasks so /cancel can reach a running task in
# this worker. Cross-worker cancel falls back to marking the DB row cancelled.
_background_tasks: dict[str, asyncio.Task] = {}
# ---------------------------------------------------------------------------
# small helpers
# ---------------------------------------------------------------------------
def _usage_tokens(usage):
"""Return ``(prompt, completion)`` tokens from a chat- or responses-shaped usage."""
if not usage:
return 0, 0
if "input_tokens" in usage:
return usage.get("input_tokens", 0) or 0, usage.get("output_tokens", 0) or 0
return usage.get("prompt_tokens", 0) or 0, usage.get("completion_tokens", 0) or 0
def _text_format_to_response_format(text):
"""Map Responses ``text.format`` → Chat Completions ``response_format`` (best effort)."""
if not isinstance(text, dict):
return None
fmt = text.get("format")
if not isinstance(fmt, dict):
return None
ftype = fmt.get("type")
if ftype == "json_object":
return {"type": "json_object"}
if ftype == "json_schema":
return {"type": "json_schema", "json_schema": {
k: fmt[k] for k in ("name", "schema", "strict", "description") if k in fmt
}}
return None
def _native_usage_from_response(data):
return data.get("usage")
async def _resolve_history_messages(previous_response_id):
"""Rebuild prior-turn chat messages from the stored response chain."""
if not previous_response_id:
return []
db = get_db()
chain = await db.get_response_chain(previous_response_id)
messages = []
for turn in chain:
# Each turn stored the chat messages that produced it + its output items.
for m in turn.get("input_messages") or []:
messages.append(m)
for item in turn.get("output_items") or []:
if item.get("type") == "message":
text = "".join(
p.get("text", "") for p in item.get("content") or []
if p.get("type") == "output_text"
)
if text:
messages.append({"role": "assistant", "content": text})
elif item.get("type") == "function_call":
messages.append({
"role": "assistant", "content": None,
"tool_calls": [{"id": item.get("call_id"), "type": "function",
"function": {"name": item.get("name"),
"arguments": item.get("arguments", "")}}],
})
return messages
class _NativeStream:
"""Re-emit an SDK Responses event stream, rewriting the response id and
capturing the final output/usage for storage."""
def __init__(self, response_id):
self.response_id = response_id
self.output_items = []
self.usage = None
async def events(self, sdk_gen):
async for event in sdk_gen:
data = event.model_dump() if hasattr(event, "model_dump") else event
etype = data.get("type", "")
resp = data.get("response")
if isinstance(resp, dict) and resp.get("id"):
resp["id"] = self.response_id
if etype in ("response.completed", "response.incomplete", "response.failed") \
and isinstance(resp, dict):
self.output_items = resp.get("output", []) or []
self.usage = resp.get("usage")
yield f"event: {etype}\ndata: {orjson.dumps(data).decode('utf-8')}\n\n".encode("utf-8")
# ---------------------------------------------------------------------------
# backend execution (non-streaming, used by background + non-stream sync)
# ---------------------------------------------------------------------------
async def _run_to_completion(*, native, oclient, endpoint, model, tracking_model,
send_params, native_params):
"""Drive the backend to completion (no client streaming).
Returns ``(output_items, usage)`` where usage is responses-shaped. Caller is
responsible for ``decrement_usage`` (translated failures self-decrement inside
``create_chat_with_retries``)."""
if native:
resp_obj = await oclient.responses.create(stream=False, **native_params)
data = resp_obj.model_dump()
return data.get("output", []) or [], data.get("usage")
async_gen = await create_chat_with_retries(oclient, {**send_params, "stream": False},
endpoint, model, tracking_model)
message = async_gen.choices[0].message.model_dump() if async_gen.choices else {}
output_items = chat_message_to_output_items(message)
usage = usage_chat_to_responses(
async_gen.usage.model_dump() if async_gen.usage is not None else None
)
return output_items, usage
# ---------------------------------------------------------------------------
# POST /v1/responses
# ---------------------------------------------------------------------------
@router.post("/v1/responses")
async def openai_responses_proxy(request: Request):
config = get_config()
try:
payload = orjson.loads((await request.body()).decode("utf-8"))
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
model = payload.get("model")
input_data = payload.get("input")
instructions = payload.get("instructions")
stream = bool(payload.get("stream"))
store = payload.get("store", True)
background = bool(payload.get("background"))
previous_response_id = payload.get("previous_response_id")
tools = payload.get("tools")
metadata = payload.get("metadata") or {}
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if input_data is None:
raise HTTPException(status_code=400, detail="Missing required field 'input'")
if background and not store:
raise HTTPException(status_code=400, detail="background mode requires store=true")
if ":latest" in model:
model = model.split(":latest")[0]
# Resolve conversation: prior turns (from store) + this turn's input.
history = await _resolve_history_messages(previous_response_id)
messages = history + responses_input_to_messages(input_data, instructions)
response_id = f"resp_{secrets.token_hex(24)}"
created_at = int(time.time())
# Cache lookup (foreground only) — before endpoint selection.
_cache = get_llm_cache()
if _cache is not None and _cache_enabled and not background:
cached = await _cache.get_chat("openai_responses", model, messages)
if cached is not None:
resp_obj = orjson.loads(cached)
resp_obj["id"] = response_id
if stream:
async def _served_cached():
yield responses_object_to_sse(resp_obj)
return StreamingResponse(_served_cached(), media_type="text/event-stream")
return JSONResponse(content=resp_obj)
# Endpoint selection (reserves a slot — must be released exactly once).
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers,
api_key=config.api_keys.get(endpoint, "no-key"))
native = is_ext_openai_endpoint(endpoint)
# Build backend params for both shapes.
send_params = {"messages": messages, "model": model}
_opt = {
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"max_tokens": payload.get("max_output_tokens"),
"tools": tools_responses_to_chat(tools),
"tool_choice": payload.get("tool_choice"),
"response_format": _text_format_to_response_format(payload.get("text")),
}
send_params.update({k: v for k, v in _opt.items() if v is not None})
native_instructions, native_input = messages_to_responses_input(messages)
native_params = {"model": model, "input": native_input, "store": False}
_nopt = {
"instructions": native_instructions,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"max_output_tokens": payload.get("max_output_tokens"),
"tools": tools,
"tool_choice": payload.get("tool_choice"),
"text": payload.get("text"),
"reasoning": payload.get("reasoning"),
}
native_params.update({k: v for k, v in _nopt.items() if v is not None})
async def _persist(status, output_items=None, usage=None, error=None, insert=False):
if not store:
return
db = get_db()
if insert:
await db.store_response(
response_id, previous_response_id=previous_response_id, model=model,
status=status, created_at=created_at, input_messages=messages,
output_items=output_items, usage=usage, instructions=instructions, error=error)
else:
await db.update_response_status(response_id, status, output_items=output_items,
usage=usage, error=error)
async def _track(usage):
prompt_tok, comp_tok = _usage_tokens(usage)
if prompt_tok or comp_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
async def _cache_store(output_items, usage):
if _cache is None or not _cache_enabled or not output_items:
return
obj = build_response_object(response_id=response_id, model=model,
output_items=output_items, usage=usage,
created_at=created_at,
previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
try:
await _cache.set_chat("openai_responses", model, messages, orjson.dumps(obj))
except Exception as _ce:
print(f"[cache] set_chat (openai_responses) failed: {_ce}")
# ---- background: run detached, return queued immediately --------------
if background:
await _persist("queued", insert=True)
async def _bg_run():
try:
await get_db().update_response_status(response_id, "in_progress")
output_items, usage = await _run_to_completion(
native=native, oclient=oclient, endpoint=endpoint, model=model,
tracking_model=tracking_model, send_params=send_params,
native_params=native_params)
await _track(usage)
await _persist("completed", output_items=output_items, usage=usage)
await _cache_store(output_items, usage)
except asyncio.CancelledError:
await get_db().update_response_status(response_id, "cancelled")
raise
except Exception as e:
await get_db().update_response_status(
response_id, "failed",
error={"message": str(e)[:500], "type": type(e).__name__})
finally:
await decrement_usage(endpoint, tracking_model)
_background_tasks.pop(response_id, None)
task = asyncio.create_task(_bg_run())
_background_tasks[response_id] = task
queued = build_response_object(response_id=response_id, model=model, output_items=[],
status="queued", created_at=created_at,
previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
return JSONResponse(content=queued, status_code=200)
# ---- streaming sync ----------------------------------------------------
if stream:
if native:
source = await oclient.responses.create(stream=True, **native_params)
translator = _NativeStream(response_id)
else:
source = await create_chat_with_retries(
oclient, {**send_params, "stream": True,
"stream_options": {"include_usage": True}},
endpoint, model, tracking_model)
translator = ChatToResponsesStream(
response_id, model, created_at=created_at,
previous_response_id=previous_response_id, instructions=instructions,
metadata=metadata)
async def _stream():
await _persist("in_progress", insert=True)
try:
async for sse in translator.events(source):
yield sse
await _track(translator.usage)
await _persist("completed", output_items=translator.output_items,
usage=translator.usage)
await _cache_store(translator.output_items, translator.usage)
finally:
await decrement_usage(endpoint, tracking_model)
return StreamingResponse(_stream(), media_type="text/event-stream")
# ---- non-streaming sync ------------------------------------------------
try:
output_items, usage = await _run_to_completion(
native=native, oclient=oclient, endpoint=endpoint, model=model,
tracking_model=tracking_model, send_params=send_params,
native_params=native_params)
await _track(usage)
await _persist("completed", output_items=output_items, usage=usage, insert=True)
await _cache_store(output_items, usage)
finally:
await decrement_usage(endpoint, tracking_model)
resp_obj = build_response_object(
response_id=response_id, model=model, output_items=output_items, usage=usage,
created_at=created_at, previous_response_id=previous_response_id,
instructions=instructions, metadata=metadata)
return JSONResponse(content=resp_obj)
# ---------------------------------------------------------------------------
# GET / DELETE / cancel
# ---------------------------------------------------------------------------
def _stored_to_response_object(row):
return build_response_object(
response_id=row["response_id"], model=row.get("model"),
output_items=row.get("output_items") or [], usage=row.get("usage"),
status=row.get("status") or "completed", created_at=row.get("created_at"),
previous_response_id=row.get("previous_response_id"),
instructions=row.get("instructions"), error=row.get("error"))
@router.get("/v1/responses/{response_id}")
async def get_response(response_id: str):
row = await get_db().get_response(response_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
return JSONResponse(content=_stored_to_response_object(row))
@router.delete("/v1/responses/{response_id}")
async def delete_response(response_id: str):
deleted = await get_db().delete_response(response_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
return JSONResponse(content={"id": response_id, "object": "response.deleted", "deleted": True})
@router.post("/v1/responses/{response_id}/cancel")
async def cancel_response(response_id: str):
row = await get_db().get_response(response_id)
if row is None:
raise HTTPException(status_code=404, detail=f"Response '{response_id}' not found")
# Cancel the running task if it lives in this worker; otherwise just mark the
# DB row so a polling client sees a terminal state (cross-worker limitation).
task = _background_tasks.get(response_id)
if task is not None and not task.done():
task.cancel()
elif row.get("status") in ("queued", "in_progress"):
await get_db().update_response_status(response_id, "cancelled")
row = await get_db().get_response(response_id)
return JSONResponse(content=_stored_to_response_object(row))

View file

@ -1,30 +0,0 @@
"""Static-asset and dashboard routes."""
from pathlib import Path
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import HTMLResponse, RedirectResponse
# Directory containing static files (resolved relative to project root).
STATIC_DIR = Path(__file__).resolve().parent.parent / "static"
router = APIRouter()
@router.get("/favicon.ico")
async def redirect_favicon():
return RedirectResponse(url="/static/favicon.ico")
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""
Render the dynamic NOMYO Router dashboard listing the configured endpoints
and the models details, availability & task status.
"""
index_path = STATIC_DIR / "index.html"
try:
return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Page not found")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

View file

View file

@ -1,50 +0,0 @@
"""Backend control operations (model unload).
llama-server and llama-swap evict a resident model through different routes:
* llama-server ``POST {base}/models/unload`` with body ``{"model": id}``
* llama-swap ``POST {base}/api/models/unload/{id}`` (path parameter)
``unload_model`` dispatches on the configured backend type so callers don't
have to know which one they are talking to. Both routes live at the endpoint
root, so any ``/v1`` suffix is stripped first.
"""
from typing import Optional
import aiohttp
from config import get_config
from state import default_headers
from backends.sessions import get_probe_session
from backends.normalize import is_llama_swap
from backends.health import _format_connection_issue
async def unload_model(endpoint: str, model_id: str) -> bool:
"""Ask ``endpoint`` to unload ``model_id``. Returns True on a 2xx response.
``model_id`` must be the backend's native model identifier (the raw HF id
for llama-server / llama-swap), not the router-normalized display name.
"""
cfg = get_config()
base_url = endpoint.rstrip("/").removesuffix("/v1")
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
api_key: Optional[str] = cfg.api_keys.get(endpoint)
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
if is_llama_swap(endpoint):
url = f"{base_url}/api/models/unload/{model_id}"
json_body = None
else:
url = f"{base_url}/models/unload"
json_body = {"model": model_id}
client: aiohttp.ClientSession = get_probe_session(endpoint)
try:
async with client.post(url, json=json_body, headers=headers) as resp:
ok = resp.status < 400
print(f"[unload_model] {model_id} on {endpoint}: {resp.status}")
return ok
except Exception as e:
print(f"[unload_model] {_format_connection_issue(url, e)}")
return False

View file

@ -1,136 +0,0 @@
"""Backend health probes and error classification helpers.
Contains:
* cache-freshness check (``_is_fresh``)
* aiohttp response success assertion (``_ensure_success``)
* human-readable connection-issue formatter
* upstream-error detection that distinguishes connection failures from
legitimate 4xx responses (``_is_backend_connection_error``)
* per-(endpoint, model) unhealthy marker that feeds ``choose_endpoint``
* llama-server status interpretation (``_is_llama_model_loaded`` etc.)
"""
import asyncio
import time
from urllib.parse import urlparse
import aiohttp
import openai
from fastapi import HTTPException
from security import _mask_secrets
from state import _completion_error_cache, _completion_error_cache_lock
def _is_fresh(cached_at: float, ttl: int) -> bool:
return (time.time() - cached_at) < ttl
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
if resp.status >= 400:
text = await resp.text()
raise HTTPException(status_code=resp.status, detail=_mask_secrets(text))
def _format_connection_issue(url: str, error: Exception) -> str:
"""
Provide a human-friendly error string for connection failures so operators
know which endpoint and address failed from inside the container.
"""
parsed = urlparse(url)
host_hint = parsed.hostname or ""
port_hint = parsed.port or ""
if isinstance(error, aiohttp.ClientConnectorError):
resolved_host = getattr(error, "host", host_hint) or host_hint or "?"
resolved_port = getattr(error, "port", port_hint) or port_hint or "?"
parts = [
f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).",
"Ensure the endpoint address is reachable from within the container.",
]
if resolved_host in {"localhost", "127.0.0.1"}:
parts.append(
"Inside Docker, 'localhost' refers to the container itself; use "
"'host.docker.internal' or a Docker network alias if the service "
"runs on the host machine."
)
os_error = getattr(error, "os_error", None)
if isinstance(os_error, OSError):
errno = getattr(os_error, "errno", None)
strerror = os_error.strerror or str(os_error)
if errno is not None or strerror:
parts.append(f"OS error [{errno}]: {strerror}.")
elif os_error:
parts.append(f"OS error: {os_error}.")
parts.append(f"Original error: {error}.")
return " ".join(parts)
if isinstance(error, asyncio.TimeoutError):
return (
f"Timed out waiting for {url}. "
"The remote endpoint may be offline or slow to respond."
)
return f"Error while contacting {url}: {error}"
def _is_backend_connection_error(exc: Exception) -> bool:
"""True for upstream connection-class failures observed via the OpenAI client.
Targets the case where a llama-server in router mode keeps answering
/v1/models but its delegated worker for a specific model is dead, so
chat/completions calls return 5xx with 'proxy error: Could not establish
connection' (or the SDK raises APIConnectionError outright).
Excludes BadRequestError with exceed_context_size_error by design those
must stay on the reactive-trim path.
"""
if isinstance(exc, openai.APIConnectionError):
return True
if isinstance(exc, openai.InternalServerError):
msg = str(exc).lower()
return (
"proxy error" in msg
or "could not establish connection" in msg
or "connection refused" in msg
)
return False
async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None:
"""Record (endpoint, model) as broken so choose_endpoint avoids it.
Cleared only by TTL the dead-worker failure mode is invisible to the
/v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot
rely on a successful probe as a recovery signal.
"""
async with _completion_error_cache_lock:
_completion_error_cache[(endpoint, model)] = time.time()
print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True)
def _is_llama_model_loaded(item: dict) -> bool:
"""Return True if a llama-server /v1/models item has status 'loaded'.
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
If no status field is present, the model is always-loaded (not dynamically managed)."""
status = item.get("status")
if status is None:
return True # No status field: model is always loaded (e.g. single-model servers)
if isinstance(status, dict):
return status.get("value") == "loaded"
if isinstance(status, str):
return status == "loaded"
return False
def _is_llama_model_loaded_or_sleeping(item: dict) -> bool:
"""Return True if status is 'loaded' or 'sleeping'.
Newer llama-server versions report 'sleeping' in /v1/models when a model is idle;
ps_details needs to include these so _fetch_llama_props can detect and unload them."""
status = item.get("status")
if status is None:
return True
if isinstance(status, dict):
return status.get("value") in ("loaded", "sleeping")
if isinstance(status, str):
return status in ("loaded", "sleeping")
return False

View file

@ -1,132 +0,0 @@
"""Endpoint URL, model-name, and endpoint-classification helpers.
The endpoint classifiers read live config via ``get_config()`` so that the
startup-time rebind of ``config`` in router.py is picked up at call time.
"""
from config import get_config
def _normalize_llama_model_name(name: str) -> str:
"""Extract the model name from a huggingface-style identifier.
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
"""
if "/" in name:
name = name.rsplit("/", 1)[1]
if ":" in name:
name = name.split(":")[0]
return name
def _extract_llama_quant(name: str) -> str:
"""Extract the quantization level from a huggingface-style identifier.
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
Returns empty string if no quant suffix is present.
"""
if ":" in name:
return name.rsplit(":", 1)[1]
return ""
def ep2base(ep):
if "/v1" in ep:
base_url = ep
else:
base_url = ep + "/v1"
return base_url
def dedupe_on_keys(dicts, key_fields):
"""
Helper function to deduplicate endpoint details based on given dict keys.
"""
seen = set()
out = []
for d in dicts:
# Build a tuple of the values for the chosen keys
key = tuple(d.get(k) for k in key_fields)
if key not in seen:
seen.add(key)
out.append(d)
return out
def is_llama_swap(endpoint: str) -> bool:
"""True if the endpoint is a configured llama-swap front."""
return endpoint in get_config().llama_swap_endpoints
def is_llama_server(endpoint: str) -> bool:
"""True for a llama.cpp llama-server OR a llama-swap front.
Both speak the same OpenAI-compatible surface, so the router treats them
identically everywhere except loaded-model detection and model unload.
"""
cfg = get_config()
return endpoint in cfg.llama_server_endpoints or endpoint in cfg.llama_swap_endpoints
def llama_endpoints(cfg) -> list:
"""Combined, de-duplicated llama-server + llama-swap endpoints (order preserved)."""
return list(dict.fromkeys([*cfg.llama_server_endpoints, *cfg.llama_swap_endpoints]))
def is_ext_openai_endpoint(endpoint: str) -> bool:
"""
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama, llama-server or llama-swap).
Returns True for:
- External services like OpenAI.com, Groq, etc.
Returns False for:
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
- llama-server / llama-swap endpoints (explicitly configured)
"""
# Check if it's a llama-server / llama-swap endpoint (has /v1 and is in a configured list)
if is_llama_server(endpoint):
return False
if "/v1" not in endpoint:
return False
base_endpoint = endpoint.replace('/v1', '')
if base_endpoint in get_config().endpoints:
return False # It's Ollama's /v1
# Check for default Ollama port
if ':11434' in endpoint:
return False # It's Ollama
return True # It's an external OpenAI endpoint
def is_openai_compatible(endpoint: str) -> bool:
"""
Return True if the endpoint speaks the OpenAI API (not native Ollama).
This includes external OpenAI endpoints AND llama-server / llama-swap endpoints.
"""
return "/v1" in endpoint or is_llama_server(endpoint)
def get_tracking_model(endpoint: str, model: str) -> str:
"""
Normalize model name for tracking purposes so it matches the PS table key.
- For llama-server endpoints: strips HF prefix and quantization suffix
- For Ollama endpoints: appends ":latest" if no version suffix is present
- For external OpenAI endpoints: returns as-is (not shown in PS)
This ensures consistent model naming across all routes for usage tracking.
"""
# External OpenAI endpoints are not shown in PS, keep as-is
if is_ext_openai_endpoint(endpoint):
return model
# llama-server / llama-swap endpoints use normalized names in PS
if is_llama_server(endpoint):
return _normalize_llama_model_name(model)
# Ollama endpoints: append ":latest" if no version suffix
if ":" not in model:
return model + ":latest"
return model

View file

@ -1,488 +0,0 @@
"""Backend probe / discovery primitives.
The ``fetch`` class wraps the three discovery paths the router uses:
* ``available_models`` what the endpoint advertises (Ollama ``/api/tags``
or OpenAI-style ``/v1/models``)
* ``loaded_models`` what is currently resident (Ollama ``/api/ps`` or
llama-server ``/v1/models`` filtered on ``status == "loaded"``)
* ``endpoint_details`` arbitrary detail fetch used by management routes
Each path goes through three layers of cache: success cache, error cache,
and an in-flight request map. Stale-while-revalidate refreshes happen in
background tasks tracked by the ``_bg_refresh_*`` maps in ``state``.
``_raw_probe`` and ``_endpoint_health`` are the lower-level dual probes
used by ``/health`` and ``/api/config`` to distinguish a healthy daemon
with a broken model-introspection path from a dead daemon.
"""
import asyncio
import time
from typing import List, Optional, Set
import aiohttp
from config import get_config
from state import (
_models_cache,
_models_cache_lock,
_loaded_models_cache,
_loaded_models_cache_lock,
_available_error_cache,
_available_error_cache_lock,
_loaded_error_cache,
_loaded_error_cache_lock,
_inflight_available_models,
_inflight_loaded_models,
_inflight_lock,
_bg_refresh_available,
_bg_refresh_loaded,
_bg_refresh_lock,
default_headers,
)
from backends.sessions import get_probe_session
from backends.health import (
_is_fresh,
_ensure_success,
_format_connection_issue,
_is_llama_model_loaded,
)
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible, is_llama_server, is_llama_swap
class fetch:
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Internal function that performs the actual HTTP request to fetch available models.
This is called by available_models() after checking caches and in-flight requests.
"""
cfg = get_config()
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
ep_base = endpoint.rstrip("/")
if is_llama_server(endpoint) and "/v1" not in endpoint:
endpoint_url = f"{ep_base}/v1/models"
key = "data"
elif "/v1" in endpoint or is_llama_server(endpoint):
endpoint_url = f"{ep_base}/models"
key = "data"
else:
endpoint_url = f"{ep_base}/api/tags"
key = "models"
client: aiohttp.ClientSession = get_probe_session(endpoint)
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
async with _models_cache_lock:
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
message = _format_connection_issue(endpoint_url, e)
print(f"[fetch.available_models] {message}")
# Update error cache with lock protection
async with _available_error_cache_lock:
_available_error_cache[endpoint] = time.time()
return set()
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
"""
Background task to refresh available models cache without blocking the caller.
Used for stale-while-revalidate pattern.
Deduplicates: only one background refresh runs per endpoint at a time.
"""
async with _bg_refresh_lock:
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
return # A refresh is already running for this endpoint
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
_bg_refresh_available[endpoint] = task
try:
await task
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
finally:
async with _bg_refresh_lock:
if _bg_refresh_available.get(endpoint) is task:
_bg_refresh_available.pop(endpoint, None)
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
Uses request coalescing to prevent cache stampede: if multiple requests
arrive when cache is expired, only one actual HTTP request is made.
Uses stale-while-revalidate: when the cache is between 300-600s old,
the stale data is returned immediately while a background refresh runs.
This prevents model blackouts caused by transient timeouts.
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
# Check models cache with lock protection
async with _models_cache_lock:
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
# FRESH: <= 300s old - return immediately
if _is_fresh(cached_at, 300):
return models
# STALE: 300-600s old - return stale data and refresh in background
if _is_fresh(cached_at, 600):
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
return models # Return stale data immediately
# EXPIRED: > 600s old - too stale, must refresh synchronously
del _models_cache[endpoint]
# Check error cache with lock protection
async with _available_error_cache_lock:
if endpoint in _available_error_cache:
err_age = time.time() - _available_error_cache[endpoint]
if err_age < 30:
# Very fresh error (<30s) endpoint likely still down, bail fast
return set()
elif err_age < 300:
# Stale error (30-300s) endpoint may have recovered, probe in background
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
return set()
# Error expired (>300s) remove and fall through to fresh fetch
del _available_error_cache[endpoint]
# Request coalescing: check if another request is already fetching this endpoint
async with _inflight_lock:
if endpoint in _inflight_available_models:
# Another request is already fetching - wait for it
task = _inflight_available_models[endpoint]
else:
# Create new fetch task
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
_inflight_available_models[endpoint] = task
try:
# Wait for the fetch to complete (either ours or another request's)
result = await task
return result
finally:
# Clean up in-flight tracking (only if we created it)
async with _inflight_lock:
if _inflight_available_models.get(endpoint) == task:
_inflight_available_models.pop(endpoint, None)
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
"""
Internal function that performs the actual HTTP request to fetch loaded models.
This is called by loaded_models() after checking caches and in-flight requests.
For Ollama endpoints: queries /api/ps and returns model names
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
"""
client: aiohttp.ClientSession = get_probe_session(endpoint)
cfg = get_config()
# llama-swap: loaded/running workers are reported at /running (state == "ready"),
# NOT via a status field on /v1/models (which it omits). /running is a root route,
# so strip any /v1 suffix from the configured endpoint.
if is_llama_swap(endpoint):
base_url = endpoint.rstrip("/").removesuffix("/v1")
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
api_key = cfg.api_keys.get(endpoint)
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
try:
async with client.get(f"{base_url}/running", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
models = {
item.get("model")
for item in data.get("running", [])
if item.get("model") and item.get("state") == "ready"
}
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
message = _format_connection_issue(f"{base_url}/running", e)
print(f"[fetch.loaded_models] {message}")
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
# Check if this is a llama-server endpoint
if endpoint in cfg.llama_server_endpoints:
# Query /v1/models for llama-server. Send the configured key as a
# Bearer token — current llama.cpp leaves /models public, but a
# build/config that protects it would otherwise 401 this probe.
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
api_key = cfg.api_keys.get(endpoint)
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
try:
async with client.get(f"{endpoint}/models", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
# Filter for loaded models only
items = data.get("data", [])
models = {
item.get("id")
for item in items
if item.get("id") and _is_llama_model_loaded(item)
}
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
# Probe succeeded — clear any stale error so the endpoint
# becomes routable again.
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/models", e)
print(f"[fetch.loaded_models] {message}")
# Record the failure so `choose_endpoint` can avoid routing
# to an unhealthy backend and repeated probes short-circuit.
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
else:
# Original Ollama /api/ps logic
try:
async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp)
data = await resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/api/ps", e)
print(f"[fetch.loaded_models] {message}")
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
async def _refresh_loaded_models(endpoint: str) -> None:
"""
Background task to refresh loaded models cache without blocking the caller.
Used for stale-while-revalidate pattern.
Deduplicates: only one background refresh runs per endpoint at a time.
"""
async with _bg_refresh_lock:
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
return # A refresh is already running for this endpoint
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
_bg_refresh_loaded[endpoint] = task
try:
await task
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
finally:
async with _bg_refresh_lock:
if _bg_refresh_loaded.get(endpoint) is task:
_bg_refresh_loaded.pop(endpoint, None)
async def loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
Uses request coalescing to prevent cache stampede and stale-while-revalidate
to serve requests immediately even when cache is stale (refreshing in background).
"""
if is_ext_openai_endpoint(endpoint):
return set()
# Check loaded models cache with lock protection
async with _loaded_models_cache_lock:
if endpoint in _loaded_models_cache:
models, cached_at = _loaded_models_cache[endpoint]
# FRESH: < 10s old - return immediately
if _is_fresh(cached_at, 10):
return models
# STALE: 10-60s old - return stale data and refresh in background
if _is_fresh(cached_at, 60):
# Kick off background refresh (fire-and-forget)
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
return models # Return stale data immediately
# EXPIRED: > 60s old - too stale, must refresh synchronously
del _loaded_models_cache[endpoint]
# Check error cache with lock protection
async with _loaded_error_cache_lock:
if endpoint in _loaded_error_cache:
if _is_fresh(_loaded_error_cache[endpoint], 300):
return set()
# Error expired - remove it
del _loaded_error_cache[endpoint]
# Request coalescing: check if another request is already fetching this endpoint
async with _inflight_lock:
if endpoint in _inflight_loaded_models:
# Another request is already fetching - wait for it
task = _inflight_loaded_models[endpoint]
else:
# Create new fetch task
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
_inflight_loaded_models[endpoint] = task
try:
# Wait for the fetch to complete (either ours or another request's)
result = await task
return result
finally:
# Clean up in-flight tracking (only if we created it)
async with _inflight_lock:
if _inflight_loaded_models.get(endpoint) == task:
_inflight_loaded_models.pop(endpoint, None)
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
When ``skip_error_cache`` is False (the default), the call is short-circuited
if the endpoint recently failed (recorded in ``_available_error_cache``).
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
``timeout`` overrides the session default for this single request (seconds, total).
"""
# Fast-fail if the endpoint is known to be down (unless caller opts out)
if not skip_error_cache:
async with _available_error_cache_lock:
if endpoint in _available_error_cache:
if _is_fresh(_available_error_cache[endpoint], 300):
return []
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}"
client: aiohttp.ClientSession = get_probe_session(endpoint)
req_kwargs = {}
if timeout is not None:
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
try:
async with client.get(request_url, headers=headers, **req_kwargs) as resp:
await _ensure_success(resp)
data = await resp.json()
detail = data.get(detail, [])
return detail
except Exception as e:
# If anything goes wrong we cannot reply details
message = _format_connection_issue(request_url, e)
print(f"[fetch.endpoint_details] {message}")
if not skip_error_cache:
async with _available_error_cache_lock:
_available_error_cache[endpoint] = time.time()
return []
# -------------------------------------------------------------
# Endpoint health probes (shared by /api/config and /health)
# -------------------------------------------------------------
async def _raw_probe(
ep: str,
route: str,
api_key: Optional[str] = None,
timeout: Optional[float] = None,
) -> tuple[bool, object]:
"""Direct HTTP probe that distinguishes success from failure
(unlike `fetch.endpoint_details`, which returns [] on either).
Returns `(ok, payload_or_error_message)`.
"""
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
req_kwargs = {}
if timeout is not None:
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
try:
client: aiohttp.ClientSession = get_probe_session(ep)
async with client.get(url, headers=headers, **req_kwargs) as resp:
await _ensure_success(resp)
data = await resp.json()
return True, data
except Exception as exc:
return False, _format_connection_issue(url, exc)
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
"""Probe an endpoint and return `{status, version?, detail?}`.
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
that a daemon which is reachable but has a broken model-introspection
path (issue #83) is reported as `error` rather than `ok`.
OpenAI-compatible endpoints use a single `/models` probe.
"""
if is_openai_compatible(ep):
ok, payload = await _raw_probe(
ep, "/models", get_config().api_keys.get(ep), timeout=timeout,
)
if ok:
return {"status": "ok", "version": "latest"}
return {"status": "error", "detail": str(payload)}
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
_raw_probe(ep, "/api/version", timeout=timeout),
_raw_probe(ep, "/api/ps", timeout=timeout),
)
version_value = (
version_payload.get("version")
if version_ok and isinstance(version_payload, dict)
else None
)
if version_ok and ps_ok:
return {"status": "ok", "version": version_value}
if not version_ok and not ps_ok:
return {"status": "error", "detail": str(version_payload)}
# Partial failure — daemon reachable but one probe failed. Report
# as "error" so callers can surface the issue; include `version` so
# the operator knows the daemon itself is alive.
if not ps_ok:
return {
"status": "error",
"version": version_value,
"detail": f"/api/ps: {ps_payload}",
}
return {
"status": "error",
"detail": f"/api/version: {version_payload}",
}

View file

@ -1,121 +0,0 @@
"""aiohttp / OpenAI client factories aware of Unix-socket endpoints.
Unix socket endpoints follow the ``.sock`` hostname convention (e.g.
``http://192.168.0.52.sock/v1``) and resolve to ``/run/user/<uid>/<host>``.
Their sessions/clients live in ``state.app_state`` so that startup can
populate them once and routes can reuse them.
"""
import os
import aiohttp
import ollama
import openai
from state import app_state
from backends.normalize import ep2base
def _is_unix_socket_endpoint(endpoint: str) -> bool:
"""Return True if endpoint uses Unix socket (.sock hostname convention).
Detects URLs like http://192.168.0.52.sock/v1 where the host ends with
.sock, indicating the connection should use a Unix domain socket at
/tmp/<host> instead of TCP.
"""
try:
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
return host.endswith(".sock")
except IndexError:
return False
def _get_socket_path(endpoint: str) -> str:
"""Derive Unix socket file path from a .sock endpoint URL.
http://192.168.0.52.sock/v1 -> /run/user/<uid>/192.168.0.52.sock
"""
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
return f"/run/user/{os.getuid()}/{host}"
def get_session(endpoint: str) -> aiohttp.ClientSession:
"""Return the appropriate aiohttp session for the given endpoint.
Unix socket endpoints (.sock) get their own UnixConnector session.
All other endpoints share the main TCP session.
"""
if _is_unix_socket_endpoint(endpoint):
sess = app_state["socket_sessions"].get(endpoint)
if sess is not None:
return sess
return app_state["session"]
def get_probe_session(endpoint: str) -> aiohttp.ClientSession:
"""Return the session used for lightweight health/introspection probes.
Probes (available/loaded models, endpoint health) run on a connection
pool kept separate from the proxy/streaming session, so a burst of
long-lived completion requests cannot starve them otherwise a probe
would queue waiting for a connection, hit its deadline, and mark a
perfectly healthy endpoint as unavailable under load.
Unix socket endpoints keep their dedicated per-endpoint session. TCP
endpoints use the shared probe session, falling back to the main
session when the probe pool has not been initialised (e.g. in tests).
"""
if _is_unix_socket_endpoint(endpoint):
sess = app_state["socket_sessions"].get(endpoint)
if sess is not None:
return sess
return app_state.get("probe_session") or app_state["session"]
def get_ollama_client(endpoint: str) -> ollama.AsyncClient:
"""Return a cached ``ollama.AsyncClient`` for the endpoint, creating it once.
``ollama.AsyncClient`` wraps an ``httpx.AsyncClient`` whose construction
builds an SSL context and reloads the OS trust store (~40 ms). It is safe to
reuse concurrently, so we keep one per endpoint instead of building a fresh
one on every request otherwise that 40 ms of CPU runs on the event loop
per request and caps single-worker throughput at ~25 req/s.
"""
cache = app_state["ollama_clients"]
client = cache.get(endpoint)
if client is None:
client = ollama.AsyncClient(host=endpoint)
cache[endpoint] = client
return client
def _make_openai_client(
endpoint: str,
default_headers: dict | None = None,
api_key: str = "no-key",
) -> openai.AsyncOpenAI:
"""Return a cached AsyncOpenAI client configured for the given endpoint.
Clients are cached per ``(endpoint, api_key)`` and reused across requests:
constructing one builds an SSL context and reloads the OS trust store
(~40 ms), which serializes the event loop if done per request. For Unix
socket endpoints, injects the pre-created httpx UDS transport so the OpenAI
SDK connects via the socket instead of TCP.
"""
cache = app_state["openai_clients"]
cache_key = (endpoint, api_key)
client = cache.get(cache_key)
if client is not None:
return client
base_url = ep2base(endpoint)
kwargs: dict = {"api_key": api_key}
if default_headers is not None:
kwargs["default_headers"] = default_headers
if _is_unix_socket_endpoint(endpoint):
http_client = app_state["httpx_clients"].get(endpoint)
if http_client is not None:
kwargs["http_client"] = http_client
base_url = "http://localhost/v1"
client = openai.AsyncOpenAI(base_url=base_url, **kwargs)
cache[cache_key] = client
return client

143
config.py
View file

@ -1,143 +0,0 @@
"""Router configuration loader.
Pydantic ``BaseSettings`` model populated from YAML (path resolved via
``_config_path_from_env``) with ``${VAR}`` expansion, plus env-var overrides
under the ``NOMYO_ROUTER_`` prefix.
"""
import os
import re
from pathlib import Path
from typing import Dict, List, Optional
import yaml
from pydantic import Field
from pydantic_settings import BaseSettings
class Config(BaseSettings):
# List of Ollama endpoints
endpoints: list[str] = Field(
default_factory=lambda: [
"http://localhost:11434",
]
)
# List of llama-server endpoints (OpenAI-compatible with /v1/models status info)
llama_server_endpoints: List[str] = Field(default_factory=list)
# List of llama-swap endpoints (OpenAI-compatible front for multiple llama-server
# workers). Same surface as llama_server_endpoints, but loaded models are read from
# /running (not /v1/models status) and unload uses POST /api/models/unload/:model_id.
llama_swap_endpoints: List[str] = Field(default_factory=list)
# Max concurrent connections per endpointmodel pair, see OLLAMA_NUM_PARALLEL
max_concurrent_connections: int = 1
# Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}}
endpoint_config: Dict[str, Dict] = Field(default_factory=dict)
# When True, config order = priority; routes by utilization ratio + config index (WRR)
priority_routing: bool = Field(default=False)
# Conversation affinity: route the same conversation back to the endpoint that
# previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm.
# Soft preference — falls back to the standard algorithm when the affine endpoint
# is saturated or no longer has the model loaded.
conversation_affinity: bool = Field(default=False)
# TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min):
# if the backend has already evicted the model, the KV cache is cold anyway.
conversation_affinity_ttl: int = Field(default=300)
api_keys: Dict[str, str] = Field(default_factory=dict)
# Optional router-level API key used to gate access to this service and dashboard
router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY")
# Database configuration
db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db"))
# Semantic LLM Cache configuration
cache_enabled: bool = Field(default=False)
# Backend: "memory" (default, in-process), "sqlite" (persistent), "redis" (distributed)
cache_backend: str = Field(default="memory")
# Cosine similarity threshold: 1.0 = exact match only, <1.0 = semantic (requires :semantic image)
cache_similarity: float = Field(default=1.0)
# TTL in seconds; None = cache forever
cache_ttl: Optional[int] = Field(default=3600)
# SQLite backend: path to cache database file
cache_db_path: str = Field(default="llm_cache.db")
# Redis backend: connection URL
cache_redis_url: str = Field(default="redis://localhost:6379/0")
# Weight of BM25-weighted chat-history embedding vs last-user-message embedding
# 0.3 = 30% history context signal, 70% question signal
cache_history_weight: float = Field(default=0.3)
class Config:
# YAML loading is handled manually via Config.from_yaml(); env vars use this prefix.
env_prefix = "NOMYO_ROUTER_"
@classmethod
def _expand_env_refs(cls, obj):
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
if isinstance(obj, dict):
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
if isinstance(obj, list):
return [cls._expand_env_refs(v) for v in obj]
if isinstance(obj, str):
# Only expand if it is exactly ${VAR}
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
if m:
return os.getenv(m.group(1), "")
return obj
@classmethod
def from_yaml(cls, path: Path) -> "Config":
"""Load the YAML file and create the Config instance."""
if path.exists():
with path.open("r", encoding="utf-8") as fp:
data = yaml.safe_load(fp) or {}
cleaned = cls._expand_env_refs(data)
if isinstance(cleaned, dict):
# Accept hyphenated config key and map it to the field name
key_aliases = [
# canonical field name
"router_api_key",
# lowercase, hyphen/underscore variants
"nomyo-router-api-key",
"nomyo_router_api_key",
"nomyo-router_api_key",
"nomyo_router-api_key",
# uppercase env-style variants
"NOMYO-ROUTER_API_KEY",
"NOMYO_ROUTER_API_KEY",
]
for alias in key_aliases:
if alias in cleaned:
cleaned["router_api_key"] = cleaned.get("router_api_key", cleaned.pop(alias))
break
# If not present in YAML (or empty), fall back to env var explicitly
if not cleaned.get("router_api_key"):
env_key = os.getenv("NOMYO_ROUTER_API_KEY")
if env_key:
cleaned["router_api_key"] = env_key
return cls(**cleaned)
return cls()
def _config_path_from_env() -> Path:
"""
Resolve the configuration file path. Defaults to `config.yaml`
in the current working directory unless NOMYO_ROUTER_CONFIG_PATH
is set.
"""
candidate = os.getenv("NOMYO_ROUTER_CONFIG_PATH")
if candidate:
return Path(candidate).expanduser()
return Path("config.yaml")
# ------------------------------------------------------------------
# Shared config accessor
# ------------------------------------------------------------------
# Submodules read config at call time via get_config() instead of importing
# a bound name. The single source of truth is ``router.config`` — the lazy
# import below resolves it after router.py has finished loading, and lets
# tests that ``patch.object(router, "config", cfg)`` flow through.
def get_config() -> "Config":
"""Return the currently active Config from router.py."""
import router # lazy to avoid module-load circular import
return router.config

View file

@ -6,15 +6,7 @@ endpoints:
- https://api.openai.com/v1 - https://api.openai.com/v1
llama_server_endpoints: llama_server_endpoints:
- http://192.168.0.51:8889/v1 - http://192.168.0.50:8889/v1
# llama-swap endpoints (OpenAI-compatible front for multiple llama-server workers).
# Same surface as llama_server_endpoints, but the router reads loaded/running workers
# from /running (state == "ready") instead of a /v1/models status field, and unloads via
# POST /api/models/unload/:model_id. The router also exposes /upstream/:model_id/<path>
# to bypass llama-swap and reach a model's underlying llama-server worker directly.
llama_swap_endpoints:
- http://192.168.0.52:8890/v1
# Maximum concurrent connections *per endpointmodel pair* (equals to OLLAMA_NUM_PARALLEL) # Maximum concurrent connections *per endpointmodel pair* (equals to OLLAMA_NUM_PARALLEL)
# This is the global default; individual endpoints can override it via endpoint_config below. # This is the global default; individual endpoints can override it via endpoint_config below.
@ -65,8 +57,7 @@ api_keys:
"http://192.168.0.51:11434": "ollama" "http://192.168.0.51:11434": "ollama"
"http://192.168.0.52:11434": "ollama" "http://192.168.0.52:11434": "ollama"
"https://api.openai.com/v1": "${OPENAI_KEY}" "https://api.openai.com/v1": "${OPENAI_KEY}"
"http://192.168.0.51:8889/v1": "llama" "http://192.168.0.50:8889/v1": "llama"
"http://192.168.0.52:8889/v1": "llama-swap"
# ------------------------------------------------------------- # -------------------------------------------------------------
# Semantic LLM Cache (optional — disabled by default) # Semantic LLM Cache (optional — disabled by default)

View file

@ -1,120 +0,0 @@
"""Sliding-window context-trim helpers.
Mirrors what llama.cpp's context-shift used to do: count tokens with tiktoken
(cl100k_base) when available, drop oldest non-system messages until the prompt
fits inside (n_ctx - safety_margin).
Also owns the per-(endpoint, model) n_ctx cache that the routes populate from
exceed_context_size_error bodies and from finish_reason=="length" signals.
"""
import os
# Point tiktoken at the vendored cl100k_base vocab so the encoding loads offline,
# without a network download. The download would otherwise fail anyway: this repo
# has a top-level `requests` package that shadows the pip `requests` tiktoken's
# downloader imports, so get_encoding() would silently fall back to char/4. See
# vendor/tiktoken/. setdefault lets an explicit env override win.
os.environ.setdefault(
"TIKTOKEN_CACHE_DIR",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "vendor", "tiktoken"),
)
try:
import tiktoken as _tiktoken
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
except Exception:
_tiktoken_enc = None
def _count_message_tokens(messages: list) -> int:
"""Approximate token count for a message list.
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
Falls back to char/4 heuristic if tiktoken is unavailable.
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
"""
if _tiktoken_enc is None:
return sum(len(str(m.get("content", ""))) for m in messages) // 4
total = 2 # priming tokens
for msg in messages:
total += 4 # per-message role/separator overhead
content = msg.get("content", "")
if isinstance(content, str):
total += len(_tiktoken_enc.encode(content))
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
total += len(_tiktoken_enc.encode(part.get("text", "")))
return total
def _trim_messages_for_context(
messages: list,
n_ctx: int,
safety_margin: int = None,
target_tokens: int = None,
) -> list:
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
Keeps all system messages and the most recent non-system messages that fit
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
first (FIFO). The last message is always preserved.
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
response, including RAG tool results and tool call JSON synthesis.
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
Pass a calibrated value when actual n_prompt_tokens is known from the error
body so that tiktoken underestimation vs the backend tokenizer is corrected.
"""
if target_tokens is not None:
target = target_tokens
else:
if safety_margin is None:
safety_margin = n_ctx // 4
target = n_ctx - safety_margin
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
while len(non_system) > 1:
if _count_message_tokens(system_msgs + non_system) <= target:
break
non_system.pop(0) # drop oldest non-system message
# Ensure the first non-system message is a user message (chat templates require it).
# Drop any leading assistant/tool messages that were left after trimming.
while non_system and non_system[0].get("role") != "user":
non_system.pop(0)
return system_msgs + non_system
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
_count_message_tokens only counts message text, so we cannot derive an accurate
per-token scale from the ratio. Instead we compute the *delta* we need to remove
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
Example: actual=17993, n_ctx=16384, headroom=4096 need to shed 5705 backend
tokens shed 6846 tiktoken tokens from messages.
"""
cur_tiktoken = _count_message_tokens(msgs)
headroom = n_ctx // 4 # reserve for generated output
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
tiktoken_to_shed = int(to_shed * 1.2)
return max(1, cur_tiktoken - tiktoken_to_shed)
# Per-(endpoint, model) n_ctx cache.
# Populated from two sources:
# 1. 400 exceed_context_size_error body → n_ctx field
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
# so large-context models (200k+ for coding) are never touched.
_endpoint_nctx: dict[tuple[str, str], int] = {}
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this

186
db.py
View file

@ -1,20 +1,9 @@
import aiosqlite, asyncio, orjson import aiosqlite, asyncio
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from datetime import datetime, timezone from datetime import datetime, timezone
from collections import defaultdict from collections import defaultdict
def get_db() -> "TokenDatabase":
"""Return the live TokenDatabase instance held by router.py.
Resolved lazily so submodules can access it without import cycles, and
so test patches of ``router.db`` flow through to all callers.
"""
import router # lazy to avoid module-load circular import
return router.db
class TokenDatabase: class TokenDatabase:
def __init__(self, db_path: str = "token_counts.db"): def __init__(self, db_path: str = "token_counts.db"):
self.db_path = db_path self.db_path = db_path
@ -75,24 +64,6 @@ class TokenDatabase:
''') ''')
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_timestamp ON token_time_series(timestamp)')
await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)') await db.execute('CREATE INDEX IF NOT EXISTS idx_token_time_series_model_ts ON token_time_series(model, timestamp)')
# Responses API state — the router owns conversation state for the
# /v1/responses family (store / previous_response_id) and tracks
# background-task status here so polling survives across workers.
await db.execute('''
CREATE TABLE IF NOT EXISTS stored_responses (
response_id TEXT PRIMARY KEY,
previous_response_id TEXT,
model TEXT,
status TEXT,
created_at INTEGER,
input_messages TEXT,
output_items TEXT,
usage TEXT,
instructions TEXT,
error TEXT
)
''')
await db.execute('CREATE INDEX IF NOT EXISTS idx_stored_responses_prev ON stored_responses(previous_response_id)')
await db.commit() await db.commit()
async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int): async def update_token_counts(self, endpoint: str, model: str, input_tokens: int, output_tokens: int):
@ -337,158 +308,3 @@ class TokenDatabase:
await db.commit() await db.commit()
return aggregated_count return aggregated_count
# -----------------------------------------------------------------
# Responses API state (store / previous_response_id / background)
# -----------------------------------------------------------------
@staticmethod
def _row_to_response(row) -> dict:
"""Map a stored_responses row to a plain dict, decoding JSON columns."""
def _loads(val):
if val is None:
return None
try:
return orjson.loads(val)
except (orjson.JSONDecodeError, TypeError):
return None
return {
'response_id': row[0],
'previous_response_id': row[1],
'model': row[2],
'status': row[3],
'created_at': row[4],
'input_messages': _loads(row[5]),
'output_items': _loads(row[6]),
'usage': _loads(row[7]),
'instructions': row[8],
'error': _loads(row[9]),
}
async def store_response(
self,
response_id: str,
*,
previous_response_id: Optional[str],
model: str,
status: str,
created_at: int,
input_messages: list,
output_items: Optional[list] = None,
usage: Optional[dict] = None,
instructions: Optional[str] = None,
error: Optional[dict] = None,
):
"""Insert or replace a stored Responses-API response row."""
db = await self._get_connection()
async with self._operation_lock:
await db.execute('''
INSERT INTO stored_responses
(response_id, previous_response_id, model, status, created_at,
input_messages, output_items, usage, instructions, error)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT (response_id) DO UPDATE SET
previous_response_id = excluded.previous_response_id,
model = excluded.model,
status = excluded.status,
created_at = excluded.created_at,
input_messages = excluded.input_messages,
output_items = excluded.output_items,
usage = excluded.usage,
instructions = excluded.instructions,
error = excluded.error
''', (
response_id, previous_response_id, model, status, created_at,
orjson.dumps(input_messages).decode("utf-8"),
orjson.dumps(output_items).decode("utf-8") if output_items is not None else None,
orjson.dumps(usage).decode("utf-8") if usage is not None else None,
instructions,
orjson.dumps(error).decode("utf-8") if error is not None else None,
))
await db.commit()
async def update_response_status(
self,
response_id: str,
status: str,
*,
output_items: Optional[list] = None,
usage: Optional[dict] = None,
error: Optional[dict] = None,
):
"""Update the status (and optionally output/usage/error) of a stored response."""
db = await self._get_connection()
async with self._operation_lock:
await db.execute('''
UPDATE stored_responses
SET status = ?,
output_items = COALESCE(?, output_items),
usage = COALESCE(?, usage),
error = COALESCE(?, error)
WHERE response_id = ?
''', (
status,
orjson.dumps(output_items).decode("utf-8") if output_items is not None else None,
orjson.dumps(usage).decode("utf-8") if usage is not None else None,
orjson.dumps(error).decode("utf-8") if error is not None else None,
response_id,
))
await db.commit()
async def get_response(self, response_id: str) -> Optional[dict]:
"""Return a stored response as a dict, or None if not found."""
db = await self._get_connection()
async with self._operation_lock:
async with db.execute('''
SELECT response_id, previous_response_id, model, status, created_at,
input_messages, output_items, usage, instructions, error
FROM stored_responses WHERE response_id = ?
''', (response_id,)) as cursor:
row = await cursor.fetchone()
return self._row_to_response(row) if row is not None else None
async def delete_response(self, response_id: str) -> bool:
"""Delete a stored response. Returns True if a row was removed."""
db = await self._get_connection()
async with self._operation_lock:
cursor = await db.execute(
'DELETE FROM stored_responses WHERE response_id = ?', (response_id,)
)
await db.commit()
return cursor.rowcount > 0
async def get_response_chain(self, response_id: str, max_turns: int = 50) -> list:
"""Walk previous_response_id back to the root, returned oldest-first.
Bounded to ``max_turns`` so a pathological chain cannot stall a request.
Missing links terminate the walk gracefully.
"""
chain: list = []
seen: set = set()
current = response_id
while current and current not in seen and len(chain) < max_turns:
seen.add(current)
resp = await self.get_response(current)
if resp is None:
break
chain.append(resp)
current = resp.get('previous_response_id')
chain.reverse()
return chain
async def fail_orphaned_responses(self) -> int:
"""Mark non-terminal responses as failed (called on startup).
A background task lives in a worker's event loop; a process restart loses
it while the DB row stays ``queued``/``in_progress`` forever. Reconcile
those to ``failed`` so polling clients get a terminal state.
"""
db = await self._get_connection()
async with self._operation_lock:
cursor = await db.execute('''
UPDATE stored_responses
SET status = 'failed',
error = ?
WHERE status IN ('queued', 'in_progress')
''', (orjson.dumps({"message": "Response interrupted by server restart", "type": "server_error"}).decode("utf-8"),))
await db.commit()
return cursor.rowcount

View file

@ -206,8 +206,6 @@ The `/health` endpoint provides comprehensive health status:
} }
``` ```
For Ollama endpoints the probe is a parallel check of `/api/version` (liveness) and `/api/ps` (the route used by `choose_endpoint` when selecting a backend for a request). Reporting `ok` only when both succeed prevents the router from advertising an endpoint as healthy while completion calls dead-end on `/api/ps`. The same dual probe backs `/api/config`, which the dashboard uses to render endpoint health.
## Database Schema ## Database Schema
The router uses SQLite for persistent storage: The router uses SQLite for persistent storage:

View file

@ -78,37 +78,6 @@ endpoints:
- OpenAI-compatible endpoints use `/v1` prefix - OpenAI-compatible endpoints use `/v1` prefix
- The router automatically detects endpoint type based on URL pattern - The router automatically detects endpoint type based on URL pattern
### `llama_server_endpoints`
**Type**: `list[str]` (optional)
**Default**: `[]`
**Description**: List of [llama.cpp `llama-server`](https://github.com/ggml-org/llama.cpp) endpoints (OpenAI-compatible, configured with the `/v1` suffix). The router reads each backend's loaded models from `/v1/models` (entries with `status == "loaded"`) and unloads idle models via `POST /models/unload`.
```yaml
llama_server_endpoints:
- http://192.168.0.50:8889/v1
```
### `llama_swap_endpoints`
**Type**: `list[str]` (optional)
**Default**: `[]`
**Description**: List of [llama-swap](https://github.com/mostlygeek/llama-swap) endpoints (OpenAI-compatible, configured with the `/v1` suffix). llama-swap fronts multiple `llama-server` workers behind one address. It is treated like `llama_server_endpoints` for routing, model discovery, and reranking, but differs in two ways the router handles automatically:
- **Loaded-model detection** — llama-swap's `/v1/models` omits the per-model `status` field, so running workers are read from `GET /running` (entries with `state == "ready"`).
- **Model unload** — done via `POST /api/models/unload/:model_id` (path parameter), not the `llama-server` body form.
The router also exposes a passthrough route, `GET|POST /upstream/:model_id/<path>`, which forwards directly to a model's underlying `llama-server` worker (via llama-swap's `/upstream`), letting clients use `llama-server` features that llama-swap does not forward (e.g. token-array prompts).
```yaml
llama_swap_endpoints:
- http://192.168.0.50:8890/v1
```
### `max_concurrent_connections` ### `max_concurrent_connections`
**Type**: `int` **Type**: `int`

View file

@ -29,10 +29,6 @@ Response:
- `200`: All endpoints healthy - `200`: All endpoints healthy
- `503`: One or more endpoints unhealthy - `503`: One or more endpoints unhealthy
**Probe scope per endpoint**:
- **Ollama endpoints** are probed at both `/api/version` (liveness) and `/api/ps` (model-introspection used by the router). If either fails the endpoint is reported as `error`; the response still includes `version` when the daemon is reachable so operators can tell a partial failure from a full outage. The `detail` field names the failing probe, e.g. `"/api/ps: 502 …"`.
- **OpenAI-compatible / llama-server endpoints** are probed at `/models`.
### Current Usage ### Current Usage
```bash ```bash
@ -137,8 +133,6 @@ Response:
} }
``` ```
Uses the same dual-probe logic as `/health` (Ollama: `/api/version` + `/api/ps`; OpenAI-compatible: `/models`). An endpoint will report `error` whenever either probe fails. The dashboard renders the `detail` field as a tooltip on the status cell.
### Cache Statistics ### Cache Statistics
```bash ```bash

View file

@ -1,35 +0,0 @@
"""Conversation fingerprinting for prompt-cache-aware routing."""
import hashlib
from typing import Optional
def _conversation_fingerprint(model: str, messages: Optional[list],
prompt: Optional[str]) -> Optional[str]:
"""
Stable hash over (model, first system + first user turn). That prefix
determines whether the backend's prompt cache is reusable; later turns
don't influence the routing decision because they extend the same prefix.
Returns None when there is no usable prefix.
"""
parts: list[str] = [model or "_"]
if messages:
for m in messages:
role = m.get("role") if isinstance(m, dict) else None
if role not in ("system", "user"):
continue
content = m.get("content")
if isinstance(content, list): # OpenAI multimodal parts
content = "".join(
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
if not isinstance(content, str):
continue
parts.append(f"{role}:{content}")
if role == "user":
break
elif prompt:
parts.append(f"user:{prompt}")
else:
return None
return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()

View file

@ -1,66 +0,0 @@
"""Image and timestamp helpers used by the Ollama/OpenAI request pipeline."""
import base64
import io
import time
from datetime import datetime, timezone
from PIL import Image
def iso8601_ns():
ns = time.time_ns()
sec, ns_rem = divmod(ns, 1_000_000_000)
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
f"{ns_rem:09d}Z"
)
def is_base64(image_string):
try:
if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode():
return True
except Exception:
return False
def resize_image_if_needed(image_data):
try:
# Check if already data-url
if image_data.startswith("data:"):
try:
header, image_data = image_data.split(",", 1)
except ValueError:
pass
# Decode the base64 image data
image_bytes = base64.b64decode(image_data)
with Image.open(io.BytesIO(image_bytes)) as image:
if image.mode not in ("RGB", "L"):
image = image.convert("RGB")
# Get current size
width, height = image.size
# Calculate the new dimensions while maintaining aspect ratio
if width > 512 or height > 512:
aspect_ratio = width / height
if aspect_ratio > 1: # Width is larger
new_width = 512
new_height = int(512 / aspect_ratio)
else: # Height is larger
new_height = 512
new_width = int(512 * aspect_ratio)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Encode the resized image back to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return resized_image_data
except Exception as e:
print(f"Error processing image: {e}")
return None

View file

View file

@ -1,218 +0,0 @@
"""High-level chat request orchestrator.
``_make_chat_request`` is the shared core that:
* picks an endpoint via ``choose_endpoint`` (which atomically reserves a slot),
* dispatches to either the native Ollama client or an OpenAI-compatible
client based on endpoint type,
* applies reactive context trimming when the backend rejects with
``exceed_context_size_error``,
* counts tokens for billing/SSE,
* always releases the reservation via ``decrement_usage`` in ``finally``.
``_make_moe_requests`` builds on it to implement the
"3 responses + 3 critiques + 1 final" mixture-of-experts dance.
"""
import asyncio
import re
import time
import ollama
import enhance
from config import get_config
from state import default_headers, token_queue
from context_window import _trim_messages_for_context, _calibrated_trim_target
from backends.normalize import is_openai_compatible
from backends.sessions import _make_openai_client
from routing import choose_endpoint, decrement_usage
from requests.messages import (
get_last_user_content,
transform_tool_calls_to_openai,
transform_images_to_data_urls,
_strip_assistant_prefill,
_strip_images_from_messages,
_accumulate_openai_tc_delta,
_build_ollama_tool_calls,
)
from requests.rechunk import rechunk
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
config = get_config()
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")[0]
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
try:
if use_openai:
start_ts = time.perf_counter()
try:
response = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
raise
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2):
print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
else:
raise
elif "image input is not supported" in _e_str:
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
response = await oclient.chat.completions.create(**params)
else:
raise
if stream:
# For streaming, we need to collect all chunks
chunks = []
tc_acc = {} # accumulate tool-call deltas
async for chunk in response:
chunks.append(chunk)
_accumulate_openai_tc_delta(chunk, tc_acc)
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
# Inject fully-accumulated tool calls into the final response
if tc_acc and response.message:
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
else:
prompt_tok = 0
comp_tok = 0
if response.usage is not None:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(response)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if stream:
# For streaming, collect all chunks
chunks = []
async for chunk in response:
chunks.append(chunk)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, tracking_model)
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make MOE (Multiple Opinions Ensemble) requests.
Generates 3 responses, 3 critiques, and returns the final selected response.
"""
query = get_last_user_content(messages)
if not query:
raise ValueError("No user query found in messages")
if options is None:
options = {}
options["temperature"] = 1
moe_reqs = []
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
for n, r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
# Select final response
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)

View file

@ -1,187 +0,0 @@
"""Message-shape transforms used across the chat/completions paths.
Covers the directions between Ollama's native message format and the
OpenAI Chat Completions format:
* tool-call normalization (Ollama OpenAI),
* images encoded as base64 lists OpenAI multimodal ``image_url`` parts,
* trailing-assistant prefill strip (rejected by Claude/OpenAI),
* streaming tool_calls accumulation across deltas,
* logprobs translation (OpenAI choice Ollama ``Logprob``).
"""
import secrets
import ollama
import orjson
from ollama._types import TokenLogprob, Logprob
from images import is_base64, resize_image_if_needed
def get_last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
def _strip_assistant_prefill(messages: list) -> list:
"""Remove a trailing assistant message used as prefill.
OpenAI-compatible endpoints (including Claude) do not support prefill and
will reject requests where the last message has role 'assistant'."""
if messages and messages[-1].get("role") == "assistant":
return messages[:-1]
return messages
def transform_tool_calls_to_openai(message_list):
"""
Ensure tool_calls in assistant messages conform to the OpenAI format:
- Each tool call must have "type": "function"
- Each tool call must have an "id"
- arguments must be a JSON string, not a dict
Also ensure tool-role messages have a tool_call_id.
"""
# Track generated IDs so tool-role messages can reference them
last_tool_call_ids = {}
for msg in message_list:
role = msg.get("role")
if role == "assistant" and "tool_calls" in msg:
for tc in msg["tool_calls"]:
if "type" not in tc:
tc["type"] = "function"
if "id" not in tc:
tc["id"] = f"call_{secrets.token_hex(16)}"
func = tc.get("function", {})
if isinstance(func.get("arguments"), dict):
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
# Remember the id for the following tool-role message
name = func.get("name")
if name:
last_tool_call_ids[name] = tc["id"]
elif role == "tool":
if "tool_call_id" not in msg:
# Try to match by name from a preceding assistant tool_call
name = msg.get("name") or msg.get("tool_name")
if name and name in last_tool_call_ids:
msg["tool_call_id"] = last_tool_call_ids.pop(name)
return message_list
def transform_images_to_data_urls(message_list):
for message in message_list:
if "images" in message:
images = message.pop("images")
if not isinstance(images, list):
continue
new_content = []
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
if not is_base64(image):
raise ValueError(f"Image string is not a valid base64 encoded string.")
resized_image = resize_image_if_needed(image)
if resized_image:
data_url = f"data:image/png;base64,{resized_image}"
#new_content.append({
# "type": "text",
# "text": ""
#})
new_content.append({
"type": "image_url",
"image_url": {
"url": data_url
}
})
message["content"] = new_content
return message_list
def _strip_images_from_messages(messages: list) -> list:
"""Remove image_url parts from message content, keeping only text."""
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_only = [p for p in content if p.get("type") != "image_url"]
if len(text_only) == 1 and text_only[0].get("type") == "text":
content = text_only[0]["text"]
else:
content = text_only
result.append({**msg, "content": content})
else:
result.append(msg)
return result
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
``accumulator`` is a dict mapping tool-call *index* to
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
is the concatenation of all JSON fragments seen so far.
"""
if not chunk.choices:
return
delta = chunk.choices[0].delta
tc_deltas = getattr(delta, "tool_calls", None)
if not tc_deltas:
return
for tc in tc_deltas:
idx = tc.index
if idx not in accumulator:
accumulator[idx] = {
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
"name": tc.function.name if tc.function else None,
"arguments": "",
}
else:
if getattr(tc, "id", None):
accumulator[idx]["id"] = tc.id
if tc.function and tc.function.name:
accumulator[idx]["name"] = tc.function.name
if tc.function and tc.function.arguments:
accumulator[idx]["arguments"] += tc.function.arguments
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
if not accumulator:
return None
result = []
for idx in sorted(accumulator.keys()):
tc = accumulator[idx]
try:
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
except (orjson.JSONDecodeError, TypeError):
args = {}
result.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
))
return result
def _convert_openai_logprobs(choice) -> list | None:
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
lp = getattr(choice, "logprobs", None)
if lp is None:
return None
content = getattr(lp, "content", None)
if not content:
return None
result = []
for entry in content:
top = [
TokenLogprob(token=alt.token, logprob=alt.logprob)
for alt in (entry.top_logprobs or [])
]
result.append(Logprob(
token=entry.token,
logprob=entry.logprob,
top_logprobs=top or None,
))
return result

View file

@ -1,151 +0,0 @@
"""OpenAI → Ollama response shape converters.
Methods on the ``rechunk`` class are called as bare functions
(``rechunk.openai_chat_completion2ollama(...)``) there is no instance
state. The class is just a namespace.
``extract_usage_from_llama_timings`` reads the ``timings`` field that
llama-server returns in place of OpenAI's ``usage`` so the router can still
count tokens for those backends.
"""
import time
import ollama
import orjson
from images import iso8601_ns
from requests.messages import _convert_openai_logprobs
class rechunk:
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
now = time.perf_counter()
if chunk.choices == [] and chunk.usage is not None:
return ollama.ChatResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True,
done_reason='stop',
total_duration=int((now - start_ts) * 1_000_000_000),
load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens),
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
eval_count=int(chunk.usage.completion_tokens),
eval_duration=int((now - start_ts) * 1_000_000_000),
message=ollama.Message(role="assistant", content=""),
)
with_thinking = chunk.choices[0] if chunk.choices[0] else None
if stream == True:
thinking = (getattr(with_thinking.delta, "reasoning_content", None) or getattr(with_thinking.delta, "reasoning", None)) if with_thinking else None
role = chunk.choices[0].delta.role or "assistant"
content = chunk.choices[0].delta.content or ''
else:
thinking = (getattr(with_thinking.message, "reasoning_content", None) or getattr(with_thinking.message, "reasoning", None)) if with_thinking else None
role = chunk.choices[0].message.role or "assistant"
content = chunk.choices[0].message.content or ''
# Convert OpenAI tool_calls to Ollama format
# In streaming mode, tool_calls arrive as partial deltas across multiple chunks
# (name only in first delta, arguments as incremental JSON fragments).
# Callers must accumulate deltas and inject the final result; skip here.
ollama_tool_calls = None
if not stream:
raw_tool_calls = getattr(with_thinking.message, "tool_calls", None) if with_thinking else None
if raw_tool_calls:
ollama_tool_calls = []
for tc in raw_tool_calls:
try:
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
except (orjson.JSONDecodeError, TypeError):
args = {}
ollama_tool_calls.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc.function.name, arguments=args)
))
# Convert OpenAI logprobs to Ollama format
ollama_logprobs = _convert_openai_logprobs(with_thinking) if with_thinking else None
assistant_msg = ollama.Message(
role=role,
content=content,
thinking=thinking,
images=None,
tool_name=None,
tool_calls=ollama_tool_calls)
rechunk = ollama.ChatResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
message=assistant_msg,
logprobs=ollama_logprobs)
return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
now = time.perf_counter()
with_thinking = chunk.choices[0] if chunk.choices[0] else None
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
rechunk = ollama.GenerateResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=10000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
response=chunk.choices[0].text or '',
thinking=thinking)
return rechunk
def openai_embeddings2ollama(chunk: dict) -> ollama.EmbeddingsResponse:
rechunk = ollama.EmbeddingsResponse(embedding=chunk.data[0].embedding)
return rechunk
def openai_embed2ollama(chunk: dict, model: str) -> ollama.EmbedResponse:
rechunk = ollama.EmbedResponse(
model=model,
created_at=iso8601_ns(),
done=None,
done_reason=None,
total_duration=None,
load_duration=None,
prompt_eval_count=None,
prompt_eval_duration=None,
eval_count=None,
eval_duration=None,
embeddings=[chunk.data[0].embedding])
return rechunk
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
llama-server returns a ``timings`` dict instead of the standard OpenAI
``usage`` field::
"timings": {
"cache_n": 236, // prompt tokens reused from cache
"prompt_n": 1, // prompt tokens processed
"predicted_n": 35 // predicted (completion) tokens
}
prompt_tokens = prompt_n + cache_n
completion_tokens = predicted_n
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
timings are found.
"""
timings = getattr(obj, "timings", None)
if timings is None:
return None
if isinstance(timings, dict):
prompt_n = timings.get("prompt_n", 0) or 0
cache_n = timings.get("cache_n", 0) or 0
predicted_n = timings.get("predicted_n", 0) or 0
return (prompt_n + cache_n, predicted_n)
return None

View file

@ -1,492 +0,0 @@
"""Translation between the OpenAI **Responses API** and **Chat Completions**.
The router speaks Chat Completions to every backend (Ollama, llama-server,
external OpenAI). To expose ``/v1/responses`` transparently on top of that, this
module converts in both directions:
* request: Responses ``input`` / ``instructions`` / ``tools`` chat ``messages`` / ``tools``
* response: chat ``choices[0].message`` Responses ``output`` items
* stream: chat completion deltas Responses typed SSE events
Pure functions / a stream-translator class no I/O, mirroring the style of
``requests/messages.py``. The native passthrough path (external OpenAI) does not
use this module; it forwards the SDK's Responses objects directly.
"""
import secrets
import time
import orjson
from requests.messages import _accumulate_openai_tc_delta
# ---------------------------------------------------------------------------
# Request direction: Responses → Chat Completions
# ---------------------------------------------------------------------------
def _responses_content_to_chat(content):
"""Convert a Responses message ``content`` into Chat Completions content.
Collapses a single text part to a plain string (what most backends expect);
keeps a multimodal list otherwise.
"""
if content is None or isinstance(content, str):
return content
if not isinstance(content, list):
return str(content)
parts = []
for p in content:
if not isinstance(p, dict):
parts.append({"type": "text", "text": str(p)})
continue
ptype = p.get("type")
if ptype in ("input_text", "output_text", "text"):
parts.append({"type": "text", "text": p.get("text", "")})
elif ptype in ("input_image", "image_url"):
url = p.get("image_url")
if isinstance(url, dict):
url = url.get("url")
if url:
parts.append({"type": "image_url", "image_url": {"url": url}})
# input_file / refusal / reasoning parts have no chat equivalent → skip
if len(parts) == 1 and parts[0].get("type") == "text":
return parts[0]["text"]
return parts
def _input_item_to_message(item):
"""Convert a single Responses ``input`` item to a chat message (or None)."""
if isinstance(item, str):
return {"role": "user", "content": item}
if not isinstance(item, dict):
return None
itype = item.get("type")
if itype == "function_call":
return {
"role": "assistant",
"content": None,
"tool_calls": [{
"id": item.get("call_id") or item.get("id"),
"type": "function",
"function": {
"name": item.get("name"),
"arguments": item.get("arguments", ""),
},
}],
}
if itype == "function_call_output":
output = item.get("output", "")
if not isinstance(output, str):
output = orjson.dumps(output).decode("utf-8")
return {
"role": "tool",
"tool_call_id": item.get("call_id") or item.get("id"),
"content": output,
}
if itype in ("reasoning",):
# No Chat Completions equivalent — drop.
return None
# "message" item or a bare {role, content} chat-style item
role = item.get("role")
if role is None:
return None
return {"role": role, "content": _responses_content_to_chat(item.get("content"))}
def responses_input_to_messages(input_data, instructions=None):
"""Build a Chat Completions ``messages`` list from Responses ``input``.
``instructions`` becomes a leading system message; a string ``input`` becomes
a single user message; a list ``input`` is mapped item-by-item.
"""
messages = []
if instructions:
messages.append({"role": "system", "content": instructions})
if input_data is None:
return messages
if isinstance(input_data, str):
messages.append({"role": "user", "content": input_data})
return messages
if isinstance(input_data, list):
for item in input_data:
msg = _input_item_to_message(item)
if msg is not None:
messages.append(msg)
return messages
def _chat_content_to_responses_parts(content, assistant=False):
"""Convert chat message content → Responses content parts."""
text_type = "output_text" if assistant else "input_text"
if content is None:
return []
if isinstance(content, str):
return [{"type": text_type, "text": content}]
parts = []
for p in content if isinstance(content, list) else []:
if not isinstance(p, dict):
parts.append({"type": text_type, "text": str(p)})
elif p.get("type") == "text":
parts.append({"type": text_type, "text": p.get("text", "")})
elif p.get("type") == "image_url":
url = (p.get("image_url") or {}).get("url")
if url:
parts.append({"type": "input_image", "image_url": url})
return parts
def messages_to_responses_input(messages):
"""Convert chat messages → ``(instructions, Responses input items)``.
Used for the native passthrough path: history that the router has resolved in
chat-message space is re-expressed as Responses ``input``. Leading/standalone
system messages are merged into ``instructions``.
"""
instructions_parts = []
items = []
for m in messages:
role = m.get("role")
if role == "system":
c = m.get("content")
instructions_parts.append(c if isinstance(c, str) else orjson.dumps(c).decode("utf-8"))
continue
if role == "tool":
out = m.get("content")
if not isinstance(out, str):
out = orjson.dumps(out).decode("utf-8")
items.append({"type": "function_call_output",
"call_id": m.get("tool_call_id"), "output": out})
continue
if role == "assistant" and m.get("tool_calls"):
for tc in m["tool_calls"]:
fn = tc.get("function", {})
items.append({"type": "function_call", "call_id": tc.get("id"),
"name": fn.get("name"), "arguments": fn.get("arguments", "")})
if m.get("content"):
items.append({"role": "assistant",
"content": _chat_content_to_responses_parts(m["content"], assistant=True)})
continue
items.append({"role": role,
"content": _chat_content_to_responses_parts(m.get("content"),
assistant=(role == "assistant"))})
instructions = "\n\n".join(p for p in instructions_parts if p) or None
return instructions, items
def responses_object_to_sse(resp):
"""Render a *finished* Responses object as a valid SSE event stream.
Used to serve cache/store hits to streaming clients without a backend call.
"""
seq = [-1]
def ev(etype, payload):
seq[0] += 1
body = {"type": etype, "sequence_number": seq[0], **payload}
return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8")
parts_out = []
in_progress = {**resp, "status": "in_progress", "output": [], "output_text": ""}
parts_out.append(ev("response.created", {"response": in_progress}))
parts_out.append(ev("response.in_progress", {"response": in_progress}))
for oi, item in enumerate(resp.get("output", [])):
parts_out.append(ev("response.output_item.added",
{"output_index": oi, "item": {**item, "status": "in_progress"}}))
if item.get("type") == "message":
for ci, part in enumerate(item.get("content", [])):
if part.get("type") == "output_text":
iid = item.get("id")
parts_out.append(ev("response.content_part.added", {
"item_id": iid, "output_index": oi, "content_index": ci,
"part": {"type": "output_text", "text": "", "annotations": []}}))
parts_out.append(ev("response.output_text.delta", {
"item_id": iid, "output_index": oi, "content_index": ci,
"delta": part.get("text", "")}))
parts_out.append(ev("response.output_text.done", {
"item_id": iid, "output_index": oi, "content_index": ci,
"text": part.get("text", "")}))
parts_out.append(ev("response.content_part.done", {
"item_id": iid, "output_index": oi, "content_index": ci, "part": part}))
parts_out.append(ev("response.output_item.done", {"output_index": oi, "item": item}))
parts_out.append(ev("response.completed", {"response": resp}))
return b"".join(parts_out)
def tools_responses_to_chat(tools):
"""Map Responses tool definitions (flattened) → Chat Completions (nested)."""
if not tools:
return None
out = []
for t in tools:
if isinstance(t, dict) and t.get("type") == "function" and "function" not in t:
fn = {k: t[k] for k in ("name", "description", "parameters", "strict") if k in t}
out.append({"type": "function", "function": fn})
else:
out.append(t)
return out
# ---------------------------------------------------------------------------
# Response direction: Chat Completions → Responses
# ---------------------------------------------------------------------------
def _new_id(prefix):
return f"{prefix}_{secrets.token_hex(16)}"
def chat_message_to_output_items(message):
"""Convert an assistant chat message (dict) into Responses output items."""
items = []
content = message.get("content")
if content:
items.append({
"type": "message",
"id": _new_id("msg"),
"status": "completed",
"role": "assistant",
"content": [{"type": "output_text", "text": content, "annotations": []}],
})
for tc in message.get("tool_calls") or []:
fn = tc.get("function", {})
items.append({
"type": "function_call",
"id": _new_id("fc"),
"call_id": tc.get("id"),
"name": fn.get("name"),
"arguments": fn.get("arguments", ""),
"status": "completed",
})
return items
def usage_chat_to_responses(usage):
"""Map chat usage ``{prompt_tokens, completion_tokens}`` → Responses usage."""
if not usage:
return None
prompt = usage.get("prompt_tokens") or 0
completion = usage.get("completion_tokens") or 0
return {
"input_tokens": prompt,
"output_tokens": completion,
"total_tokens": usage.get("total_tokens") or (prompt + completion),
}
def output_items_to_text(output_items):
"""Concatenate the ``output_text`` parts of all message items."""
chunks = []
for item in output_items or []:
if item.get("type") != "message":
continue
for part in item.get("content") or []:
if part.get("type") == "output_text":
chunks.append(part.get("text", ""))
return "".join(chunks)
def build_response_object(
*,
response_id,
model,
output_items=None,
usage=None,
status="completed",
created_at=None,
previous_response_id=None,
instructions=None,
error=None,
metadata=None,
):
"""Assemble a full ``object:"response"`` body for a non-streaming reply."""
output_items = output_items or []
return {
"id": response_id,
"object": "response",
"created_at": created_at or int(time.time()),
"status": status,
"model": model,
"output": output_items,
"output_text": output_items_to_text(output_items),
"instructions": instructions,
"previous_response_id": previous_response_id,
"usage": usage_chat_to_responses(usage) if usage and "input_tokens" not in usage else usage,
"error": error,
"metadata": metadata or {},
}
# ---------------------------------------------------------------------------
# Streaming direction: Chat Completions deltas → Responses typed SSE events
# ---------------------------------------------------------------------------
class ChatToResponsesStream:
"""Translate a Chat Completions streaming generator into Responses events.
Usage::
translator = ChatToResponsesStream(response_id, model, created_at)
async for sse_bytes in translator.events(chat_async_gen):
yield sse_bytes
# translator.output_items / translator.usage now populated for storage
Emits the ordered event family
``response.created`` ``response.in_progress``
(``response.output_item.added`` ``response.content_part.added``
``response.output_text.delta``* ``response.output_text.done``
``response.content_part.done`` ``response.output_item.done``) and/or
function-call item events ``response.completed`` (carrying usage).
"""
def __init__(self, response_id, model, created_at=None,
previous_response_id=None, instructions=None, metadata=None):
self.response_id = response_id
self.model = model
self.created_at = created_at or int(time.time())
self.previous_response_id = previous_response_id
self.instructions = instructions
self.metadata = metadata or {}
self.seq = -1
self.output_items = []
self.usage = None
def _snapshot(self, status, output=None):
return build_response_object(
response_id=self.response_id,
model=self.model,
output_items=output if output is not None else [],
usage=self.usage,
status=status,
created_at=self.created_at,
previous_response_id=self.previous_response_id,
instructions=self.instructions,
metadata=self.metadata,
)
def _event(self, etype, payload):
self.seq += 1
body = {"type": etype, "sequence_number": self.seq, **payload}
return f"event: {etype}\ndata: {orjson.dumps(body).decode('utf-8')}\n\n".encode("utf-8")
async def events(self, async_gen):
yield self._event("response.created", {"response": self._snapshot("in_progress")})
yield self._event("response.in_progress", {"response": self._snapshot("in_progress")})
next_oi = 0
# text message state
msg_item_id = None
msg_oi = None
text_parts = []
# function-call state, keyed by chat tool_call index
tc_state = {} # idx -> {oi, item_id, call_id, name, args}
async for chunk in async_gen:
usage = getattr(chunk, "usage", None)
if usage is not None:
self.usage = {
"prompt_tokens": usage.prompt_tokens or 0,
"completion_tokens": usage.completion_tokens or 0,
}
choices = getattr(chunk, "choices", None)
if not choices:
continue
delta = choices[0].delta
content_piece = getattr(delta, "content", None)
if content_piece:
if msg_item_id is None:
msg_item_id = _new_id("msg")
msg_oi = next_oi
next_oi += 1
item = {
"id": msg_item_id, "type": "message", "status": "in_progress",
"role": "assistant", "content": [],
}
yield self._event("response.output_item.added",
{"output_index": msg_oi, "item": item})
yield self._event("response.content_part.added", {
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
"part": {"type": "output_text", "text": "", "annotations": []},
})
text_parts.append(content_piece)
yield self._event("response.output_text.delta", {
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
"delta": content_piece,
})
for tc in getattr(delta, "tool_calls", None) or []:
idx = tc.index
fn = getattr(tc, "function", None)
if idx not in tc_state:
item_id = _new_id("fc")
state = {
"oi": next_oi, "item_id": item_id,
"call_id": getattr(tc, "id", None) or _new_id("call"),
"name": (fn.name if fn else None), "args": "",
}
next_oi += 1
tc_state[idx] = state
yield self._event("response.output_item.added", {
"output_index": state["oi"],
"item": {
"id": item_id, "type": "function_call", "status": "in_progress",
"call_id": state["call_id"], "name": state["name"], "arguments": "",
},
})
else:
state = tc_state[idx]
if getattr(tc, "id", None):
state["call_id"] = tc.id
if fn and fn.name:
state["name"] = fn.name
if fn and fn.arguments:
state["args"] += fn.arguments
yield self._event("response.function_call_arguments.delta", {
"item_id": state["item_id"], "output_index": state["oi"],
"delta": fn.arguments,
})
# finalize message item
if msg_item_id is not None:
full_text = "".join(text_parts)
yield self._event("response.output_text.done", {
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
"text": full_text,
})
done_part = {"type": "output_text", "text": full_text, "annotations": []}
yield self._event("response.content_part.done", {
"item_id": msg_item_id, "output_index": msg_oi, "content_index": 0,
"part": done_part,
})
msg_item = {
"id": msg_item_id, "type": "message", "status": "completed",
"role": "assistant", "content": [done_part],
}
yield self._event("response.output_item.done",
{"output_index": msg_oi, "item": msg_item})
# finalize function-call items (in output-index order)
tc_items = {}
for idx, state in tc_state.items():
yield self._event("response.function_call_arguments.done", {
"item_id": state["item_id"], "output_index": state["oi"],
"arguments": state["args"],
})
fc_item = {
"id": state["item_id"], "type": "function_call", "status": "completed",
"call_id": state["call_id"], "name": state["name"], "arguments": state["args"],
}
tc_items[state["oi"]] = fc_item
yield self._event("response.output_item.done",
{"output_index": state["oi"], "item": fc_item})
# assemble final output items ordered by output index
ordered = []
if msg_item_id is not None:
ordered.append((msg_oi, msg_item))
ordered.extend(tc_items.items())
self.output_items = [item for _, item in sorted(ordered, key=lambda kv: kv[0])]
yield self._event("response.completed",
{"response": self._snapshot("completed", self.output_items)})

View file

@ -1,44 +1,44 @@
aiohappyeyeballs==2.6.2 aiohappyeyeballs==2.6.1
aiohttp==3.14.1 aiohttp==3.13.5
aiosignal==1.4.0 aiosignal==1.4.0
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.14.0 anyio==4.13.0
async-timeout==5.0.1 async-timeout==5.0.1
attrs==26.1.0 attrs==26.1.0
certifi==2026.6.17 certifi==2026.4.22
click==8.4.1 click==8.3.3
distro==1.9.0 distro==1.9.0
exceptiongroup==1.3.1 exceptiongroup==1.3.1
fastapi==0.138.0 fastapi==0.136.1
fastapi-sse==1.1.1 fastapi-sse==1.1.1
frozenlist==1.8.0 frozenlist==1.8.0
h11==0.16.0 h11==0.16.0
httpcore==1.0.9 httpcore==1.0.9
httpx==0.28.1 httpx==0.28.1
idna==3.18 idna==3.15
jiter==0.15.0 jiter==0.14.0
multidict==6.7.1 multidict==6.7.1
ollama==0.6.2 ollama==0.6.2
openai==2.43.0 openai==1.109.1
orjson>=3.11.5 orjson>=3.11.5
numpy>=1.26 numpy>=1.26
pillow==12.2.0 pillow==12.2.0
propcache==0.5.2 propcache==0.5.2
pydantic==2.13.4 pydantic==2.13.4
pydantic-settings==2.14.2 pydantic-settings==2.14.1
pydantic_core==2.46.4 pydantic_core==2.46.4
python-dotenv==1.2.2 python-dotenv==1.2.2
PyYAML==6.0.3 PyYAML==6.0.3
sniffio==1.3.1 sniffio==1.3.1
starlette>=1.0.1 starlette==0.52.1
truststore==0.10.4 truststore==0.10.4
tiktoken==0.13.0 tiktoken==0.13.0
tqdm==4.68.3 tqdm==4.67.3
typing-inspection==0.4.2 typing-inspection==0.4.2
typing_extensions==4.15.0 typing_extensions==4.15.0
uvicorn==0.49.0 uvicorn==0.47.0
uvloop uvloop
yarl==1.24.2 yarl==1.23.0
aiosqlite aiosqlite
# Semantic LLM cache — base install (exact-match mode, no heavy ML deps) # Semantic LLM cache — base install (exact-match mode, no heavy ML deps)
# For semantic mode use the :semantic Docker image tag (adds sentence-transformers + torch) # For semantic mode use the :semantic Docker image tag (adds sentence-transformers + torch)

4090
router.py

File diff suppressed because it is too large Load diff

View file

@ -1,325 +0,0 @@
"""Endpoint selection (load-balancing + conversation affinity).
``choose_endpoint`` is the heart of routing it picks an endpoint that
advertises the model, prefers ones with the model already loaded and a free
slot, applies the conversation-affinity hint when available, and honors
config-order priority routing when ``priority_routing`` is set.
``increment_usage`` / ``decrement_usage`` keep the per-(endpoint, model)
counter that drives utilization-based selection; they fan out an SSE
snapshot on every change.
"""
import asyncio
import random
import time
from typing import Optional
from config import get_config
from state import (
usage_counts,
usage_lock,
_loaded_error_cache,
_loaded_error_cache_lock,
_completion_error_cache,
_completion_error_cache_lock,
_COMPLETION_ERROR_TTL,
_affinity_map,
_affinity_lock,
_AFFINITY_MAX_ENTRIES,
)
from sse import _capture_snapshot, _distribute_snapshot
from backends.health import _is_fresh
from backends.normalize import (
is_ext_openai_endpoint,
is_openai_compatible,
is_llama_server,
llama_endpoints,
get_tracking_model,
)
from backends.probe import fetch
async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock:
usage_counts[endpoint][model] += 1
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock:
# Avoid negative counts
current = usage_counts[endpoint].get(model, 0)
if current > 0:
usage_counts[endpoint][model] = current - 1
# Optionally, clean up zero entries
if usage_counts[endpoint].get(model, 0) == 0:
usage_counts[endpoint].pop(model, None)
#if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
def get_max_connections(ep: str) -> int:
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
cfg = get_config()
return cfg.endpoint_config.get(ep, {}).get(
"max_concurrent_connections", cfg.max_concurrent_connections
)
async def choose_endpoint(model: str, reserve: bool = True,
affinity_key: Optional[str] = None) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
ensuring that the chosen endpoint actually *advertises* the model.
The selection algorithm:
1 Query every endpoint for its advertised models (`/api/tags`).
2 Build a list of endpoints that contain the requested model.
2.5 If conversation affinity is enabled and the caller passes
``affinity_key``, prefer the endpoint that previously served the
same conversation but only when it still has the model loaded
and a free slot. Otherwise fall through to the standard logic.
3 For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
4 If none are both loaded and free, fall back to any endpoint
from the filtered list that simply has a free slot and randomly
select one.
5 If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
6 If no endpoint advertises the model at all, raise an error.
"""
config = get_config()
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
# Include config.endpoints plus any llama-server / llama-swap endpoints
llama_eps_extra = [ep for ep in llama_endpoints(config) if ep not in config.endpoints]
all_endpoints = config.endpoints + llama_eps_extra
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
# 6
if not candidate_endpoints:
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
model_without_latest = model.split(":latest")[0]
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model_without_latest in models and (is_ext_openai_endpoint(ep) or is_llama_server(ep))
]
if not candidate_endpoints:
# Only add :latest suffix if model doesn't already have a version suffix
if ":" not in model:
model = model + ":latest"
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
if not candidate_endpoints:
raise RuntimeError(
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
f"advertise the model '{model}'."
)
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# 3⃣.5 Exclude endpoints whose loaded-model probe has been failing
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
# would appear with an empty loaded set but pass through to the
# free-slot fallback (step 4) — sending completion calls to an
# unhealthy backend. See issue #83.
async with _loaded_error_cache_lock:
unhealthy = {
ep for ep, ts in _loaded_error_cache.items()
if _is_fresh(ts, 300)
}
if unhealthy:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in unhealthy
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# If *every* candidate is unhealthy we still fall through with the
# original list — refusing to route is worse than retrying a
# possibly-recovered backend.
# 3⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
# failed with a backend connection error (e.g. llama-server in router mode
# whose delegated worker for *this* model died). /v1/models keeps reporting
# OK in that case, so the probe-level filter above cannot catch it.
async with _completion_error_cache_lock:
completion_broken = {
ep for (ep, m), ts in _completion_error_cache.items()
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
}
if completion_broken:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in completion_broken
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# Same fallback: if every candidate is broken for this model, fall
# through and let the upstream retry — possibly the operator restarted
# the dead worker.
# Look up a possible affinity hint *before* taking usage_lock. The two
# locks are never held together to avoid lock-ordering issues.
affine_ep: Optional[str] = None
if config.conversation_affinity and affinity_key:
async with _affinity_lock:
entry = _affinity_map.get(affinity_key)
if entry is not None:
ep, _stored_model, expires_at = entry
if expires_at < time.monotonic():
_affinity_map.pop(affinity_key, None)
else:
affine_ep = ep
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
# tracking names for llama-server (HF prefix / quant suffix stripped).
def tracking_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
def utilization_ratio(ep: str) -> float:
return tracking_usage(ep) / get_max_connections(ep)
def total_load(ep: str) -> int:
"""Sum of in-flight requests across *all* models on the endpoint."""
return sum(usage_counts.get(ep, {}).values())
# How many models each candidate currently has *resident* (from the
# /api/ps probe). With infinite keep-alive a model stays loaded long
# after its in-flight count drops to zero, so this is the signal that
# spreads *distinct* models across backends.
ep_loaded_counts = {
ep: len(models) for ep, models in zip(candidate_endpoints, loaded_sets)
}
def loaded_count(ep: str) -> int:
return ep_loaded_counts.get(ep, 0)
def pick_least_loaded(eps: list[str]) -> str:
"""Pick the least-committed endpoint, breaking ties at random.
Ordering key is ``(total_load, loaded_count)``:
* ``total_load`` (in-flight requests across *all* models) keeps a
request off a backend already busy with a *different* model
otherwise the per-model count reads zero everywhere and the
ranking is discarded (cold model B landing on the box serving A).
* ``loaded_count`` (number of *resident* models) then spreads
distinct models across backends. Two different cold models (27b,
35b) requested back-to-back must not pile onto the same box: once
27b is resident there, that box has loaded_count 1 while the idle
backends have 0, so the next cold model prefers an empty backend
even though every backend reports zero in-flight load.
``random.choice`` only breaks genuine ties on both keys, so a single
idle cluster still distributes the very first cold model evenly."""
best = min((total_load(ep), loaded_count(ep)) for ep in eps)
tied = [ep for ep in eps if (total_load(ep), loaded_count(ep)) == best]
return random.choice(tied)
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
selected: Optional[str] = None
# 2⃣.5 Conversation affinity preference — only honour the hint when
# the affine endpoint still advertises the model loaded *and* has a
# free slot. Otherwise fall back to the standard algorithm.
if affine_ep:
ep_loaded = {
ep: set(models)
for ep, models in zip(candidate_endpoints, loaded_sets)
}
if (affine_ep in candidate_endpoints
and model in ep_loaded.get(affine_ep, set())
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
selected = affine_ep
if selected is None:
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < get_max_connections(ep)
]
if loaded_and_free:
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
else:
# All endpoints here already have the model loaded, so there
# is no model-switching cost to optimise for. Pick the least
# *total*-loaded one (tie broken at random) so we steer away
# from a backend busy serving other models.
selected = pick_least_loaded(loaded_and_free)
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < get_max_connections(ep)
]
if endpoints_with_free_slot:
if config.priority_routing:
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Prefer the endpoint with the lowest *total* load so the
# cold-start cost lands on genuinely idle hardware rather
# than a backend already busy with a different model.
selected = pick_least_loaded(endpoints_with_free_slot)
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model)
snapshot = None
if reserve:
usage_counts[selected][tracking_model] += 1
snapshot = _capture_snapshot()
if snapshot is not None:
await _distribute_snapshot(snapshot)
# Record / refresh affinity *after* releasing usage_lock.
if reserve and config.conversation_affinity and affinity_key:
expires_at = time.monotonic() + config.conversation_affinity_ttl
async with _affinity_lock:
_affinity_map[affinity_key] = (selected, model, expires_at)
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
now = time.monotonic()
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
_affinity_map.pop(k, None)
return selected, tracking_model

View file

@ -1,14 +0,0 @@
"""Secret-masking helpers used when logging or surfacing backend errors."""
import re
def _mask_secrets(text: str) -> str:
"""
Mask common API key patterns to avoid leaking secrets in logs or error payloads.
"""
if not text:
return text
# OpenAI-style keys (sk-...) and generic "api key" mentions
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
return text

62
sse.py
View file

@ -1,62 +0,0 @@
"""Server-sent-events plumbing.
Captures the current ``usage_counts`` / ``token_usage_counts`` snapshot and
fan-outs it to every subscribed asyncio.Queue. Routes that need a live
dashboard feed call ``subscribe`` / ``unsubscribe`` to obtain a queue.
"""
import asyncio
from typing import Dict
import orjson
from state import (
usage_counts,
token_usage_counts,
_subscribers,
_subscribers_lock,
)
def _capture_snapshot() -> str:
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
return orjson.dumps({
"usage_counts": dict(usage_counts),
"token_usage_counts": dict(token_usage_counts)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
async def _distribute_snapshot(snapshot: str) -> None:
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
async with _subscribers_lock:
for q in _subscribers:
if q.full():
try:
await q.get()
except asyncio.QueueEmpty:
pass
await q.put(snapshot)
async def close_all_sse_queues():
for q in list(_subscribers):
# sentinel value that the generator will recognise
await q.put(None)
async def subscribe() -> asyncio.Queue:
"""
Returns a new Queue that will receive every snapshot.
"""
q: asyncio.Queue = asyncio.Queue(maxsize=10)
async with _subscribers_lock:
_subscribers.add(q)
return q
async def unsubscribe(q: asyncio.Queue):
async with _subscribers_lock:
_subscribers.discard(q)
async def get_usage_counts() -> Dict:
return dict(usage_counts) # shallow copy

116
state.py
View file

@ -1,116 +0,0 @@
"""Shared mutable router state.
All process-wide caches, locks, in-flight task maps, queues, counters and
buffers used by the router live here. These names are only ever *mutated*
(dict/set updates, lock acquisitions, queue put/get) never rebound so
importing them via ``from state import `` is safe from every module.
Rebound singletons (``config``, ``db``, ``token_worker_task``,
``flush_task``) intentionally stay in router.py so their reassignment on
startup is visible to all callers.
"""
import asyncio
from collections import defaultdict
from typing import Dict, Set
# ------------------------------------------------------------------
# Inmemory caches
# ------------------------------------------------------------------
# Successful results are cached for 300s
_models_cache: dict[str, tuple[Set[str], float]] = {}
_loaded_models_cache: dict[str, tuple[Set[str], float]] = {}
# Transient errors are cached separately per concern so that a failure
# in one path does not poison the other.
_available_error_cache: dict[str, float] = {}
_loaded_error_cache: dict[str, float] = {}
# Per-(endpoint, model) completion-path failures. A llama-server in router
# mode can keep returning /v1/models 200 OK after its delegated worker for
# a specific model dies — the probe-level caches above will not catch this.
# We record signals observed during actual completion attempts so
# choose_endpoint can avoid the affected (endpoint, model) pair without
# poisoning unrelated models on the same backend.
_completion_error_cache: dict[tuple[str, str], float] = {}
_COMPLETION_ERROR_TTL = 300
# ------------------------------------------------------------------
# Cache locks
# ------------------------------------------------------------------
_models_cache_lock = asyncio.Lock()
_loaded_models_cache_lock = asyncio.Lock()
_available_error_cache_lock = asyncio.Lock()
_loaded_error_cache_lock = asyncio.Lock()
_completion_error_cache_lock = asyncio.Lock()
# ------------------------------------------------------------------
# In-flight request tracking (prevents cache stampede)
# ------------------------------------------------------------------
_inflight_available_models: dict[str, asyncio.Task] = {}
_inflight_loaded_models: dict[str, asyncio.Task] = {}
_inflight_lock = asyncio.Lock()
_bg_refresh_available: dict[str, asyncio.Task] = {}
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
_bg_refresh_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Queues
# ------------------------------------------------------------------
_subscribers: Set[asyncio.Queue] = set()
_subscribers_lock = asyncio.Lock()
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
# ------------------------------------------------------------------
# HTTP client / connector cache
# ------------------------------------------------------------------
app_state = {
"session": None,
"connector": None,
"probe_session": None, # dedicated session for health/introspection probes
"probe_connector": None, # connection pool isolated from proxy traffic
"socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints
"httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints
# Long-lived backend clients, reused across requests. Constructing these is
# expensive (~40 ms each — every new client builds an SSL context and reloads
# the OS trust store via truststore), so building one per request serializes
# the event loop and caps throughput. Created once at startup, closed on
# shutdown. See backends.sessions.get_ollama_client / _make_openai_client.
"ollama_clients": {}, # endpoint -> ollama.AsyncClient
"openai_clients": {}, # (endpoint, api_key) -> openai.AsyncOpenAI
}
# Default outbound HTTP headers attached to every backend request.
default_headers = {
"HTTP-Referer": "https://nomyo.ai",
"Referer": "https://nomyo.ai",
"X-Title": "NOMYO Router",
}
# ------------------------------------------------------------------
# Token Count Buffer (for write-behind pattern)
# ------------------------------------------------------------------
# Structure: {endpoint: {model: (input_tokens, output_tokens)}}
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0)))
# Time series buffer with timestamp
time_series_buffer: list[dict[str, int | str]] = []
# Lock to protect buffer access from race conditions
buffer_lock = asyncio.Lock()
# Configuration for periodic flushing
FLUSH_INTERVAL = 10 # seconds
# ------------------------------------------------------------------
# Perendpoint permodel active connection counters
# ------------------------------------------------------------------
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic).
# Keeps the same conversation pinned to the endpoint that already has its
# KV-cache prefix warm. Model is stored so the dashboard can aggregate live
# entries per (endpoint, model) without recomputing fingerprints.
# Never held together with usage_lock.
_affinity_map: Dict[str, tuple[str, str, float]] = {}
_affinity_lock = asyncio.Lock()
_AFFINITY_MAX_ENTRIES = 10000

View file

@ -192,10 +192,6 @@
color: #8b0000; color: #8b0000;
font-weight: bold; font-weight: bold;
} }
.status-error[title] {
cursor: help;
text-decoration: underline dotted;
}
.copy-link, .copy-link,
.delete-link, .delete-link,
.show-link, .show-link,
@ -740,16 +736,6 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
return await resp.json(); return await resp.json();
} }
function escapeHtml(value) {
if (value === null || value === undefined) return "";
return String(value)
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#39;");
}
function toggleDarkMode() { function toggleDarkMode() {
document.documentElement.classList.toggle("dark-mode"); document.documentElement.classList.toggle("dark-mode");
} }
@ -766,24 +752,40 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
// Build HTML for both endpoints and llama_server_endpoints // Build HTML for both endpoints and llama_server_endpoints
let html = ""; let html = "";
const renderRow = (e) => { // Add Ollama endpoints
const statusClass = html += data.endpoints
e.status === "ok" ? "status-ok" : "status-error"; .map((e) => {
const version = e.version || "N/A"; const statusClass =
const titleAttr = e.detail e.status === "ok"
? ` title="${escapeHtml(e.detail)}"` ? "status-ok"
: ""; : "status-error";
return ` const version = e.version || "N/A";
return `
<tr> <tr>
<td class="endpoint">${escapeHtml(e.url)}</td> <td class="endpoint">${e.url}</td>
<td class="status ${statusClass}"${titleAttr}>${escapeHtml(e.status)}</td> <td class="status ${statusClass}">${e.status}</td>
<td class="version">${escapeHtml(version)}</td> <td class="version">${version}</td>
</tr>`; </tr>`;
}; })
.join("");
html += data.endpoints.map(renderRow).join("");
// Add llama-server endpoints
if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) { if (data.llama_server_endpoints && data.llama_server_endpoints.length > 0) {
html += data.llama_server_endpoints.map(renderRow).join(""); html += data.llama_server_endpoints
.map((e) => {
const statusClass =
e.status === "ok"
? "status-ok"
: "status-error";
const version = e.version || "N/A";
return `
<tr>
<td class="endpoint">${e.url}</td>
<td class="status ${statusClass}">${e.status}</td>
<td class="version">${version}</td>
</tr>`;
})
.join("");
} }
body.innerHTML = html; body.innerHTML = html;
@ -1171,13 +1173,11 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
} }
}; };
source.onerror = (err) => { source.onerror = async (err) => {
// EventSource auto-reconnects on transient drops as long as we console.error("SSE connection error. Retrying...", err);
// don't close it. Don't treat a dropped stream as an auth failure: source.close();
// auth prompting is handled by loadEndpoints()/authedFetch() on the await showApiKeyModal("Enter the NOMYO Router API key to view live usage.");
// REST endpoints. A genuine 401 closes the stream permanently here loadUsage();
// (no reconnect loop), and the REST path surfaces the modal.
console.error("SSE connection error; awaiting auto-reconnect.", err);
}; };
window.addEventListener("beforeunload", () => source.close()); window.addEventListener("beforeunload", () => source.close());
} }

View file

@ -1,17 +0,0 @@
endpoints:
- http://192.168.0.51:12434
llama_server_endpoints:
- http://192.168.0.51:12434/v1
llama_swap_endpoints:
- http://192.168.0.51:12435/v1
max_concurrent_connections: 2
api_keys:
"http://192.168.0.51:12434": "ollama"
"http://192.168.0.51:12434/v1": "llama"
"http://192.168.0.51:12435/v1": "llama-swap"
cache_enabled: false

View file

@ -1,236 +0,0 @@
"""
Test configuration for nomyo-router.
Run from project root:
pytest test/ -v
pytest test/ -m "not integration" # skip real-server tests
pytest test/ -m integration -v # only real-server tests
Environment variables:
NOMYO_TEST_OLLAMA Ollama endpoint (default: http://192.168.0.50:12434)
NOMYO_TEST_LLAMA llama-server endpoint (default: http://192.168.0.50:12434/v1)
NOMYO_TEST_MODEL_CHAT chat model to use (auto-discovered if unset)
NOMYO_TEST_EMBED_MODEL embedding model (auto-discovered if unset)
"""
import asyncio
import os
import ssl
import sys
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import aiohttp
import httpx
import pytest
_TEST_DIR = Path(__file__).parent
# Must be set before importing router so module-level Config.from_yaml + Config field
# defaults pick these up. db_path is intentionally absent from config_test.yaml so the
# env-var default wins — keeps tests portable across CI runners (Linux/macOS/Windows).
os.environ.setdefault("NOMYO_ROUTER_CONFIG_PATH", str(_TEST_DIR / "config_test.yaml"))
os.environ.setdefault(
"NOMYO_ROUTER_DB_PATH",
str(Path(tempfile.gettempdir()) / "nomyo_router_test_tokens.db"),
)
sys.path.insert(0, str(_TEST_DIR.parent))
import router # noqa: E402
TEST_OLLAMA = os.getenv("NOMYO_TEST_OLLAMA", "http://192.168.0.51:12434")
TEST_LLAMA = os.getenv("NOMYO_TEST_LLAMA", "http://192.168.0.51:12434/v1")
def pytest_configure(config):
config.addinivalue_line(
"markers",
"integration: tests that require a real backend at 192.168.0.50:12434",
)
# ── Config mocks ─────────────────────────────────────────────────────────────
@pytest.fixture
def mock_config():
"""Minimal config pointing at TEST_OLLAMA / TEST_LLAMA."""
cfg = MagicMock()
cfg.endpoints = [TEST_OLLAMA]
cfg.llama_server_endpoints = [TEST_LLAMA]
cfg.llama_swap_endpoints = []
cfg.api_keys = {TEST_OLLAMA: "ollama", TEST_LLAMA: "llama"}
cfg.max_concurrent_connections = 2
cfg.router_api_key = None
cfg.cache_enabled = False
return cfg
@pytest.fixture
def mock_config_no_llama():
"""Config with Ollama only, no llama-server."""
cfg = MagicMock()
cfg.endpoints = [TEST_OLLAMA]
cfg.llama_server_endpoints = []
cfg.llama_swap_endpoints = []
cfg.api_keys = {TEST_OLLAMA: "ollama"}
cfg.max_concurrent_connections = 2
cfg.router_api_key = None
cfg.cache_enabled = False
return cfg
@pytest.fixture
def mock_config_with_key():
"""Config with router_api_key set (enables auth middleware)."""
cfg = MagicMock()
cfg.endpoints = [TEST_OLLAMA]
cfg.llama_server_endpoints = []
cfg.llama_swap_endpoints = []
cfg.api_keys = {}
cfg.max_concurrent_connections = 2
cfg.router_api_key = "test-secret-key"
cfg.cache_enabled = False
return cfg
# ── aiohttp session (used by fetch tests + choose_endpoint tests) ─────────────
@pytest.fixture
async def aio_session():
"""Real aiohttp session stored in app_state; intercepted by aioresponses."""
ssl_ctx = ssl.create_default_context()
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
session = aiohttp.ClientSession(connector=conn)
router.app_state["session"] = session
# Clear caches to prevent test bleed
router._models_cache.clear()
router._loaded_models_cache.clear()
router._available_error_cache.clear()
router._loaded_error_cache.clear()
router._inflight_available_models.clear()
router._inflight_loaded_models.clear()
router._bg_refresh_available.clear()
router._bg_refresh_loaded.clear()
yield session
await session.close()
router.app_state["session"] = None
# ── Validation-only HTTP client (no real backend needed) ──────────────────────
@pytest.fixture
async def client(mock_config, tmp_path):
"""httpx client for validation/auth tests — no real backend calls made."""
from db import TokenDatabase
ssl_ctx = ssl.create_default_context()
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
session = aiohttp.ClientSession(connector=conn)
db_inst = TokenDatabase(str(tmp_path / "test.db"))
await db_inst.init_db()
old_session = router.app_state.get("session")
old_db = router.db
router.app_state["session"] = session
router.db = db_inst
with patch.object(router, "config", mock_config):
transport = httpx.ASGITransport(app=router.app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test", timeout=10.0
) as c:
yield c
await session.close()
router.app_state["session"] = old_session
router.db = old_db
@pytest.fixture
async def client_auth(mock_config_with_key, tmp_path):
"""httpx client with router_api_key configured (for auth middleware tests)."""
from db import TokenDatabase
ssl_ctx = ssl.create_default_context()
conn = aiohttp.TCPConnector(ssl=ssl_ctx)
session = aiohttp.ClientSession(connector=conn)
db_inst = TokenDatabase(str(tmp_path / "test_auth.db"))
await db_inst.init_db()
old_session = router.app_state.get("session")
old_db = router.db
router.app_state["session"] = session
router.db = db_inst
with patch.object(router, "config", mock_config_with_key):
transport = httpx.ASGITransport(app=router.app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test", timeout=10.0
) as c:
yield c
await session.close()
router.app_state["session"] = old_session
router.db = old_db
# ── Integration client (full startup with real backend) ──────────────────────
@pytest.fixture(scope="module")
async def integration_client():
"""Full app startup pointing at the real test server."""
await router.startup_event()
transport = httpx.ASGITransport(app=router.app)
async with httpx.AsyncClient(
transport=transport,
base_url="http://test",
timeout=httpx.Timeout(60.0),
) as c:
yield c
await router.shutdown_event()
# ── Model discovery fixtures ──────────────────────────────────────────────────
@pytest.fixture(scope="module")
async def chat_model(integration_client):
"""Return a chat/generation model name available on the test server."""
env_model = os.getenv("NOMYO_TEST_MODEL_CHAT")
if env_model:
return env_model
resp = await integration_client.get("/api/tags")
if resp.status_code != 200:
pytest.skip("Cannot reach test server")
models = resp.json().get("models", [])
# Prefer small models for faster tests
for m in models:
name = m.get("name", "")
if any(x in name.lower() for x in ["0.5b", "1b", "3b", "1.5b", "2b"]):
return name
if models:
return models[0]["name"]
pytest.skip("No chat models available on test server")
@pytest.fixture(scope="module")
async def embed_model(integration_client):
"""Return an embedding model name available on the test server."""
env_model = os.getenv("NOMYO_TEST_EMBED_MODEL")
if env_model:
return env_model
resp = await integration_client.get("/api/tags")
if resp.status_code != 200:
pytest.skip("Cannot reach test server")
models = resp.json().get("models", [])
for m in models:
name = m.get("name", "")
if any(x in name.lower() for x in ["embed", "nomic", "minilm", "bge", "e5"]):
return name
pytest.skip("No embedding model available on test server")

View file

@ -1,138 +0,0 @@
# Load testing the NOMYO Router
`loadtest.py` is a self-contained load generator (asyncio + httpx) with a built-in
**mock backend** so you can measure the router's own concurrency ceiling on a given
machine — independent of real GPU/backend compute.
It answers the question *"how many concurrent connections can the router sustain
on this box?"* by hammering it with N concurrent virtual clients and reporting
throughput, latency percentiles and (for streaming) time-to-first-token.
Run everything from the project root with the project venv active:
```bash
source ~/.venv/nomyo-router/bin/activate # whatever venv has the router deps
```
## The three modes
### 1. `--mock-backend` (recommended) — fully self-contained
Spawns a fast fake Ollama/OpenAI backend **and** the router (wired to it via a
temporary config), drives load against the router, then tears both down. Because
the backend is trivial, the numbers reflect the **router's proxy overhead**, not
model inference time.
```bash
python test/load/loadtest.py --mock-backend --stream --concurrency 128 --duration 30
```
### 2. Default — drive an already-running router
```bash
python test/load/loadtest.py --url http://127.0.0.1:12434 \
--api ollama --stream --concurrency 64 --duration 30 --model llama3
```
### 3. `--serve-mock` — just the mock backend
Run only the fake backend and point your own router `config.yaml` at it
(`endpoints: [http://127.0.0.1:11434]`):
```bash
python test/load/loadtest.py --serve-mock --mock-port 11434 --mock-tokens 64
```
## Finding the concurrency knee
`--ramp` sweeps several concurrency levels and prints a table. The knee is where
`req/s` stops rising and `p99` latency starts climbing sharply:
```bash
python test/load/loadtest.py --mock-backend --stream \
--ramp 8,32,64,128,256 --duration 15
```
```
conc req ok err req/s p50ms p90ms p99ms maxms ttftP50 ttftP99
---------------------------------------------------------------------------------------------
8 120 120 0 19.8 404.6 448.3 478.6 501.4 358.4 391.7
32 140 140 0 21.5 1487.1 1641.8 2341.8 2397.4 1269.8 1476.3
64 148 148 0 21.3 2953.0 4632.5 5204.3 5267.0 1207.8 3031.7
128 168 168 0 19.0 6376.4 8608.9 8726.9 8739.8 2843.1 8348.6
```
> Reading the table above: throughput stays flat (~20 req/s) while latency grows
> linearly with concurrency — the classic signature of a **single-worker
> serialization bottleneck**. Raising `--router-workers` lets throughput scale
> across CPU cores; the per-worker ceiling is what each table row measures.
## Streaming vs non-streaming, Ollama vs OpenAI
| flag | effect |
|------|--------|
| `--stream` / `--no-stream` | streamed response (default) vs a single buffered response |
| `--api ollama` | drives `POST /api/chat` (default) |
| `--api openai` | drives `POST /v1/chat/completions` |
Streaming runs additionally report **TTFT** (time-to-first-token), which isolates
prefill/routing latency from total stream duration.
## Shaping the mock backend (the "fake GPU")
The mock's latency is fully configurable, so you can model anything from an
instant echo (measure pure proxy overhead) to a slow, long-streaming model
(measure how many slow streams the box holds open at once):
| flag | meaning |
|------|---------|
| `--mock-ttft-ms` | prefill latency before the first token (ms) |
| `--mock-tokens` | number of completion tokens emitted |
| `--mock-tok-ms` | per-token decode delay (ms) — inverse of tokens/sec |
| `--mock-models` | comma-separated model names advertised in `/api/tags` & `/api/ps` |
Example — simulate a realistic 40 tok/s model with 300 ms prefill emitting 200
tokens, and see how many concurrent such streams the router holds:
```bash
python test/load/loadtest.py --mock-backend --stream --ramp 16,64,256 \
--mock-ttft-ms 300 --mock-tokens 200 --mock-tok-ms 25 --duration 20
```
## Load shape & misc flags
| flag | default | meaning |
|------|---------|---------|
| `--concurrency N` | 32 | concurrent virtual clients |
| `--duration S` | 20 | seconds per stage (ignored if `--requests` set) |
| `--requests N` | — | send exactly N requests instead of timing out |
| `--warmup S` | 2 | unmeasured warmup before each stage (hot caches/connections) |
| `--timeout S` | 120 | per-request timeout |
| `--model NAME` | `mock` | model name requested (must match what the backend advertises) |
| `--prompt STR` | … | user prompt sent in every request |
| `--json PATH` | — | also write the full results as JSON |
### `--mock-backend` orchestration knobs
| flag | default | meaning |
|------|---------|---------|
| `--router-workers N` | 1 | `uvicorn --workers` for the spawned router |
| `--router-max-conc N` | = peak concurrency | `max_concurrent_connections` in the generated config (so the router doesn't queue unless you want it to) |
| `--router-port` / `--mock-port` | auto | fix the ports instead of auto-picking free ones |
| `--keep-config` | off | keep the generated temp `config.yaml` for inspection |
## Notes & caveats
- **Single-machine bias.** With `--mock-backend`, the driver, router and mock all
share the same CPU, so they compete for cores. For an upper-bound number, run
the driver on a separate machine against a real router (`--url`), or pin
processes to different cores.
- The generated config sets `conversation_affinity: false` and
`cache_enabled: false` to measure the raw proxy path. The temp config and a
throwaway token DB (under the system temp dir) are deleted on exit.
- To measure the router's *admission* limit instead of raw throughput, set
`--router-max-conc` low (e.g. `2`) — requests beyond the limit queue on the
least-busy endpoint rather than erroring.
- Requires the router's own dependencies (`aiohttp`, `httpx`, `uvicorn`, …); it
reuses the project venv, no extra packages needed.
```

View file

@ -1,803 +0,0 @@
#!/usr/bin/env python3
"""
NOMYO Router load test asyncio + httpx driver with a built-in mock backend.
Three modes
-----------
1. Drive an already-running router (default)::
python test/load/loadtest.py --url http://127.0.0.1:12434 \
--concurrency 64 --duration 30 --stream
2. Fully self-contained "mock backend" mode spins up a fast fake Ollama/OpenAI
backend AND the router (wired to that backend via a temp config), load-tests
them, then tears both down. This isolates the *router's* proxy overhead from
real GPU compute, so the numbers tell you how many concurrent connections the
router itself can sustain on this machine::
python test/load/loadtest.py --mock-backend \
--concurrency 128 --duration 30 --stream
3. Run just the mock backend (point your own router config at it)::
python test/load/loadtest.py --serve-mock --mock-port 11434
Both streaming and non-streaming are supported (--stream / --no-stream), against
either the Ollama API (--api ollama -> POST /api/chat) or the OpenAI-compatible
API (--api openai -> POST /v1/chat/completions).
Finding the concurrency ceiling
-------------------------------
Use --ramp to sweep concurrency levels and print a table; the "knee" is where
p99 latency climbs sharply or req/s stops increasing::
python test/load/loadtest.py --mock-backend --stream \
--ramp 16,32,64,128,256 --duration 15
The mock backend is a configurable "fake GPU": --mock-ttft-ms (prefill latency),
--mock-tokens (completion length) and --mock-tok-ms (per-token decode delay) let
you model anything from an instant echo (measure pure proxy overhead) to a slow,
long-streaming model (measure how many slow streams the box holds open).
"""
from __future__ import annotations
import argparse
import asyncio
import contextlib
import json
import math
import os
import signal
import socket
import statistics
import sys
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
_WORDS = (
"the quick brown fox jumps over the lazy dog while a router proxies many "
"concurrent streaming completions across several ollama and openai backends "
"without dropping a single token under sustained synthetic load testing"
).split()
def _gen_text(n_tokens: int) -> str:
"""Deterministic pseudo-completion of roughly ``n_tokens`` space-separated tokens."""
return " ".join(_WORDS[i % len(_WORDS)] for i in range(max(0, n_tokens)))
def _rfc3339_now() -> str:
# Ollama-style timestamp, e.g. 2024-01-01T00:00:00.000000Z
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
def _count_prompt_tokens(messages: list) -> int:
total = 0
for m in messages or []:
c = m.get("content")
if isinstance(c, str):
total += len(c.split())
elif isinstance(c, list):
for part in c:
if isinstance(part, dict) and isinstance(part.get("text"), str):
total += len(part["text"].split())
return max(1, total)
def _free_port() -> int:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
return port
# ===========================================================================
# Mock backend (a fast, configurable fake Ollama + OpenAI-compatible server)
# ===========================================================================
def build_mock_app(models: list[str], ttft_ms: float, tokens: int, tok_ms: float):
"""Construct the aiohttp mock-backend application.
Serves the native-Ollama surface the router uses for discovery and the
`/api/chat` path (`/api/version`, `/api/tags`, `/api/ps`, `/api/chat`,
`/api/generate`) plus the OpenAI-compatible surface used by the
`/v1/chat/completions` path (`/v1/models`, `/v1/chat/completions`,
`/v1/completions`).
"""
from aiohttp import web # imported lazily so the driver has no hard aiohttp dep
ttft = ttft_ms / 1000.0
tok_delay = tok_ms / 1000.0
def _tag_entry(name: str) -> dict:
return {
"name": name,
"model": name,
"modified_at": _rfc3339_now(),
"size": 4_000_000_000,
"digest": "0" * 64,
"details": {
"parent_model": "",
"format": "gguf",
"family": "mock",
"families": ["mock"],
"parameter_size": "7B",
"quantization_level": "Q4_0",
},
}
async def version(_req):
return web.json_response({"version": "0.0.0-nomyo-mock"})
async def tags(_req):
return web.json_response({"models": [_tag_entry(m) for m in models]})
async def ps(_req):
# Report every advertised model as loaded with VRAM so choose_endpoint
# treats this endpoint as "loaded + free".
out = []
for m in models:
e = _tag_entry(m)
e["size_vram"] = e["size"]
e["expires_at"] = "2999-01-01T00:00:00Z"
out.append(e)
return web.json_response({"models": out})
async def v1_models(_req):
now = int(time.time())
return web.json_response({
"object": "list",
"data": [{"id": m, "object": "model", "created": now, "owned_by": "mock"} for m in models],
})
# ----- Ollama /api/chat -------------------------------------------------
async def api_chat(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", True)
prompt_tok = _count_prompt_tokens(payload.get("messages", []))
t0 = time.perf_counter()
if stream:
resp = web.StreamResponse(
status=200, headers={"Content-Type": "application/x-ndjson"}
)
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
line = {
"model": model,
"created_at": _rfc3339_now(),
"message": {"role": "assistant", "content": _WORDS[i % len(_WORDS)] + " "},
"done": False,
}
await resp.write(json.dumps(line).encode() + b"\n")
dur_ns = int((time.perf_counter() - t0) * 1e9)
final = {
"model": model,
"created_at": _rfc3339_now(),
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
"total_duration": dur_ns,
"load_duration": 0,
"prompt_eval_count": prompt_tok,
"prompt_eval_duration": int(ttft * 1e9),
"eval_count": tokens,
"eval_duration": dur_ns,
}
await resp.write(json.dumps(final).encode() + b"\n")
await resp.write_eof()
return resp
# non-streaming: simulate the whole generation latency, then one object
await asyncio.sleep(ttft + tokens * tok_delay)
dur_ns = int((time.perf_counter() - t0) * 1e9)
return web.json_response({
"model": model,
"created_at": _rfc3339_now(),
"message": {"role": "assistant", "content": _gen_text(tokens)},
"done": True,
"done_reason": "stop",
"total_duration": dur_ns,
"load_duration": 0,
"prompt_eval_count": prompt_tok,
"prompt_eval_duration": int(ttft * 1e9),
"eval_count": tokens,
"eval_duration": dur_ns,
})
# ----- Ollama /api/generate --------------------------------------------
async def api_generate(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", True)
prompt_tok = max(1, len(str(payload.get("prompt", "")).split()))
t0 = time.perf_counter()
if stream:
resp = web.StreamResponse(status=200, headers={"Content-Type": "application/x-ndjson"})
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
await resp.write(json.dumps({
"model": model, "created_at": _rfc3339_now(),
"response": _WORDS[i % len(_WORDS)] + " ", "done": False,
}).encode() + b"\n")
dur_ns = int((time.perf_counter() - t0) * 1e9)
await resp.write(json.dumps({
"model": model, "created_at": _rfc3339_now(), "response": "", "done": True,
"done_reason": "stop", "total_duration": dur_ns,
"prompt_eval_count": prompt_tok, "eval_count": tokens, "eval_duration": dur_ns,
}).encode() + b"\n")
await resp.write_eof()
return resp
await asyncio.sleep(ttft + tokens * tok_delay)
dur_ns = int((time.perf_counter() - t0) * 1e9)
return web.json_response({
"model": model, "created_at": _rfc3339_now(), "response": _gen_text(tokens),
"done": True, "done_reason": "stop", "total_duration": dur_ns,
"prompt_eval_count": prompt_tok, "eval_count": tokens, "eval_duration": dur_ns,
})
# ----- OpenAI /v1/chat/completions -------------------------------------
async def v1_chat(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", False)
want_usage = bool((payload.get("stream_options") or {}).get("include_usage"))
prompt_tok = _count_prompt_tokens(payload.get("messages", []))
created = int(time.time())
cid = "chatcmpl-mock"
if stream:
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
def _sse(obj: dict) -> bytes:
return b"data: " + json.dumps(obj).encode() + b"\n\n"
# first chunk carries the role
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
}))
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {"content": _WORDS[i % len(_WORDS)] + " "}, "finish_reason": None}],
}))
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}))
if want_usage:
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [],
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
"total_tokens": prompt_tok + tokens},
}))
await resp.write(b"data: [DONE]\n\n")
await resp.write_eof()
return resp
await asyncio.sleep(ttft + tokens * tok_delay)
return web.json_response({
"id": cid, "object": "chat.completion", "created": created, "model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": _gen_text(tokens)},
"finish_reason": "stop", "logprobs": None}],
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
"total_tokens": prompt_tok + tokens},
})
# ----- OpenAI /v1/completions ------------------------------------------
async def v1_completions(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", False)
prompt_tok = max(1, len(str(payload.get("prompt", "")).split()))
created = int(time.time())
cid = "cmpl-mock"
if stream:
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
await resp.write(b"data: " + json.dumps({
"id": cid, "object": "text_completion", "created": created, "model": model,
"choices": [{"index": 0, "text": _WORDS[i % len(_WORDS)] + " ", "finish_reason": None}],
}).encode() + b"\n\n")
await resp.write(b"data: [DONE]\n\n")
await resp.write_eof()
return resp
await asyncio.sleep(ttft + tokens * tok_delay)
return web.json_response({
"id": cid, "object": "text_completion", "created": created, "model": model,
"choices": [{"index": 0, "text": _gen_text(tokens), "finish_reason": "stop"}],
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
"total_tokens": prompt_tok + tokens},
})
app = web.Application(client_max_size=64 * 1024 * 1024)
app.add_routes([
web.get("/api/version", version),
web.get("/api/tags", tags),
web.get("/api/ps", ps),
web.post("/api/chat", api_chat),
web.post("/api/generate", api_generate),
web.get("/v1/models", v1_models),
web.post("/v1/chat/completions", v1_chat),
web.post("/v1/completions", v1_completions),
])
return app
def serve_mock(args) -> None:
from aiohttp import web
models = [m.strip() for m in args.mock_models.split(",") if m.strip()]
app = build_mock_app(models, args.mock_ttft_ms, args.mock_tokens, args.mock_tok_ms)
print(f"[mock] serving models={models} on http://{args.mock_host}:{args.mock_port} "
f"(ttft={args.mock_ttft_ms}ms tokens={args.mock_tokens} tok={args.mock_tok_ms}ms)",
flush=True)
web.run_app(app, host=args.mock_host, port=args.mock_port, print=None)
# ===========================================================================
# Load driver
# ===========================================================================
@dataclass
class Sample:
ok: bool
status: int
latency: float # full request wall time (s)
ttft: Optional[float] # time-to-first-byte for streaming (s), else None
err: Optional[str] = None
@dataclass
class Stats:
concurrency: int
wall: float = 0.0
samples: list = field(default_factory=list)
@property
def ok(self) -> list:
return [s for s in self.samples if s.ok]
@property
def n_total(self) -> int:
return len(self.samples)
@property
def n_ok(self) -> int:
return len(self.ok)
def _pct(values: list[float], p: float) -> float:
if not values:
return float("nan")
s = sorted(values)
if len(s) == 1:
return s[0]
k = (len(s) - 1) * (p / 100.0)
lo = math.floor(k)
hi = math.ceil(k)
if lo == hi:
return s[int(k)]
return s[lo] + (s[hi] - s[lo]) * (k - lo)
def _build_request(args):
"""Return (path, json_payload) for a single request."""
messages = [{"role": "user", "content": args.prompt}]
if args.api == "openai":
path = "/v1/chat/completions"
body = {"model": args.model, "messages": messages, "stream": args.stream}
else:
path = "/api/chat"
body = {"model": args.model, "messages": messages, "stream": args.stream}
return path, body
async def _one_request(client: httpx.AsyncClient, url: str, body: dict, stream: bool) -> Sample:
t0 = time.perf_counter()
try:
if stream:
ttft = None
async with client.stream("POST", url, json=body) as resp:
status = resp.status_code
async for _chunk in resp.aiter_bytes():
if ttft is None:
ttft = time.perf_counter() - t0
# drain complete
lat = time.perf_counter() - t0
ok = 200 <= status < 300
return Sample(ok=ok, status=status, latency=lat, ttft=ttft,
err=None if ok else f"HTTP {status}")
else:
resp = await client.post(url, json=body)
lat = time.perf_counter() - t0
ok = 200 <= resp.status_code < 300
# touch body so the full response is received
_ = resp.content
return Sample(ok=ok, status=resp.status_code, latency=lat, ttft=None,
err=None if ok else f"HTTP {resp.status_code}")
except Exception as e: # noqa: BLE001 — record any transport error as a failed sample
lat = time.perf_counter() - t0
return Sample(ok=False, status=0, latency=lat, ttft=None,
err=f"{type(e).__name__}: {str(e)[:120]}")
async def run_stage(args, concurrency: int) -> Stats:
path, body = _build_request(args)
url = args.url.rstrip("/") + path
stats = Stats(concurrency=concurrency)
limits = httpx.Limits(max_connections=concurrency + 50,
max_keepalive_connections=concurrency + 50)
timeout = httpx.Timeout(args.timeout, connect=15.0)
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
# warmup (unmeasured): make a few requests so caches/connections are hot
if args.warmup > 0:
warm_deadline = time.perf_counter() + args.warmup
async def _warm():
while time.perf_counter() < warm_deadline:
await _one_request(client, url, body, args.stream)
await asyncio.gather(*[_warm() for _ in range(min(concurrency, 8))])
use_duration = args.requests is None
deadline = time.perf_counter() + args.duration if use_duration else None
remaining = args.requests if not use_duration else None
remaining_lock = asyncio.Lock()
async def worker():
nonlocal remaining
while True:
if use_duration:
if time.perf_counter() >= deadline:
return
else:
async with remaining_lock:
if remaining <= 0:
return
remaining -= 1
s = await _one_request(client, url, body, args.stream)
stats.samples.append(s)
wall0 = time.perf_counter()
await asyncio.gather(*[worker() for _ in range(concurrency)])
stats.wall = time.perf_counter() - wall0
return stats
def _print_stage(stats: Stats, args, header: bool) -> None:
lat = [s.latency * 1000 for s in stats.ok]
ttfts = [s.ttft * 1000 for s in stats.ok if s.ttft is not None]
rps = stats.n_ok / stats.wall if stats.wall else 0.0
errs = stats.n_total - stats.n_ok
if args.ramp:
if header:
cols = f"{'conc':>5} {'req':>7} {'ok':>7} {'err':>5} {'req/s':>9} " \
f"{'p50ms':>8} {'p90ms':>8} {'p99ms':>9} {'maxms':>9}"
if args.stream:
cols += f" {'ttftP50':>8} {'ttftP99':>8}"
print(cols)
print("-" * len(cols))
row = (f"{stats.concurrency:>5} {stats.n_total:>7} {stats.n_ok:>7} {errs:>5} "
f"{rps:>9.1f} {_pct(lat,50):>8.1f} {_pct(lat,90):>8.1f} "
f"{_pct(lat,99):>9.1f} {(max(lat) if lat else float('nan')):>9.1f}")
if args.stream:
row += f" {_pct(ttfts,50):>8.1f} {_pct(ttfts,99):>8.1f}"
print(row, flush=True)
return
# single-stage detailed report
print(f"\n=== Results (concurrency={stats.concurrency}, "
f"{'stream' if args.stream else 'non-stream'}, api={args.api}) ===")
print(f" wall time : {stats.wall:8.2f} s")
print(f" requests : {stats.n_total} total, {stats.n_ok} ok, {errs} failed")
print(f" throughput : {rps:8.1f} req/s")
if lat:
print(f" latency p50 : {_pct(lat,50):8.1f} ms")
print(f" p90 : {_pct(lat,90):8.1f} ms")
print(f" p95 : {_pct(lat,95):8.1f} ms")
print(f" p99 : {_pct(lat,99):8.1f} ms")
print(f" max : {max(lat):8.1f} ms")
print(f" mean : {statistics.mean(lat):8.1f} ms")
if ttfts:
print(f" TTFT p50 : {_pct(ttfts,50):8.1f} ms")
print(f" p90 : {_pct(ttfts,90):8.1f} ms")
print(f" p99 : {_pct(ttfts,99):8.1f} ms")
if errs:
by_err: dict[str, int] = {}
for s in stats.samples:
if not s.ok:
by_err[s.err or "unknown"] = by_err.get(s.err or "unknown", 0) + 1
print(" errors:")
for k, v in sorted(by_err.items(), key=lambda kv: -kv[1]):
print(f" {v:>6} {k}")
async def run_driver(args) -> list[Stats]:
stages = ([int(x) for x in args.ramp.split(",")] if args.ramp else [args.concurrency])
results: list[Stats] = []
for i, c in enumerate(stages):
stats = await run_stage(args, c)
_print_stage(stats, args, header=(i == 0))
results.append(stats)
return results
# ===========================================================================
# Orchestration: --mock-backend (spawn mock + router, run, tear down)
# ===========================================================================
PROJECT_ROOT = Path(__file__).resolve().parents[2]
async def _wait_http_ok(url: str, timeout: float, accept=(200,)) -> bool:
deadline = time.perf_counter() + timeout
async with httpx.AsyncClient(timeout=5.0) as client:
while time.perf_counter() < deadline:
try:
r = await client.get(url)
if r.status_code in accept:
return True
except Exception:
pass
await asyncio.sleep(0.25)
return False
def _write_temp_config(mock_url: str, models: list[str], max_conc: int) -> Path:
fd, path = tempfile.mkstemp(prefix="nomyo_loadtest_", suffix=".yaml")
os.close(fd)
cfg = (
"# Auto-generated by test/load/loadtest.py --mock-backend. Safe to delete.\n"
"endpoints:\n"
f" - {mock_url}\n"
"llama_server_endpoints: []\n"
f"max_concurrent_connections: {max_conc}\n"
"priority_routing: false\n"
"conversation_affinity: false\n"
"cache_enabled: false\n"
"nomyo-router-api-key: \"\"\n"
"api_keys:\n"
f" \"{mock_url}\": \"mock\"\n"
)
Path(path).write_text(cfg)
return Path(path)
async def run_with_mock_backend(args) -> list[Stats]:
mock_port = args.mock_port or _free_port()
router_port = args.router_port or _free_port()
mock_url = f"http://127.0.0.1:{mock_port}"
router_url = f"http://127.0.0.1:{router_port}"
models = [m.strip() for m in args.mock_models.split(",") if m.strip()]
# Size the router's per-endpoint admission limit so it does not artificially
# serialize the load (unless the user explicitly wants to measure that).
peak = max([int(x) for x in args.ramp.split(",")]) if args.ramp else args.concurrency
max_conc = args.router_max_conc if args.router_max_conc else max(peak, 1)
cfg_path = _write_temp_config(mock_url, models, max_conc)
db_path = Path(tempfile.gettempdir()) / f"nomyo_loadtest_{os.getpid()}.db"
env = dict(os.environ)
env["NOMYO_ROUTER_CONFIG_PATH"] = str(cfg_path)
env["NOMYO_ROUTER_DB_PATH"] = str(db_path)
mock_proc = None
router_proc = None
try:
# 1. mock backend first, so the router never caches it as "down"
mock_cmd = [
sys.executable, str(Path(__file__).resolve()), "--serve-mock",
"--mock-host", "127.0.0.1", "--mock-port", str(mock_port),
"--mock-models", args.mock_models,
"--mock-ttft-ms", str(args.mock_ttft_ms),
"--mock-tokens", str(args.mock_tokens),
"--mock-tok-ms", str(args.mock_tok_ms),
]
print(f"[orchestrator] starting mock backend: {mock_url}", flush=True)
mock_proc = await asyncio.create_subprocess_exec(*mock_cmd)
if not await _wait_http_ok(f"{mock_url}/api/version", timeout=15):
raise RuntimeError("mock backend did not become ready")
# 2. router
router_cmd = [
sys.executable, "-m", "uvicorn", "router:app",
"--host", "127.0.0.1", "--port", str(router_port),
# Per-request access logging is pure noise (and overhead) under load.
"--no-access-log",
]
if args.router_workers and args.router_workers > 1:
router_cmd += ["--workers", str(args.router_workers)]
print(f"[orchestrator] starting router: {router_url} "
f"(workers={args.router_workers}, max_concurrent_connections={max_conc})", flush=True)
router_proc = await asyncio.create_subprocess_exec(
*router_cmd, cwd=str(PROJECT_ROOT), env=env
)
# /health returns 200 only once it can reach the (healthy) mock backend
if not await _wait_http_ok(f"{router_url}/health", timeout=40, accept=(200,)):
raise RuntimeError("router did not become healthy")
print("[orchestrator] router healthy — starting load\n", flush=True)
# 3. drive load against the router
args.url = router_url
return await run_driver(args)
finally:
for name, proc in (("router", router_proc), ("mock", mock_proc)):
if proc and proc.returncode is None:
with contextlib.suppress(ProcessLookupError):
proc.send_signal(signal.SIGINT)
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(proc.wait(), timeout=8)
if proc.returncode is None:
with contextlib.suppress(ProcessLookupError):
proc.kill()
print(f"[orchestrator] stopped {name}", flush=True)
if args.keep_config:
print(f"[orchestrator] kept config: {cfg_path}", flush=True)
else:
with contextlib.suppress(FileNotFoundError):
cfg_path.unlink()
for suffix in ("", "-shm", "-wal"):
with contextlib.suppress(FileNotFoundError):
Path(str(db_path) + suffix).unlink()
# ===========================================================================
# CLI
# ===========================================================================
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="NOMYO Router load test (asyncio + httpx) with a built-in mock backend.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
mode = p.add_argument_group("mode")
mode.add_argument("--serve-mock", action="store_true",
help="Run ONLY the mock backend (foreground) and exit on Ctrl-C.")
mode.add_argument("--mock-backend", action="store_true",
help="Spawn mock backend + router, load-test them, then tear down.")
tgt = p.add_argument_group("target (driver)")
tgt.add_argument("--url", default="http://127.0.0.1:12434",
help="Router base URL to drive (ignored with --mock-backend).")
tgt.add_argument("--api", choices=["ollama", "openai"], default="ollama",
help="ollama -> POST /api/chat ; openai -> POST /v1/chat/completions")
tgt.add_argument("--model", default="mock", help="Model name to request.")
tgt.add_argument("--prompt", default="Say hello and count to ten.",
help="User prompt sent in every request.")
stream_grp = tgt.add_mutually_exclusive_group()
stream_grp.add_argument("--stream", dest="stream", action="store_true",
help="Stream the response (default).")
stream_grp.add_argument("--no-stream", dest="stream", action="store_false",
help="Request a single non-streamed response.")
p.set_defaults(stream=True)
load = p.add_argument_group("load shape")
load.add_argument("--concurrency", type=int, default=32,
help="Number of concurrent virtual clients.")
load.add_argument("--duration", type=float, default=20.0,
help="Seconds to run each stage (ignored if --requests given).")
load.add_argument("--requests", type=int, default=None,
help="Send exactly N requests instead of running for --duration.")
load.add_argument("--ramp", default=None,
help="Comma-separated concurrency stages, e.g. 16,32,64,128 "
"(prints a table to find the knee).")
load.add_argument("--warmup", type=float, default=2.0,
help="Seconds of unmeasured warmup before each stage.")
load.add_argument("--timeout", type=float, default=120.0,
help="Per-request timeout (seconds).")
load.add_argument("--json", dest="json_out", default=None,
help="Also write the results as JSON to this path.")
mock = p.add_argument_group("mock backend tuning")
mock.add_argument("--mock-host", default="127.0.0.1")
mock.add_argument("--mock-port", type=int, default=0,
help="Mock backend port (0 = auto-pick a free port).")
mock.add_argument("--mock-models", default="mock",
help="Comma-separated model names the mock advertises.")
mock.add_argument("--mock-ttft-ms", type=float, default=0.0,
help="Simulated prefill latency before the first token (ms).")
mock.add_argument("--mock-tokens", type=int, default=64,
help="Completion length in tokens the mock emits.")
mock.add_argument("--mock-tok-ms", type=float, default=0.0,
help="Simulated per-token decode delay (ms) = inverse of tok/s.")
orch = p.add_argument_group("router orchestration (--mock-backend only)")
orch.add_argument("--router-port", type=int, default=0,
help="Router port (0 = auto-pick a free port).")
orch.add_argument("--router-workers", type=int, default=1,
help="uvicorn --workers for the spawned router.")
orch.add_argument("--router-max-conc", type=int, default=0,
help="max_concurrent_connections in the generated config "
"(0 = match peak concurrency so the router does not queue).")
orch.add_argument("--keep-config", action="store_true",
help="Do not delete the generated temp config on exit.")
return p
def _dump_json(path: str, args, results: list[Stats]) -> None:
out = {
"config": {k: getattr(args, k) for k in (
"api", "model", "stream", "duration", "requests", "warmup", "timeout",
"mock_tokens", "mock_ttft_ms", "mock_tok_ms")},
"stages": [],
}
for st in results:
lat = [s.latency * 1000 for s in st.ok]
ttfts = [s.ttft * 1000 for s in st.ok if s.ttft is not None]
out["stages"].append({
"concurrency": st.concurrency,
"wall_s": st.wall,
"requests": st.n_total,
"ok": st.n_ok,
"errors": st.n_total - st.n_ok,
"rps": (st.n_ok / st.wall) if st.wall else 0.0,
"latency_ms": {p: _pct(lat, p) for p in (50, 90, 95, 99)} | (
{"max": max(lat), "mean": statistics.mean(lat)} if lat else {}),
"ttft_ms": {p: _pct(ttfts, p) for p in (50, 90, 99)} if ttfts else {},
})
Path(path).write_text(json.dumps(out, indent=2))
print(f"\n[driver] wrote JSON results to {path}", flush=True)
def main() -> None:
args = build_parser().parse_args()
if args.serve_mock:
try:
serve_mock(args)
except KeyboardInterrupt:
pass
return
if args.requests is not None and args.requests <= 0:
print("--requests must be > 0", file=sys.stderr)
sys.exit(2)
if args.mock_backend:
results = asyncio.run(run_with_mock_backend(args)) or []
else:
results = asyncio.run(run_driver(args))
if args.json_out and results:
_dump_json(args.json_out, args, results)
if __name__ == "__main__":
main()

View file

@ -1,7 +0,0 @@
[pytest]
asyncio_mode = auto
markers =
integration: tests that require a real backend at 192.168.0.51:12434
testpaths = .
filterwarnings =
ignore::pytest.PytestUnhandledThreadExceptionWarning

View file

@ -1,3 +0,0 @@
pytest>=8.0
pytest-asyncio>=0.24
pytest-cov>=5.0

View file

@ -1,60 +0,0 @@
# Testing nomyo-router
## Setup
Install test dependencies (from the project root):
```bash
pip install -r test/requirements_test.txt
```
## Running tests
All commands run from the `test/` directory:
```bash
cd test
```
**All non-integration tests** (no backend required):
```bash
pytest -m "not integration" -v
```
**Integration tests only** (requires backend at `192.168.0.51:12434`):
```bash
pytest -m integration -v
```
**Everything:**
```bash
pytest -v
```
## Test structure
| File | What it covers | Backend needed |
|---|---|---|
| `test_unit_helpers.py` | Pure helper functions (`_mask_secrets`, `_is_fresh`, `ep2base`, etc.) | No |
| `test_unit_transforms.py` | Message transform functions (tool calls, image stripping, etc.) | No |
| `test_unit_context.py` | Context window trimming logic | No |
| `test_fetch.py` | `fetch.available_models` / `fetch.loaded_models` with mocked HTTP | No |
| `test_choose_endpoint.py` | `choose_endpoint` routing logic with mocked fetch layer | No |
| `test_api_validation.py` | HTTP 400/401/403 validation and auth middleware (in-process app) | No |
| `test_api_integration.py` | Full request/response against a real Ollama/llama-server backend | **Yes** |
## Integration test backend
Integration tests start the router in-process via `startup_event()` and route traffic
through `httpx.ASGITransport` — no separately running router instance is needed.
They do require a reachable Ollama or llama-server backend. Override the defaults via
environment variables:
```bash
export NOMYO_TEST_OLLAMA=http://192.168.0.51:12434
export NOMYO_TEST_EMBED_MODEL=nomic-embed-text # optional, auto-discovered otherwise
export NOMYO_TEST_MODEL_CHAT=llama3.2 # optional, auto-discovered otherwise
```
If the backend is unreachable, integration tests are automatically skipped.

View file

@ -1,304 +0,0 @@
"""
Integration tests against the real backend at 192.168.0.50:12434.
Run with:
pytest test/test_api_integration.py -v -m integration
All tests in this file are marked @pytest.mark.integration.
They require the test server to be reachable and to have at least one
chat model and one embedding model available.
Env vars to pin specific models:
NOMYO_TEST_MODEL_CHAT e.g. qwen2.5:1.5b
NOMYO_TEST_EMBED_MODEL e.g. nomic-embed-text:latest
"""
import json
import pytest
pytestmark = pytest.mark.integration
# ── Health / discovery routes ─────────────────────────────────────────────────
class TestDiscoveryRoutes:
async def test_version(self, integration_client):
resp = await integration_client.get("/api/version")
assert resp.status_code == 200
data = resp.json()
assert "version" in data
assert isinstance(data["version"], str)
async def test_tags_returns_models(self, integration_client):
resp = await integration_client.get("/api/tags")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert isinstance(data["models"], list)
assert len(data["models"]) > 0
async def test_ps_returns_list(self, integration_client):
resp = await integration_client.get("/api/ps")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert isinstance(data["models"], list)
async def test_v1_models_returns_data(self, integration_client):
resp = await integration_client.get("/v1/models")
assert resp.status_code == 200
data = resp.json()
assert "data" in data
assert isinstance(data["data"], list)
async def test_usage_returns_counts(self, integration_client):
resp = await integration_client.get("/api/usage")
assert resp.status_code == 200
data = resp.json()
assert "usage_counts" in data
assert "token_usage_counts" in data
async def test_config_returns_endpoints(self, integration_client):
resp = await integration_client.get("/api/config")
assert resp.status_code == 200
data = resp.json()
assert "endpoints" in data
async def test_hostname(self, integration_client):
resp = await integration_client.get("/api/hostname")
assert resp.status_code == 200
assert "hostname" in resp.json()
async def test_health(self, integration_client):
resp = await integration_client.get("/health")
assert resp.status_code in (200, 503)
data = resp.json()
assert data["status"] in ("ok", "error")
assert "endpoints" in data
async def test_cache_stats(self, integration_client):
resp = await integration_client.get("/api/cache/stats")
assert resp.status_code == 200
data = resp.json()
assert "enabled" in data
# ── /api/chat ─────────────────────────────────────────────────────────────────
class TestApiChat:
async def test_non_streaming(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": False,
"messages": [{"role": "user", "content": "Reply with exactly: OK"}],
"options": {"num_predict": 10},
},
)
assert resp.status_code == 200
data = resp.json()
assert "message" in data
assert "content" in data["message"]
async def test_streaming_ndjson(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": True,
"messages": [{"role": "user", "content": "Say hi"}],
"options": {"num_predict": 5},
},
)
assert resp.status_code == 200
lines = [l for l in resp.text.strip().split("\n") if l.strip()]
assert len(lines) >= 1
for line in lines:
obj = json.loads(line)
assert "model" in obj
async def test_non_streaming_has_token_counts(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": False,
"messages": [{"role": "user", "content": "Count to 3"}],
"options": {"num_predict": 20},
},
)
assert resp.status_code == 200
data = resp.json()
assert data.get("done") is True
# Token counts should be present in the final chunk
assert data.get("prompt_eval_count", 0) >= 0
async def test_system_message_honoured(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/chat",
json={
"model": chat_model,
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant. Always reply with exactly: PONG"},
{"role": "user", "content": "PING"},
],
"options": {"num_predict": 10},
},
)
assert resp.status_code == 200
content = resp.json()["message"]["content"]
assert isinstance(content, str)
assert len(content) > 0
# ── /api/generate ─────────────────────────────────────────────────────────────
class TestApiGenerate:
async def test_non_streaming(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/generate",
json={
"model": chat_model,
"prompt": "Complete: The sky is",
"stream": False,
"options": {"num_predict": 5},
},
)
assert resp.status_code == 200
data = resp.json()
assert "response" in data
async def test_streaming(self, integration_client, chat_model):
resp = await integration_client.post(
"/api/generate",
json={
"model": chat_model,
"prompt": "One plus one equals",
"stream": True,
"options": {"num_predict": 5},
},
)
assert resp.status_code == 200
lines = [l for l in resp.text.strip().split("\n") if l.strip()]
assert len(lines) >= 1
# ── /api/embed ────────────────────────────────────────────────────────────────
class TestApiEmbed:
async def test_embed_single_string(self, integration_client, embed_model):
resp = await integration_client.post(
"/api/embed",
json={"model": embed_model, "input": "The quick brown fox"},
)
assert resp.status_code == 200
data = resp.json()
assert "embeddings" in data
assert isinstance(data["embeddings"], list)
assert len(data["embeddings"]) == 1
assert len(data["embeddings"][0]) > 0
async def test_embed_multiple_inputs(self, integration_client, embed_model):
resp = await integration_client.post(
"/api/embed",
json={"model": embed_model, "input": ["sentence one", "sentence two"]},
)
assert resp.status_code == 200
data = resp.json()
assert "embeddings" in data
assert len(data["embeddings"]) == 2
# ── /v1/chat/completions ──────────────────────────────────────────────────────
class TestOpenAIChatCompletions:
async def test_non_streaming(self, integration_client, chat_model):
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
resp = await integration_client.post(
"/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": "Reply OK"}],
"max_tokens": 10,
"stream": False,
},
)
assert resp.status_code == 200
data = resp.json()
assert "choices" in data
assert len(data["choices"]) > 0
assert "message" in data["choices"][0]
async def test_streaming_sse(self, integration_client, chat_model):
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
resp = await integration_client.post(
"/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
"stream": True,
},
)
assert resp.status_code == 200
# Response should be SSE format
assert "data:" in resp.text or "[DONE]" in resp.text
async def test_non_streaming_has_usage(self, integration_client, chat_model):
model = chat_model.replace(":latest", "") if ":latest" in chat_model else chat_model
resp = await integration_client.post(
"/v1/chat/completions",
json={
"model": model,
"messages": [{"role": "user", "content": "Say yes"}],
"max_tokens": 5,
"stream": False,
},
)
assert resp.status_code == 200
data = resp.json()
if "usage" in data and data["usage"]:
assert data["usage"].get("prompt_tokens", 0) >= 0
# ── /v1/embeddings ────────────────────────────────────────────────────────────
class TestOpenAIEmbeddings:
async def test_single_input(self, integration_client, embed_model):
model = embed_model.replace(":latest", "") if ":latest" in embed_model else embed_model
resp = await integration_client.post(
"/v1/embeddings",
json={"model": model, "input": "Test sentence"},
)
assert resp.status_code == 200
data = resp.json()
assert "data" in data
assert len(data["data"]) > 0
embedding = data["data"][0].get("embedding")
assert isinstance(embedding, list)
assert len(embedding) > 0
# ── Token counts (database-backed) ───────────────────────────────────────────
class TestTokenCounts:
async def test_token_counts_endpoint(self, integration_client):
resp = await integration_client.get("/api/token_counts")
assert resp.status_code == 200
data = resp.json()
assert "total_tokens" in data
assert "breakdown" in data
# ── ps_details (extended ps) ─────────────────────────────────────────────────
class TestPsDetails:
async def test_ps_details_returns_models(self, integration_client):
resp = await integration_client.get("/api/ps_details")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert isinstance(data["models"], list)

View file

@ -1,230 +0,0 @@
"""
HTTP-level validation and auth middleware tests.
These tests use an in-process httpx client and never reach a real backend:
all requests are rejected at the validation or auth layer before any
endpoint-selection or upstream HTTP calls occur.
"""
import pytest
class TestChatValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post(
"/api/chat",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status_code == 400
assert "model" in resp.json()["detail"].lower()
async def test_missing_messages_returns_400(self, client):
resp = await client.post("/api/chat", json={"model": "llama3.2"})
assert resp.status_code == 400
async def test_invalid_json_returns_400(self, client):
resp = await client.post(
"/api/chat",
content=b"not-json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
async def test_messages_not_list_returns_400(self, client):
resp = await client.post(
"/api/chat",
json={"model": "m", "messages": "not-a-list"},
)
assert resp.status_code == 400
async def test_options_not_dict_returns_400(self, client):
resp = await client.post(
"/api/chat",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}], "options": "bad"},
)
assert resp.status_code == 400
class TestGenerateValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/generate", json={"prompt": "hello"})
assert resp.status_code == 400
assert "model" in resp.json()["detail"].lower()
async def test_missing_prompt_returns_400(self, client):
resp = await client.post("/api/generate", json={"model": "m"})
assert resp.status_code == 400
assert "prompt" in resp.json()["detail"].lower()
async def test_invalid_json_returns_400(self, client):
resp = await client.post(
"/api/generate",
content=b"{bad-json",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
class TestEmbedValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/embed", json={"input": "hello"})
assert resp.status_code == 400
async def test_missing_input_returns_400(self, client):
resp = await client.post("/api/embed", json={"model": "nomic-embed-text"})
assert resp.status_code == 400
class TestEmbeddingsValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/embeddings", json={"prompt": "hello"})
assert resp.status_code == 400
async def test_missing_prompt_returns_400(self, client):
resp = await client.post("/api/embeddings", json={"model": "nomic-embed-text"})
assert resp.status_code == 400
class TestOpenAIChatValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status_code == 400
async def test_missing_messages_returns_400(self, client):
resp = await client.post(
"/v1/chat/completions",
json={"model": "gpt-4o"},
)
assert resp.status_code == 400
async def test_invalid_json_returns_400(self, client):
resp = await client.post(
"/v1/chat/completions",
content=b"}{",
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
async def test_svg_image_rejected(self, client):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "vision-model",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "data:image/svg+xml;base64,abc"}},
],
}],
},
)
assert resp.status_code == 400
assert "svg" in resp.json()["detail"].lower()
class TestOpenAICompletionsValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/v1/completions", json={"prompt": "hello"})
assert resp.status_code == 400
async def test_missing_prompt_returns_400(self, client):
resp = await client.post("/v1/completions", json={"model": "m"})
assert resp.status_code == 400
class TestRerankValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post(
"/v1/rerank",
json={"query": "search query", "documents": ["doc1"]},
)
assert resp.status_code == 400
async def test_missing_query_returns_400(self, client):
resp = await client.post(
"/v1/rerank",
json={"model": "reranker", "documents": ["doc1"]},
)
assert resp.status_code == 400
async def test_empty_documents_returns_400(self, client):
resp = await client.post(
"/v1/rerank",
json={"model": "reranker", "query": "search", "documents": []},
)
assert resp.status_code == 400
class TestShowValidation:
async def test_missing_model_returns_400(self, client):
resp = await client.post("/api/show", json={})
assert resp.status_code == 400
class TestCopyValidation:
async def test_missing_source_returns_400(self, client):
resp = await client.post("/api/copy", json={"destination": "dst"})
assert resp.status_code == 400
async def test_missing_destination_returns_400(self, client):
resp = await client.post("/api/copy", json={"source": "src"})
assert resp.status_code == 400
class TestDeleteValidation:
async def test_missing_model_returns_400(self, client):
import json as _json
resp = await client.request(
"DELETE",
"/api/delete",
content=_json.dumps({}).encode(),
headers={"Content-Type": "application/json"},
)
assert resp.status_code == 400
class TestAuthMiddleware:
async def test_no_key_returns_401(self, client_auth):
resp = await client_auth.post(
"/api/chat",
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status_code == 401
assert "Missing" in resp.json()["detail"]
async def test_invalid_key_returns_403(self, client_auth):
resp = await client_auth.post(
"/api/chat",
headers={"Authorization": "Bearer wrong-key"},
json={"model": "m", "messages": [{"role": "user", "content": "hi"}]},
)
assert resp.status_code == 403
assert "Invalid" in resp.json()["detail"]
async def test_valid_key_passes_middleware(self, client_auth):
# /api/usage reads in-memory counters only — no backend call needed
resp = await client_auth.get(
"/api/usage",
headers={"Authorization": "Bearer test-secret-key"},
)
assert resp.status_code == 200
async def test_key_via_query_param(self, client_auth):
resp = await client_auth.get("/api/usage?api_key=test-secret-key")
assert resp.status_code == 200
async def test_options_bypasses_auth(self, client_auth):
resp = await client_auth.options("/api/chat")
assert resp.status_code not in (401, 403)
async def test_root_path_bypasses_auth(self, client_auth):
resp = await client_auth.get("/")
assert resp.status_code not in (401, 403)
async def test_favicon_bypasses_auth(self, client_auth):
resp = await client_auth.get("/favicon.ico")
# Should not be blocked by auth (may 404 in test but not 401/403)
assert resp.status_code not in (401, 403)

View file

@ -1,333 +0,0 @@
"""Unit tests for cache.LLMCache in exact-match mode (no sentence-transformers needed)."""
import tempfile
from pathlib import Path
from types import SimpleNamespace
import orjson
import pytest
import cache as cache_mod
from cache import (
LLMCache,
_bm25_weighted_text,
get_llm_cache,
init_llm_cache,
openai_nonstream_to_sse,
)
_CACHE_DB_PATH = str(Path(tempfile.gettempdir()) / "nomyo_test_cache.db")
def _exact_cfg(backend: str = "memory") -> SimpleNamespace:
"""Config for exact-match mode — similarity=1.0 avoids embedding deps."""
return SimpleNamespace(
cache_enabled=True,
cache_backend=backend,
cache_similarity=1.0,
cache_history_weight=0.3,
cache_ttl=300,
cache_db_path=_CACHE_DB_PATH,
cache_redis_url="redis://localhost:6379",
)
# ──────────────────────────────────────────────────────────────────────────────
# Pure helpers
# ──────────────────────────────────────────────────────────────────────────────
class TestBM25WeightedText:
def test_empty_history(self):
assert _bm25_weighted_text([]) == ""
def test_history_without_content(self):
assert _bm25_weighted_text([{"role": "user"}, {"role": "assistant"}]) == ""
def test_repeats_high_idf_terms(self):
history = [
{"role": "user", "content": "Tell me about quantum entanglement"},
{"role": "assistant", "content": "Quantum entanglement is a phenomenon"},
{"role": "user", "content": "How does entanglement work?"},
]
out = _bm25_weighted_text(history)
# Rare/domain term ("entanglement") should appear; short stopwords (<=2 chars) dropped
assert "entanglement" in out
assert "is" not in out.split()
# ──────────────────────────────────────────────────────────────────────────────
# openai_nonstream_to_sse
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAINonstreamToSSE:
def test_valid_chat_completion(self):
chat = {
"id": "x1",
"created": 123,
"model": "gpt-4o",
"choices": [{"message": {"role": "assistant", "content": "hello"}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3},
}
out = openai_nonstream_to_sse(orjson.dumps(chat), "gpt-4o")
text = out.decode()
assert text.startswith("data: ")
assert text.endswith("data: [DONE]\n\n")
# First chunk contains the original content
first = text.split("\n\n")[0][len("data: "):]
parsed = orjson.loads(first)
assert parsed["choices"][0]["delta"]["content"] == "hello"
assert parsed["usage"]["total_tokens"] == 3
def test_corrupt_bytes_return_done_only(self):
out = openai_nonstream_to_sse(b"not-json", "m")
assert out == b"data: [DONE]\n\n"
# ──────────────────────────────────────────────────────────────────────────────
# LLMCache internal helpers
# ──────────────────────────────────────────────────────────────────────────────
class TestLLMCacheParsing:
def test_namespace_is_stable_and_isolated(self):
c = LLMCache(_exact_cfg())
a = c._namespace("chat", "m1", "system A")
b = c._namespace("chat", "m1", "system A")
assert a == b
assert c._namespace("chat", "m1", "system B") != a
assert c._namespace("generate", "m1", "system A") != a
assert len(a) == 16
def test_parse_messages_flat_strings(self):
c = LLMCache(_exact_cfg())
sys, hist, last = c._parse_messages([
{"role": "system", "content": "be helpful"},
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
{"role": "user", "content": "what is 2+2?"},
])
assert sys == "be helpful"
assert last == "what is 2+2?"
assert hist == [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
def test_parse_messages_multimodal_content(self):
c = LLMCache(_exact_cfg())
sys, _hist, last = c._parse_messages([
{"role": "system", "content": "sys"},
{"role": "user", "content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]},
])
assert sys == "sys"
assert last == "describe"
def test_parse_messages_no_user_message(self):
c = LLMCache(_exact_cfg())
sys, hist, last = c._parse_messages([
{"role": "system", "content": "sys only"},
])
assert sys == "sys only"
assert last == ""
assert hist == []
class TestPersonalTokenExtraction:
def test_email_extracted(self):
c = LLMCache(_exact_cfg())
toks = c._extract_personal_tokens("Reach me at alice@example.com please")
assert "alice@example.com" in toks
def test_numeric_id_after_keyword(self):
c = LLMCache(_exact_cfg())
toks = c._extract_personal_tokens("User id: 123456")
assert "123456" in toks
def test_identity_tag_names_extracted(self):
c = LLMCache(_exact_cfg())
toks = c._extract_personal_tokens(
"[Tags: identity] User's name is Andreas Schwibbe"
)
# Both name tokens should be extracted lowercased; stopwords dropped
assert "andreas" in toks
assert "schwibbe" in toks
assert "name" not in toks # in _IDENTITY_STOPWORDS
assert "user" not in toks
def test_empty_system_returns_empty_set(self):
c = LLMCache(_exact_cfg())
assert c._extract_personal_tokens("") == frozenset()
class TestResponseIsPersonalized:
def _resp(self, content: str) -> bytes:
return orjson.dumps({"choices": [{"message": {"content": content}}]})
def test_email_in_response_is_personalized(self):
c = LLMCache(_exact_cfg())
assert c._response_is_personalized(self._resp("contact bob@x.com"), "")
def test_uuid_in_response_is_personalized(self):
c = LLMCache(_exact_cfg())
uuid = "550e8400-e29b-41d4-a716-446655440000"
assert c._response_is_personalized(self._resp(f"id={uuid}"), "")
def test_long_numeric_id_in_response_is_personalized(self):
c = LLMCache(_exact_cfg())
assert c._response_is_personalized(self._resp("account 12345678"), "")
def test_identity_token_from_system_echoed_in_response(self):
c = LLMCache(_exact_cfg())
system = "[Tags: identity] Andreas works here"
assert c._response_is_personalized(
self._resp("Yes, Andreas is logged in"), system
)
def test_generic_response_not_personalized(self):
c = LLMCache(_exact_cfg())
assert not c._response_is_personalized(
self._resp("The capital of France is Paris."), "be helpful"
)
def test_ollama_message_format_parsed(self):
c = LLMCache(_exact_cfg())
body = orjson.dumps({"message": {"content": "alice@example.com"}})
assert c._response_is_personalized(body, "")
def test_unparseable_body_with_bytes_is_conservative(self):
c = LLMCache(_exact_cfg())
# Can't parse → returns True (err on the side of privacy)
assert c._response_is_personalized(b"binary-junk", "")
def test_empty_response_not_personalized(self):
c = LLMCache(_exact_cfg())
assert not c._response_is_personalized(b"", "anything")
# ──────────────────────────────────────────────────────────────────────────────
# End-to-end exact-match cache with the memory backend
# ──────────────────────────────────────────────────────────────────────────────
@pytest.fixture
async def memcache():
"""LLMCache wired up with the in-memory backend (no external deps)."""
c = LLMCache(_exact_cfg("memory"))
await c.init()
return c
class TestExactMatchCache:
async def test_miss_then_set_then_hit(self, memcache):
msgs = [
{"role": "system", "content": "be helpful"},
{"role": "user", "content": "what is 2+2?"},
]
resp = orjson.dumps({"choices": [{"message": {"content": "4"}}]})
assert await memcache.get_chat("chat", "m1", msgs) is None
await memcache.set_chat("chat", "m1", msgs, resp)
hit = await memcache.get_chat("chat", "m1", msgs)
assert hit == resp
async def test_namespace_isolation_by_system(self, memcache):
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
msgs_a = [
{"role": "system", "content": "system A"},
{"role": "user", "content": "same question"},
]
msgs_b = [
{"role": "system", "content": "system B"},
{"role": "user", "content": "same question"},
]
await memcache.set_chat("chat", "m", msgs_a, resp)
# Same question + different system prompt = different namespace = miss
assert await memcache.get_chat("chat", "m", msgs_b) is None
async def test_namespace_isolation_by_route(self, memcache):
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
msgs = [{"role": "user", "content": "ping"}]
await memcache.set_chat("chat", "m", msgs, resp)
assert await memcache.get_chat("openai_chat", "m", msgs) is None
async def test_no_user_message_is_noop(self, memcache):
msgs = [{"role": "system", "content": "sys only"}]
resp = orjson.dumps({"choices": [{"message": {"content": "x"}}]})
# Both get and set should silently no-op
assert await memcache.get_chat("chat", "m", msgs) is None
await memcache.set_chat("chat", "m", msgs, resp)
assert await memcache.get_chat("chat", "m", msgs) is None
async def test_personalized_response_generic_system_not_stored(self, memcache):
msgs = [
{"role": "system", "content": "be helpful"}, # generic
{"role": "user", "content": "give me an email"},
]
# Response contains an email → would leak across users sharing the
# generic namespace → must NOT be stored at all
resp = orjson.dumps({"choices": [{"message": {"content": "bob@x.com"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
assert await memcache.get_chat("chat", "m", msgs) is None
async def test_personalized_response_user_specific_system_stored(self, memcache):
msgs = [
{"role": "system", "content": "User id: 998877 prefers concise answers"},
{"role": "user", "content": "what is my id?"},
]
resp = orjson.dumps({"choices": [{"message": {"content": "Your id is 998877"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
# User-specific namespace → exact-match within this user is OK
assert await memcache.get_chat("chat", "m", msgs) == resp
async def test_generate_convenience_wrappers(self, memcache):
resp = orjson.dumps({"response": "blue"})
await memcache.set_generate("m", "what color is the sky?", "", resp)
assert await memcache.get_generate("m", "what color is the sky?") == resp
class TestStatsAndClear:
async def test_stats_tracks_hits_and_misses(self, memcache):
msgs = [{"role": "user", "content": "hello"}]
await memcache.get_chat("chat", "m", msgs) # miss
resp = orjson.dumps({"choices": [{"message": {"content": "hi"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
await memcache.get_chat("chat", "m", msgs) # hit
s = memcache.stats()
assert s["hits"] == 1
assert s["misses"] == 1
assert s["hit_rate"] == 0.5
assert s["semantic"] is False
assert s["backend"] == "memory"
async def test_clear_resets_counters_and_storage(self, memcache):
msgs = [{"role": "user", "content": "hi"}]
resp = orjson.dumps({"choices": [{"message": {"content": "ok"}}]})
await memcache.set_chat("chat", "m", msgs, resp)
await memcache.get_chat("chat", "m", msgs)
await memcache.clear()
s = memcache.stats()
assert s["hits"] == 0
assert s["misses"] == 0
assert await memcache.get_chat("chat", "m", msgs) is None
# ──────────────────────────────────────────────────────────────────────────────
# Module-level helpers
# ──────────────────────────────────────────────────────────────────────────────
class TestInitLLMCache:
async def test_disabled_returns_none(self):
cfg = _exact_cfg()
cfg.cache_enabled = False
result = await init_llm_cache(cfg)
assert result is None
async def test_enabled_returns_initialized_cache(self):
cfg = _exact_cfg()
try:
result = await init_llm_cache(cfg)
assert result is not None
assert get_llm_cache() is result
finally:
# Reset singleton between tests
cache_mod._cache = None

View file

@ -1,479 +0,0 @@
"""Tests for choose_endpoint routing logic with mocked fetch calls."""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import router
EP1 = "http://ep1:11434"
EP2 = "http://ep2:11434"
EP3 = "http://ep3:11434"
LLAMA_EP = "http://llama:8080/v1"
def _make_cfg(endpoints, llama_eps=None, swap_eps=None, max_conn=2, endpoint_config=None, priority_routing=False):
cfg = MagicMock()
cfg.endpoints = endpoints
cfg.llama_server_endpoints = llama_eps or []
cfg.llama_swap_endpoints = swap_eps or []
cfg.api_keys = {}
cfg.max_concurrent_connections = max_conn
cfg.endpoint_config = endpoint_config or {}
cfg.priority_routing = priority_routing
cfg.router_api_key = None
return cfg
@pytest.fixture(autouse=True)
def reset_usage():
"""Clear usage_counts and error caches between tests to prevent bleed."""
router.usage_counts.clear()
router._loaded_error_cache.clear()
yield
router.usage_counts.clear()
router._loaded_error_cache.clear()
class TestChooseEndpointBasic:
async def test_selects_single_candidate(self):
cfg = _make_cfg([EP1])
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"llama3.2:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
):
ep, tracking = await router.choose_endpoint("llama3.2:latest")
assert ep == EP1
assert tracking == "llama3.2:latest"
async def test_llama_swap_endpoint_is_a_candidate(self):
swap_ep = "http://swap:8080/v1"
cfg = _make_cfg([EP1], swap_eps=[swap_ep])
async def available(ep, *_):
# Only the llama-swap backend advertises this model
return {"org/model:Q4_K_M"} if ep == swap_ep else set()
async def loaded(ep):
return {"org/model:Q4_K_M"} if ep == swap_ep else set()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, tracking = await router.choose_endpoint("org/model:Q4_K_M")
assert ep == swap_ep
# llama-swap models are tracked under their normalized name
assert tracking == "model"
async def test_raises_when_no_endpoint_has_model(self):
cfg = _make_cfg([EP1, EP2])
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value=set())),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
with pytest.raises(RuntimeError, match="advertise the model"):
await router.choose_endpoint("unknown-model:latest")
async def test_prefers_loaded_endpoint(self):
cfg = _make_cfg([EP1, EP2])
async def available(ep, *_):
return {"llama3.2:latest"}
async def loaded(ep):
return {"llama3.2:latest"} if ep == EP2 else set()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
assert ep == EP2
async def test_falls_back_to_free_slot(self):
cfg = _make_cfg([EP1, EP2])
async def available(ep, *_):
return {"llama3.2:latest"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
assert ep in (EP1, EP2)
async def test_cold_model_avoids_backend_busy_with_other_model(self):
# Regression: heterogeneous cluster. A cold model B (loaded nowhere)
# must not be routed to a backend already serving a *different* model
# while other backends sit idle. The step-4 idle check used to look at
# per-model usage (zero everywhere for B) and discard the total-load
# ranking, so B could land on the busy backend at random.
cfg = _make_cfg([EP1, EP2, EP3], max_conn=4)
async def available(ep, *_):
return {"model-a:latest", "model-b:latest"}
# EP3 is busy with model A; EP1 and EP2 are completely idle. Model B
# is loaded nowhere.
router.usage_counts[EP3]["model-a:latest"] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
# Run repeatedly: the busy backend must be excluded every time,
# the idle two share the load at random.
for _ in range(50):
ep, _ = await router.choose_endpoint("model-b:latest", reserve=False)
assert ep in (EP1, EP2)
assert ep != EP3
async def test_two_cold_models_spread_across_backends(self):
# Regression: 3 backends all advertise all models. Two *different*
# cold models requested back-to-back must land on *different*
# backends. Once model-a is resident on the chosen backend (infinite
# keep-alive), its in-flight count drops back to 0 — so only the
# resident-model count distinguishes the backends. Without it, the
# second cold model would randomly re-collide on the busy backend.
cfg = _make_cfg([EP1, EP2, EP3], max_conn=4)
async def available(ep, *_):
return {"model-a:latest", "model-b:latest"}
# model-a finished loading on EP1 and stays resident; its request has
# completed so EP1 has zero in-flight load, same as EP2/EP3.
loaded = {EP1: {"model-a:latest"}, EP2: set(), EP3: set()}
async def loaded_models(ep):
return loaded[ep]
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded_models),
):
# A cold model-b must avoid EP1 (which already holds model-a) and
# go to one of the empty backends, every time.
for _ in range(50):
ep, _ = await router.choose_endpoint("model-b:latest", reserve=False)
assert ep in (EP2, EP3)
assert ep != EP1
async def test_saturated_picks_least_busy(self):
cfg = _make_cfg([EP1, EP2])
cfg.max_concurrent_connections = 1
async def available(ep, *_):
return {"llama3.2:latest"}
# Saturate EP1 with 2 active connections, EP2 with 1
router.usage_counts[EP1]["llama3.2:latest"] = 2
router.usage_counts[EP2]["llama3.2:latest"] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
# Least-busy is EP2
assert ep == EP2
async def test_excludes_endpoint_with_recent_loaded_error(self):
# Regression: issue #83 — when /api/ps fails for EP1 but EP1
# still advertises the model via /api/tags, routing must not
# fall back to EP1 just because it has a free slot.
cfg = _make_cfg([EP1, EP2])
async def available(ep, *_):
return {"llama3.2:latest"}
# EP1's /api/ps probe failed recently; EP2 is fine but the model
# is not loaded there. Without the health filter, EP1 would be
# picked by the free-slot fallback (step 4 in choose_endpoint).
router._loaded_error_cache[EP1] = time.time()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("llama3.2:latest")
assert ep == EP2
async def test_stale_loaded_error_does_not_exclude(self):
# Errors older than the 300s window must not keep an endpoint
# excluded forever.
cfg = _make_cfg([EP1])
router._loaded_error_cache[EP1] = time.time() - 301
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"m:latest"})),
):
ep, _ = await router.choose_endpoint("m:latest")
assert ep == EP1
async def test_all_unhealthy_still_routes(self):
# If every candidate has a fresh loaded-error we still try one
# (it may have recovered between the cache write and now) rather
# than refusing to route.
cfg = _make_cfg([EP1])
router._loaded_error_cache[EP1] = time.time()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"m:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value=set())),
):
ep, _ = await router.choose_endpoint("m:latest")
assert ep == EP1
async def test_reserve_increments_usage(self):
cfg = _make_cfg([EP1])
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(return_value={"model:latest"})),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"model:latest"})),
):
ep, tracking = await router.choose_endpoint("model:latest", reserve=True)
assert router.usage_counts[ep][tracking] == 1
class TestChooseEndpointModelNaming:
async def test_strips_latest_for_openai_endpoints(self):
cfg = _make_cfg(endpoints=[], llama_eps=[LLAMA_EP])
cfg.endpoints = []
async def available(ep, *_):
# llama-server advertises without :latest
return {"gpt-4o"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"gpt-4o"})),
):
ep, _ = await router.choose_endpoint("gpt-4o:latest")
assert ep == LLAMA_EP
async def test_adds_latest_for_ollama_when_bare_name(self):
cfg = _make_cfg([EP1])
async def available(ep, *_):
return {"llama3.2:latest"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", AsyncMock(return_value={"llama3.2:latest"})),
):
ep, _ = await router.choose_endpoint("llama3.2")
assert ep == EP1
class TestChooseEndpointLoadBalancing:
async def test_random_selection_among_idle(self):
cfg = _make_cfg([EP1, EP2, EP3])
selected = set()
async def available(ep, *_):
return {"model:latest"}
async def loaded(ep):
return {"model:latest"}
for _ in range(20):
router.usage_counts.clear()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, _ = await router.choose_endpoint("model:latest", reserve=False)
selected.add(ep)
# With 20 draws from 3 idle endpoints, all three should appear
assert len(selected) > 1
async def test_sort_by_load_ascending(self):
cfg = _make_cfg([EP1, EP2])
router.usage_counts[EP1]["model:latest"] = 1
router.usage_counts[EP2]["model:latest"] = 0
async def available(ep, *_):
return {"model:latest"}
async def loaded(ep):
return {"model:latest"}
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", side_effect=available),
patch.object(router.fetch, "loaded_models", side_effect=loaded),
):
ep, _ = await router.choose_endpoint("model:latest", reserve=False)
# EP2 has fewer active connections → should be selected
assert ep == EP2
# ---------------------------------------------------------------------------
# get_max_connections unit tests
# ---------------------------------------------------------------------------
class TestGetMaxConnections:
def test_returns_global_default_when_no_override(self):
cfg = _make_cfg([EP1, EP2], max_conn=3)
with patch.object(router, "config", cfg):
assert router.get_max_connections(EP1) == 3
assert router.get_max_connections(EP2) == 3
def test_returns_per_endpoint_override(self):
cfg = _make_cfg(
[EP1, EP2],
max_conn=2,
endpoint_config={EP1: {"max_concurrent_connections": 5}},
)
with patch.object(router, "config", cfg):
assert router.get_max_connections(EP1) == 5
assert router.get_max_connections(EP2) == 2 # falls back to global
def test_unrecognised_endpoint_falls_back_to_global(self):
cfg = _make_cfg([EP1], max_conn=4, endpoint_config={EP2: {"max_concurrent_connections": 1}})
with patch.object(router, "config", cfg):
assert router.get_max_connections(EP3) == 4
# ---------------------------------------------------------------------------
# Priority / WRR routing tests
# ---------------------------------------------------------------------------
MODEL = "model:latest"
def _all_loaded(ep):
"""Side-effect: every endpoint advertises and has MODEL loaded."""
return {MODEL}
class TestPriorityRouting:
"""Tests for priority_routing=True (WRR + config-order tiebreaking)."""
async def test_idle_picks_first_in_config_order(self):
"""When all endpoints are idle, priority picks the first listed endpoint."""
cfg = _make_cfg([EP1, EP2, EP3], priority_routing=True)
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP1
async def test_lower_utilization_preferred_over_priority(self):
"""An endpoint with lower ratio is preferred even if it has lower priority."""
cfg = _make_cfg([EP1, EP2], priority_routing=True)
# EP1 (priority 0) is busier: 1/2 = 0.5; EP2 (priority 1) is idle: 0/2 = 0.0
router.usage_counts[EP1][MODEL] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP2
async def test_wrr_distribution_matches_expected_sequence(self):
"""
Full WRR sequence with heterogeneous capacities, mirroring the issue example:
EP1 max=2, EP2 max=2, EP3 max=1
Expected routing order for 5 sequential requests:
EP1, EP2, EP3, EP1, EP2
"""
cfg = _make_cfg(
[EP1, EP2, EP3],
max_conn=2,
endpoint_config={EP3: {"max_concurrent_connections": 1}},
priority_routing=True,
)
expected = [EP1, EP2, EP3, EP1, EP2]
actual = []
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
for _ in expected:
ep, _ = await router.choose_endpoint(MODEL, reserve=True)
actual.append(ep)
assert actual == expected
async def test_saturated_picks_lowest_ratio_then_priority(self):
"""When all endpoints are saturated, pick lowest utilization ratio; break ties by priority."""
cfg = _make_cfg(
[EP1, EP2, EP3],
max_conn=1,
endpoint_config={EP3: {"max_concurrent_connections": 2}},
priority_routing=True,
)
# EP1 usage=1/1=1.0, EP2 usage=1/1=1.0, EP3 usage=1/2=0.5 → EP3 wins
router.usage_counts[EP1][MODEL] = 1
router.usage_counts[EP2][MODEL] = 1
router.usage_counts[EP3][MODEL] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP3
async def test_saturated_ties_broken_by_priority(self):
"""When all are saturated with equal ratio, config order wins."""
cfg = _make_cfg([EP1, EP2, EP3], max_conn=1, priority_routing=True)
router.usage_counts[EP1][MODEL] = 1
router.usage_counts[EP2][MODEL] = 1
router.usage_counts[EP3][MODEL] = 1
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
assert ep == EP1
class TestPriorityRoutingDisabled:
"""Verify that priority_routing=False keeps the original random behaviour."""
async def test_idle_endpoints_are_randomised(self):
"""Without priority routing, all-idle selection must eventually pick each endpoint."""
cfg = _make_cfg([EP1, EP2, EP3], priority_routing=False)
selected = set()
with (
patch.object(router, "config", cfg),
patch.object(router.fetch, "available_models", AsyncMock(side_effect=_all_loaded)),
patch.object(router.fetch, "loaded_models", AsyncMock(side_effect=_all_loaded)),
):
for _ in range(30):
router.usage_counts.clear()
ep, _ = await router.choose_endpoint(MODEL, reserve=False)
selected.add(ep)
# With 30 draws from 3 equally-idle endpoints, all three must appear
assert selected == {EP1, EP2, EP3}

View file

@ -1,197 +0,0 @@
"""Direct unit tests for db.TokenDatabase — no router/app dependency."""
from datetime import datetime, timezone
import pytest
from db import TokenDatabase
@pytest.fixture
async def db(tmp_path):
inst = TokenDatabase(str(tmp_path / "tokens.db"))
await inst.init_db()
yield inst
await inst.close()
class TestInit:
async def test_init_creates_tables(self, db):
# Re-init must be idempotent
await db.init_db()
# Insert + read confirms tables exist
await db.update_token_counts("http://ep", "m", 1, 2)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 1
async def test_creates_parent_directory(self, tmp_path):
nested = tmp_path / "nested" / "subdir" / "x.db"
inst = TokenDatabase(str(nested))
await inst.init_db()
try:
assert nested.parent.exists()
finally:
await inst.close()
class TestUpdateTokenCounts:
async def test_insert_then_update_aggregates(self, db):
await db.update_token_counts("http://ep", "m1", 10, 20)
await db.update_token_counts("http://ep", "m1", 5, 7)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 1
r = rows[0]
assert r["endpoint"] == "http://ep"
assert r["model"] == "m1"
assert r["input_tokens"] == 15
assert r["output_tokens"] == 27
assert r["total_tokens"] == 42
async def test_independent_endpoint_model_pairs(self, db):
await db.update_token_counts("http://ep1", "m1", 1, 1)
await db.update_token_counts("http://ep1", "m2", 2, 2)
await db.update_token_counts("http://ep2", "m1", 3, 3)
rows = [r async for r in db.load_token_counts()]
assert len(rows) == 3
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
assert totals == {
("http://ep1", "m1"): 2,
("http://ep1", "m2"): 4,
("http://ep2", "m1"): 6,
}
class TestBatchedCounts:
async def test_update_batched_counts(self, db):
counts = {
"http://a": {"m": (4, 6)},
"http://b": {"m": (1, 1), "n": (10, 0)},
}
await db.update_batched_counts(counts)
rows = [r async for r in db.load_token_counts()]
totals = {(r["endpoint"], r["model"]): r["total_tokens"] for r in rows}
assert totals == {
("http://a", "m"): 10,
("http://b", "m"): 2,
("http://b", "n"): 10,
}
async def test_empty_batch_is_noop(self, db):
await db.update_batched_counts({})
rows = [r async for r in db.load_token_counts()]
assert rows == []
class TestTimeSeries:
async def test_add_time_series_entry(self, db):
# The aggregate FK requires the (endpoint,model) row to exist first
await db.update_token_counts("http://ep", "m", 0, 0)
await db.add_time_series_entry("http://ep", "m", 3, 4)
await db.add_time_series_entry("http://ep", "m", 1, 1)
rows = [r async for r in db.get_latest_time_series(limit=10)]
assert len(rows) == 2
# Newest-first ordering; both timestamps are within the same minute,
# so just check totals are present and well-formed
for r in rows:
assert r["endpoint"] == "http://ep"
assert r["model"] == "m"
assert r["total_tokens"] == r["input_tokens"] + r["output_tokens"]
async def test_add_batched_time_series(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
entries = [
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 2, "total_tokens": 3, "timestamp": now - 60},
{"endpoint": "http://ep", "model": "m", "input_tokens": 4,
"output_tokens": 5, "total_tokens": 9, "timestamp": now},
]
await db.add_batched_time_series(entries)
rows = [r async for r in db.get_latest_time_series(limit=10)]
assert len(rows) == 2
assert rows[0]["timestamp"] >= rows[1]["timestamp"]
async def test_get_time_series_for_model_filters(self, db):
await db.update_token_counts("http://ep", "m1", 0, 0)
await db.update_token_counts("http://ep", "m2", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m1", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://ep", "model": "m2", "input_tokens": 9,
"output_tokens": 9, "total_tokens": 18, "timestamp": now},
])
rows = [r async for r in db.get_time_series_for_model("m1")]
assert len(rows) == 1
assert rows[0]["total_tokens"] == 2
async def test_endpoint_distribution_for_model(self, db):
await db.update_token_counts("http://a", "m", 0, 0)
await db.update_token_counts("http://b", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
await db.add_batched_time_series([
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://a", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": now},
{"endpoint": "http://b", "model": "m", "input_tokens": 5,
"output_tokens": 5, "total_tokens": 10, "timestamp": now},
])
dist = await db.get_endpoint_distribution_for_model("m")
assert dist == {"http://a": 4, "http://b": 10}
class TestGetTokenCountsForModel:
async def test_aggregates_across_endpoints(self, db):
await db.update_token_counts("http://a", "m", 1, 2)
await db.update_token_counts("http://b", "m", 3, 4)
result = await db.get_token_counts_for_model("m")
assert result is not None
assert result["endpoint"] == "aggregated"
assert result["model"] == "m"
assert result["input_tokens"] == 4
assert result["output_tokens"] == 6
assert result["total_tokens"] == 10
async def test_unknown_model_returns_zero_aggregate(self, db):
# SUM(...) WHERE no-match returns one row with NULLs — exposed as zeros
result = await db.get_token_counts_for_model("nope")
assert result is not None
assert result["input_tokens"] in (0, None)
class TestAggregateTimeSeriesOlderThan:
async def test_aggregates_old_entries_by_day(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
old = now - (40 * 86400) # 40 days ago
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
{"endpoint": "http://ep", "model": "m", "input_tokens": 3,
"output_tokens": 3, "total_tokens": 6, "timestamp": old + 60},
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
])
n = await db.aggregate_time_series_older_than(30, trim_old=False)
assert n == 1 # one (endpoint, model, day) group rolled up
async def test_invalid_days_falls_back_to_30(self, db):
# Just ensure it doesn't blow up with a bogus value
n = await db.aggregate_time_series_older_than(0)
assert n == 0
async def test_trim_old_removes_aggregated_rows(self, db):
await db.update_token_counts("http://ep", "m", 0, 0)
now = int(datetime.now(tz=timezone.utc).timestamp())
old = now - (40 * 86400)
await db.add_batched_time_series([
{"endpoint": "http://ep", "model": "m", "input_tokens": 1,
"output_tokens": 1, "total_tokens": 2, "timestamp": old},
{"endpoint": "http://ep", "model": "m", "input_tokens": 99,
"output_tokens": 99, "total_tokens": 198, "timestamp": now},
])
await db.aggregate_time_series_older_than(30, trim_old=True)
remaining = [r async for r in db.get_latest_time_series(limit=10)]
# Only the recent (within-cutoff) row should remain
assert len(remaining) == 1
assert remaining[0]["total_tokens"] == 198

View file

@ -1,309 +0,0 @@
"""Tests for fetch.available_models and fetch.loaded_models.
The backend probes obtain their HTTP client via ``backends.probe.get_probe_session``
and only ever call ``async with client.get(url, headers=...) as resp``. We patch that
seam with a tiny fake session instead of mocking aiohttp's internals (aioresponses),
so the suite stays independent of aiohttp's private ClientResponse/ConnectionKey
structure across version bumps.
"""
import time
from contextlib import contextmanager
from unittest.mock import patch, MagicMock
import pytest
import router
import backends.probe as probe
from conftest import TEST_OLLAMA, TEST_LLAMA
MOCK_OLLAMA_EP = "http://mock-ollama:11434"
MOCK_LLAMA_EP = "http://mock-llama:8080/v1"
def _make_cfg(ollama_eps=None, llama_eps=None, swap_eps=None, api_keys=None):
cfg = MagicMock()
cfg.endpoints = ollama_eps or [MOCK_OLLAMA_EP]
cfg.llama_server_endpoints = llama_eps or [MOCK_LLAMA_EP]
cfg.llama_swap_endpoints = swap_eps or []
cfg.api_keys = api_keys or {}
cfg.max_concurrent_connections = 2
cfg.router_api_key = None
return cfg
# ── Fake probe session ────────────────────────────────────────────────────────
class _MockResponse:
"""Minimal stand-in for the aiohttp response used by the probes."""
def __init__(self, *, status=200, payload=None, text=None):
self.status = status
self._payload = payload
self._text = text if text is not None else ""
async def json(self):
return self._payload
async def text(self):
return self._text
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
return False
class _RaisingCtx:
"""``async with client.get(...)`` that raises on entry — mimics a failed connection."""
def __init__(self, exc):
self._exc = exc
async def __aenter__(self):
raise self._exc
async def __aexit__(self, *exc):
return False
class _MockProbeSession:
"""Stand-in for the aiohttp ClientSession returned by ``get_probe_session``.
Routes are registered by exact URL via :meth:`add_get`. A registered exception
is raised when the route is entered; otherwise a :class:`_MockResponse` is yielded.
An unregistered GET fails loudly so tests can't silently pass on a wrong URL.
"""
def __init__(self):
self._routes = {}
def add_get(self, url, *, status=200, payload=None, text=None, exception=None):
self._routes[url] = exception if exception is not None else _MockResponse(
status=status, payload=payload, text=text
)
def get(self, url, **kwargs):
if url not in self._routes:
raise AssertionError(f"unexpected probe GET {url}")
entry = self._routes[url]
return _RaisingCtx(entry) if isinstance(entry, Exception) else entry
@contextmanager
def mock_probe():
"""Patch the probe's session factory to return a fresh :class:`_MockProbeSession`."""
session = _MockProbeSession()
with patch.object(probe, "get_probe_session", lambda endpoint: session):
yield session
@pytest.fixture(autouse=True)
def clear_caches(aio_session):
"""aio_session fixture already clears caches and sets up app_state."""
yield
class TestFetchAvailableModels:
async def test_ollama_tags(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [
{"name": "llama3.2:latest"},
{"name": "qwen2.5:7b"},
]},
)
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == {"llama3.2:latest", "qwen2.5:7b"}
async def test_openai_compatible_models_endpoint(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_LLAMA_EP}/models",
payload={"data": [{"id": "unsloth/model:Q8_0"}]},
)
models = await router.fetch.available_models(MOCK_LLAMA_EP, api_key="tok")
assert "unsloth/model:Q8_0" in models
async def test_caches_successful_result(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]},
)
first = await router.fetch.available_models(MOCK_OLLAMA_EP)
second = await router.fetch.available_models(MOCK_OLLAMA_EP)
# second call must be served from cache without a second HTTP request
assert first == second == {"llama3.2:latest"}
async def test_returns_empty_on_http_500(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set()
async def test_returns_empty_on_connection_error(self):
import aiohttp
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags",
exception=aiohttp.ClientConnectionError(
"Cannot connect to host mock-ollama:11434 [Connection refused]"
),
)
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set()
async def test_stale_cache_returned_while_refresh_runs(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]},
)
await router.fetch.available_models(MOCK_OLLAMA_EP)
# Manually age cache into stale-but-valid window (300-600s)
async with router._models_cache_lock:
models, _ = router._models_cache[MOCK_OLLAMA_EP]
router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]},
)
# Should return stale data immediately
stale = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert "llama3.2:latest" in stale
async def test_error_cache_short_circuits(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
# Seed error cache with a very recent error
async with router._available_error_cache_lock:
router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
with patch.object(router, "config", cfg), mock_probe():
# No route registered — if a call happens it raises AssertionError
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set()
class TestFetchLoadedModels:
async def test_ollama_ps(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "llama3.2:latest"}]},
)
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert models == {"llama3.2:latest"}
async def test_llama_server_filters_loaded(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_LLAMA_EP}/models",
payload={"data": [
{"id": "model-a", "status": {"value": "loaded"}},
{"id": "model-b", "status": {"value": "unloaded"}},
]},
)
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert models == {"model-a"}
async def test_llama_server_no_status_field_always_loaded(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_LLAMA_EP}/models",
payload={"data": [{"id": "always-on-model"}]},
)
models = await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert "always-on-model" in models
async def test_llama_swap_reads_running_state_ready(self):
# llama-swap omits the /v1/models status field, so loaded workers come
# from /running (a root route — the /v1 suffix must be stripped).
swap_ep = "http://mock-swap:8080/v1"
cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
"http://mock-swap:8080/running",
payload={"running": [
{"model": "org/ready-model:Q4_K_M", "state": "ready"},
{"model": "org/starting-model:Q8_0", "state": "starting"},
]},
)
models = await router.fetch.loaded_models(swap_ep)
assert models == {"org/ready-model:Q4_K_M"}
async def test_llama_swap_records_error_on_failure(self):
swap_ep = "http://mock-swap:8080/v1"
cfg = _make_cfg(llama_eps=[], swap_eps=[swap_ep])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get("http://mock-swap:8080/running", status=502, payload={})
await router.fetch.loaded_models(swap_ep)
assert swap_ep in router._loaded_error_cache
async def test_returns_empty_on_error(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert models == set()
async def test_ext_openai_always_empty(self):
ext_ep = "https://api.openai.com/v1"
cfg = _make_cfg(ollama_eps=[ext_ep], llama_eps=[])
with patch.object(router, "config", cfg):
models = await router.fetch.loaded_models(ext_ep)
assert models == set()
async def test_caches_result(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "qwen:7b"}]},
)
first = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
second = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert first == second
async def test_records_error_in_loaded_error_cache_on_failure(self):
# Regression: issue #83 — /api/ps failures must be recorded so
# `choose_endpoint` can exclude unhealthy backends from routing.
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert MOCK_OLLAMA_EP in router._loaded_error_cache
async def test_records_error_for_llama_server_on_failure(self):
cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert MOCK_LLAMA_EP in router._loaded_error_cache
async def test_clears_error_cache_on_subsequent_success(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
# Pre-seed an old error so loaded_models() falls through to the
# network probe instead of short-circuiting on the error cache.
async with router._loaded_error_cache_lock:
router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301
with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "qwen:7b"}]},
)
await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert MOCK_OLLAMA_EP not in router._loaded_error_cache

View file

@ -1,131 +0,0 @@
"""Tests for llama-swap specific behavior: unload dispatch + /upstream resolution."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import router
import backends.control as control
import api.openai as openai_api
import api.ollama as ollama_api
SWAP_EP = "http://swap:8080/v1"
SERVER_EP = "http://server:8080/v1"
def _cfg(*, server=None, swap=None, api_keys=None):
cfg = MagicMock()
cfg.endpoints = []
cfg.llama_server_endpoints = server or []
cfg.llama_swap_endpoints = swap or []
cfg.api_keys = api_keys or {}
return cfg
class _RecordingSession:
"""Captures the most recent ``post`` call and returns a 200 response."""
def __init__(self, status=200):
self.calls = []
self._status = status
def post(self, url, **kwargs):
self.calls.append((url, kwargs))
resp = MagicMock()
resp.status = self._status
class _Ctx:
async def __aenter__(self_):
return resp
async def __aexit__(self_, *exc):
return False
return _Ctx()
class TestUnloadDispatch:
async def test_llama_swap_uses_path_param(self):
sess = _RecordingSession()
cfg = _cfg(swap=[SWAP_EP])
with (
patch.object(router, "config", cfg),
patch.object(control, "get_probe_session", lambda ep: sess),
):
ok = await control.unload_model(SWAP_EP, "org/model:Q4_K_M")
assert ok is True
url, kwargs = sess.calls[0]
# /v1 stripped, model id is a path param, no JSON body
assert url == "http://swap:8080/api/models/unload/org/model:Q4_K_M"
assert kwargs.get("json") is None
async def test_llama_server_uses_body(self):
sess = _RecordingSession()
cfg = _cfg(server=[SERVER_EP])
with (
patch.object(router, "config", cfg),
patch.object(control, "get_probe_session", lambda ep: sess),
):
ok = await control.unload_model(SERVER_EP, "org/model:Q4_K_M")
assert ok is True
url, kwargs = sess.calls[0]
assert url == "http://server:8080/models/unload"
assert kwargs.get("json") == {"model": "org/model:Q4_K_M"}
async def test_unload_failure_returns_false(self):
sess = _RecordingSession(status=500)
cfg = _cfg(swap=[SWAP_EP])
with (
patch.object(router, "config", cfg),
patch.object(control, "get_probe_session", lambda ep: sess),
):
ok = await control.unload_model(SWAP_EP, "m")
assert ok is False
class TestUpstreamResolution:
async def test_resolves_endpoint_that_advertises_model(self):
cfg = _cfg(swap=[SWAP_EP])
with (
patch.object(openai_api, "get_config", lambda: cfg),
patch.object(openai_api.fetch, "available_models",
AsyncMock(return_value={"org/model:Q4_K_M"})),
):
ep = await openai_api._resolve_llama_swap_endpoint("org/model:Q4_K_M")
assert ep == SWAP_EP
async def test_returns_none_when_unserved(self):
cfg = _cfg(swap=[SWAP_EP])
with (
patch.object(openai_api, "get_config", lambda: cfg),
patch.object(openai_api.fetch, "available_models",
AsyncMock(return_value=set())),
):
ep = await openai_api._resolve_llama_swap_endpoint("missing")
assert ep is None
async def test_returns_none_without_swap_endpoints(self):
cfg = _cfg(swap=[])
with patch.object(openai_api, "get_config", lambda: cfg):
ep = await openai_api._resolve_llama_swap_endpoint("any")
assert ep is None
class TestCtxSizeFromCmd:
"""ctx-size parsing from a /running worker's launch `cmd` string."""
def test_parses_long_flag(self):
cmd = ("llama-server --port 5818\n -hf unsloth/gpt-oss-20b-GGUF:F16\n"
" --ctx-size 131072\n --temp 1.0\n")
assert ollama_api._ctx_size_from_cmd(cmd) == 131072
def test_parses_short_flag(self):
assert ollama_api._ctx_size_from_cmd("llama-server -c 8192 --port 1") == 8192
def test_parses_equals_form(self):
assert ollama_api._ctx_size_from_cmd("llama-server --ctx-size=4096") == 4096
def test_returns_none_when_absent(self):
assert ollama_api._ctx_size_from_cmd("llama-server --port 5818") is None
def test_returns_none_for_empty(self):
assert ollama_api._ctx_size_from_cmd("") is None

View file

@ -1,182 +0,0 @@
"""Cache-hit short-circuit tests for the OpenAI-compatible proxy routes.
These tests verify that when the LLM cache reports a hit, the route returns
the cached payload *without* selecting an endpoint or contacting any backend.
"""
from unittest.mock import AsyncMock, patch
import orjson
import pytest
from fastapi import HTTPException
import router
from api import openai as api_openai
_BYPASS = HTTPException(status_code=599, detail="bypassed")
class _FakeCache:
"""Minimal stand-in for cache.LLMCache.get_chat."""
def __init__(self, response_bytes: bytes | None):
self._resp = response_bytes
self.calls: list[tuple] = []
async def get_chat(self, route, model, messages):
self.calls.append((route, model, messages))
return self._resp
@pytest.fixture
def cache_hit_payload():
return orjson.dumps({
"id": "cmpl-xyz",
"created": 1,
"model": "test-model",
"choices": [{"message": {"role": "assistant", "content": "from-cache"}}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
})
# ──────────────────────────────────────────────────────────────────────────────
# /v1/chat/completions
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAIChatCompletionsCacheHit:
async def test_nonstream_cache_hit_returns_cached_json(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
# Patch the route's references to both helpers — they're imported by name
# into router's namespace at module load time.
with (
patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "ping"}],
"stream": False,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
# Body is streamed; collect it
body = resp.content
parsed = orjson.loads(body)
assert parsed["choices"][0]["message"]["content"] == "from-cache"
assert fake.calls and fake.calls[0][0] == "openai_chat"
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "ping"}],
"stream": True,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
text = resp.content.decode()
# First SSE frame contains the cached content as a delta
first_frame = text.split("\n\n")[0]
assert first_frame.startswith("data: ")
chunk = orjson.loads(first_frame[len("data: "):])
assert chunk["choices"][0]["delta"]["content"] == "from-cache"
# Stream is terminated with [DONE]
assert "data: [DONE]" in text
async def test_cache_disabled_in_payload_bypasses_cache_check(self, client):
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
fake = _FakeCache(b"") # has a response, but should never be consulted
with (
patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "m",
"messages": [{"role": "user", "content": "hi"}],
"nomyo": {"cache": False},
},
)
# Got past the cache short-circuit → endpoint selection invoked
assert resp.status_code == 599
assert fake.calls == []
async def test_no_cache_configured_bypasses_cache_check(self, client):
"""get_llm_cache() returning None should not break the route."""
with (
patch.object(api_openai, "get_llm_cache", return_value=None),
patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)),
):
resp = await client.post(
"/v1/chat/completions",
json={
"model": "m",
"messages": [{"role": "user", "content": "hi"}],
"nomyo": {"cache": True},
},
)
assert resp.status_code == 599
# ──────────────────────────────────────────────────────────────────────────────
# /v1/completions
# ──────────────────────────────────────────────────────────────────────────────
class TestOpenAICompletionsCacheHit:
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/completions",
json={
"model": "test-model",
"prompt": "Tell me a joke",
"stream": False,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
# Prompt-style cache lookup is namespaced under "openai_completions"
assert fake.calls[0][0] == "openai_completions"
# Cache lookup receives the prompt as a single user message
cached_msgs = fake.calls[0][2]
assert cached_msgs == [{"role": "user", "content": "Tell me a joke"}]
async def test_stream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload)
with (
patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
):
resp = await client.post(
"/v1/completions",
json={
"model": "test-model",
"prompt": "What is 2+2?",
"stream": True,
"nomyo": {"cache": True},
},
)
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
assert "data: [DONE]" in resp.content.decode()

View file

@ -1,460 +0,0 @@
"""Tests for the OpenAI Responses API support (api/responses.py + requests/responses.py).
Covers the pure translation layer, the translated (Ollama-style) and native
(external-OpenAI) backend paths, conversation storage / chaining, background mode,
and the retrieve / delete / cancel routes.
"""
import asyncio
from contextlib import ExitStack, contextmanager
from types import SimpleNamespace as NS
from unittest.mock import AsyncMock, MagicMock, patch
import orjson
import pytest
import router
from api import responses as api_responses
from requests import responses as rt
# ──────────────────────────────────────────────────────────────────────────────
# Pure translation unit tests (no app / no I/O)
# ──────────────────────────────────────────────────────────────────────────────
class TestTranslationInputToMessages:
def test_string_input(self):
msgs = rt.responses_input_to_messages("hello")
assert msgs == [{"role": "user", "content": "hello"}]
def test_instructions_become_system(self):
msgs = rt.responses_input_to_messages("hi", instructions="be brief")
assert msgs[0] == {"role": "system", "content": "be brief"}
assert msgs[1] == {"role": "user", "content": "hi"}
def test_item_list_text_and_image(self):
items = [{
"type": "message", "role": "user",
"content": [
{"type": "input_text", "text": "describe"},
{"type": "input_image", "image_url": "http://x/y.png"},
],
}]
msgs = rt.responses_input_to_messages(items)
assert msgs[0]["role"] == "user"
assert msgs[0]["content"] == [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "http://x/y.png"}},
]
def test_single_text_part_collapses_to_string(self):
items = [{"type": "message", "role": "user",
"content": [{"type": "input_text", "text": "yo"}]}]
assert rt.responses_input_to_messages(items)[0]["content"] == "yo"
def test_function_call_roundtrip(self):
items = [
{"type": "function_call", "call_id": "c1", "name": "get", "arguments": "{\"x\":1}"},
{"type": "function_call_output", "call_id": "c1", "output": "42"},
]
msgs = rt.responses_input_to_messages(items)
assert msgs[0]["role"] == "assistant"
assert msgs[0]["tool_calls"][0]["id"] == "c1"
assert msgs[0]["tool_calls"][0]["function"]["name"] == "get"
assert msgs[1] == {"role": "tool", "tool_call_id": "c1", "content": "42"}
class TestTranslationResponseDirection:
def test_chat_message_to_output_items_text(self):
items = rt.chat_message_to_output_items({"role": "assistant", "content": "hi there"})
assert len(items) == 1
assert items[0]["type"] == "message"
assert items[0]["content"][0] == {"type": "output_text", "text": "hi there", "annotations": []}
def test_chat_message_to_output_items_tool_call(self):
items = rt.chat_message_to_output_items({
"role": "assistant", "content": None,
"tool_calls": [{"id": "c9", "function": {"name": "f", "arguments": "{}"}}],
})
assert items[0]["type"] == "function_call"
assert items[0]["call_id"] == "c9"
assert items[0]["name"] == "f"
def test_usage_mapping(self):
u = rt.usage_chat_to_responses({"prompt_tokens": 7, "completion_tokens": 3})
assert u == {"input_tokens": 7, "output_tokens": 3, "total_tokens": 10}
def test_build_response_object_output_text(self):
items = rt.chat_message_to_output_items({"role": "assistant", "content": "abc"})
obj = rt.build_response_object(response_id="resp_1", model="m", output_items=items)
assert obj["object"] == "response"
assert obj["output_text"] == "abc"
assert obj["status"] == "completed"
def test_tools_responses_to_chat(self):
tools = [{"type": "function", "name": "f", "description": "d", "parameters": {"type": "object"}}]
chat_tools = rt.tools_responses_to_chat(tools)
assert chat_tools == [{"type": "function",
"function": {"name": "f", "description": "d",
"parameters": {"type": "object"}}}]
def test_messages_to_responses_input(self):
instr, items = rt.messages_to_responses_input([
{"role": "system", "content": "sys"},
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "yo"},
])
assert instr == "sys"
assert items[0] == {"role": "user", "content": [{"type": "input_text", "text": "hi"}]}
assert items[1] == {"role": "assistant", "content": [{"type": "output_text", "text": "yo"}]}
# ──────────────────────────────────────────────────────────────────────────────
# Fakes for backend generators
# ──────────────────────────────────────────────────────────────────────────────
def _fake_completion(content="hello world", usage=(3, 5)):
msg = MagicMock()
msg.model_dump.return_value = {"role": "assistant", "content": content}
usage_obj = MagicMock()
usage_obj.model_dump.return_value = {
"prompt_tokens": usage[0], "completion_tokens": usage[1], "total_tokens": sum(usage)}
return NS(choices=[NS(message=msg)], usage=usage_obj)
def _chunk(content=None, tool_calls=None):
return NS(choices=[NS(delta=NS(content=content, tool_calls=tool_calls),
finish_reason=None)], usage=None)
def _usage_chunk(p, c):
return NS(choices=[], usage=NS(prompt_tokens=p, completion_tokens=c))
def _text_chunks():
async def _gen():
yield _chunk(content="Hel")
yield _chunk(content="lo")
yield _usage_chunk(3, 5)
return _gen()
def _toolcall_chunks():
tc0 = NS(index=0, id="call_1", function=NS(name="lookup", arguments='{"q":'))
tc1 = NS(index=0, id=None, function=NS(name=None, arguments='"hi"}'))
async def _gen():
yield _chunk(tool_calls=[tc0])
yield _chunk(tool_calls=[tc1])
yield _usage_chunk(4, 2)
return _gen()
class _FakeEvent:
def __init__(self, data):
self._data = data
def model_dump(self):
return self._data
def _native_event_stream():
async def _gen():
yield _FakeEvent({"type": "response.created",
"response": {"id": "resp_openai", "status": "in_progress", "output": []}})
yield _FakeEvent({"type": "response.output_text.delta",
"item_id": "msg_1", "output_index": 0, "delta": "hi"})
yield _FakeEvent({"type": "response.completed", "response": {
"id": "resp_openai", "status": "completed",
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "hi"}]}],
"usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}}})
return _gen()
def _sse_events(text):
"""Split an SSE body into a list of (event_type, data_dict)."""
out = []
for frame in text.strip().split("\n\n"):
if not frame.strip():
continue
etype = data = None
for line in frame.splitlines():
if line.startswith("event: "):
etype = line[len("event: "):]
elif line.startswith("data: "):
data = orjson.loads(line[len("data: "):])
out.append((etype, data))
return out
@contextmanager
def _enter(*cms):
"""Enter a variable number of context managers (works with *unpacked tuples)."""
with ExitStack() as stack:
for cm in cms:
stack.enter_context(cm)
yield
def _patch_backend(native=False, endpoint="http://ollama:11434"):
"""Context managers patching endpoint selection + client construction."""
return (
patch.object(api_responses, "choose_endpoint",
AsyncMock(return_value=(endpoint, "test-model:latest"))),
patch.object(api_responses, "decrement_usage", AsyncMock()),
patch.object(api_responses, "is_ext_openai_endpoint", return_value=native),
patch.object(api_responses, "_make_openai_client", return_value=MagicMock()),
patch.object(api_responses, "get_llm_cache", return_value=None),
)
# ──────────────────────────────────────────────────────────────────────────────
# Translated path (Ollama-style backend)
# ──────────────────────────────────────────────────────────────────────────────
class TestTranslatedPath:
async def test_nonstream(self, client):
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_fake_completion("hello world")))):
resp = await client.post("/v1/responses",
json={"model": "test-model", "input": "hi", "store": False})
assert resp.status_code == 200
body = resp.json()
assert body["object"] == "response"
assert body["output_text"] == "hello world"
assert body["usage"] == {"input_tokens": 3, "output_tokens": 5, "total_tokens": 8}
assert body["id"].startswith("resp_")
async def test_stream_event_sequence(self, client):
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_text_chunks()))):
resp = await client.post("/v1/responses",
json={"model": "test-model", "input": "hi",
"stream": True, "store": False})
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
events = _sse_events(resp.content.decode())
types = [e[0] for e in events]
assert types[0] == "response.created"
assert "response.output_text.delta" in types
assert types[-1] == "response.completed"
# concatenated deltas reconstruct the content
deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta")
assert deltas == "Hello"
# completed event carries usage
completed = [d for t, d in events if t == "response.completed"][0]
assert completed["response"]["usage"]["input_tokens"] == 3
async def test_stream_tool_calls(self, client):
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_toolcall_chunks()))):
resp = await client.post("/v1/responses",
json={"model": "test-model", "input": "lookup hi",
"stream": True, "store": False})
events = _sse_events(resp.content.decode())
types = [e[0] for e in events]
assert "response.function_call_arguments.delta" in types
assert "response.function_call_arguments.done" in types
args = "".join(d["delta"] for t, d in events
if t == "response.function_call_arguments.delta")
assert args == '{"q":"hi"}'
completed = [d for t, d in events if t == "response.completed"][0]
fc = [i for i in completed["response"]["output"] if i["type"] == "function_call"][0]
assert fc["name"] == "lookup"
assert fc["arguments"] == '{"q":"hi"}'
# ──────────────────────────────────────────────────────────────────────────────
# Native path (external OpenAI backend)
# ──────────────────────────────────────────────────────────────────────────────
class TestNativePath:
async def test_nonstream_passthrough_rewrites_id(self, client):
oclient = MagicMock()
resp_obj = MagicMock()
resp_obj.model_dump.return_value = {
"id": "resp_openai", "status": "completed",
"output": [{"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": "native hi"}]}],
"usage": {"input_tokens": 2, "output_tokens": 3, "total_tokens": 5}}
oclient.responses.create = AsyncMock(return_value=resp_obj)
with (patch.object(api_responses, "choose_endpoint",
AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))),
patch.object(api_responses, "decrement_usage", AsyncMock()),
patch.object(api_responses, "is_ext_openai_endpoint", return_value=True),
patch.object(api_responses, "_make_openai_client", return_value=oclient),
patch.object(api_responses, "get_llm_cache", return_value=None)):
resp = await client.post("/v1/responses",
json={"model": "gpt", "input": "hi", "store": False})
body = resp.json()
assert body["output_text"] == "native hi"
assert body["id"].startswith("resp_") and body["id"] != "resp_openai"
# native call must not delegate state upstream
assert oclient.responses.create.call_args.kwargs["store"] is False
async def test_stream_passthrough(self, client):
oclient = MagicMock()
oclient.responses.create = AsyncMock(return_value=_native_event_stream())
with (patch.object(api_responses, "choose_endpoint",
AsyncMock(return_value=("https://api.openai.com/v1", "gpt"))),
patch.object(api_responses, "decrement_usage", AsyncMock()),
patch.object(api_responses, "is_ext_openai_endpoint", return_value=True),
patch.object(api_responses, "_make_openai_client", return_value=oclient),
patch.object(api_responses, "get_llm_cache", return_value=None)):
resp = await client.post("/v1/responses",
json={"model": "gpt", "input": "hi",
"stream": True, "store": False})
events = _sse_events(resp.content.decode())
# the completed event's response id is rewritten to the router id
completed = [d for t, d in events if t == "response.completed"][0]
assert completed["response"]["id"].startswith("resp_")
assert completed["response"]["id"] != "resp_openai"
# ──────────────────────────────────────────────────────────────────────────────
# Storage + chaining + retrieve/delete
# ──────────────────────────────────────────────────────────────────────────────
class TestStorageAndChaining:
async def test_store_and_retrieve(self, client):
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_fake_completion("remembered")))):
created = await client.post("/v1/responses",
json={"model": "test-model", "input": "hi", "store": True})
rid = created.json()["id"]
got = await client.get(f"/v1/responses/{rid}")
assert got.status_code == 200
assert got.json()["output_text"] == "remembered"
async def test_previous_response_id_rehydrates_history(self, client):
# First turn
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_fake_completion("turn-one")))):
first = await client.post("/v1/responses",
json={"model": "test-model", "input": "first?", "store": True})
rid = first.json()["id"]
# Second turn references the first — capture the messages sent to the backend
capture = AsyncMock(return_value=_fake_completion("turn-two"))
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries", capture)):
await client.post("/v1/responses",
json={"model": "test-model", "input": "second?",
"previous_response_id": rid, "store": True})
sent_messages = capture.call_args.args[1]["messages"]
contents = [m.get("content") for m in sent_messages]
assert "first?" in contents # prior user turn replayed
assert "turn-one" in contents # prior assistant turn replayed
assert "second?" in contents # current turn appended
async def test_delete(self, client):
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_fake_completion("bye")))):
created = await client.post("/v1/responses",
json={"model": "test-model", "input": "hi", "store": True})
rid = created.json()["id"]
deleted = await client.delete(f"/v1/responses/{rid}")
assert deleted.status_code == 200
assert deleted.json()["deleted"] is True
assert (await client.get(f"/v1/responses/{rid}")).status_code == 404
async def test_retrieve_missing_404(self, client):
assert (await client.get("/v1/responses/resp_missing")).status_code == 404
# ──────────────────────────────────────────────────────────────────────────────
# Background mode
# ──────────────────────────────────────────────────────────────────────────────
class TestBackgroundMode:
async def test_background_requires_store(self, client):
resp = await client.post("/v1/responses",
json={"model": "test-model", "input": "hi",
"background": True, "store": False})
assert resp.status_code == 400
async def test_background_lifecycle(self, client):
with _enter(*_patch_backend(native=False),
patch.object(api_responses, "create_chat_with_retries",
AsyncMock(return_value=_fake_completion("bg-done")))):
created = await client.post("/v1/responses",
json={"model": "test-model", "input": "hi",
"background": True, "store": True})
assert created.status_code == 200
assert created.json()["status"] == "queued"
rid = created.json()["id"]
# poll until terminal
status = None
for _ in range(100):
await asyncio.sleep(0.01)
got = await client.get(f"/v1/responses/{rid}")
status = got.json()["status"]
if status in ("completed", "failed", "cancelled"):
break
assert status == "completed"
assert got.json()["output_text"] == "bg-done"
async def test_fail_orphaned_responses(self, client):
db = router.db
await db.store_response("resp_orphan", previous_response_id=None, model="m",
status="in_progress", created_at=0, input_messages=[])
n = await db.fail_orphaned_responses()
assert n >= 1
row = await db.get_response("resp_orphan")
assert row["status"] == "failed"
# ──────────────────────────────────────────────────────────────────────────────
# Cache parity
# ──────────────────────────────────────────────────────────────────────────────
class _FakeCache:
def __init__(self, response_bytes):
self._resp = response_bytes
self.calls = []
async def get_chat(self, route, model, messages):
self.calls.append((route, model, messages))
return self._resp
class TestCacheParity:
async def test_cache_hit_served_as_response(self, client):
cached = orjson.dumps(rt.build_response_object(
response_id="resp_cached", model="test-model",
output_items=rt.chat_message_to_output_items(
{"role": "assistant", "content": "from-cache"})))
fake = _FakeCache(cached)
with (patch.object(api_responses, "get_llm_cache", return_value=fake),
patch.object(api_responses, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached")))):
resp = await client.post("/v1/responses",
json={"model": "test-model", "input": "ping",
"store": False, "nomyo": {"cache": True}})
assert resp.status_code == 200
assert resp.json()["output_text"] == "from-cache"
assert fake.calls and fake.calls[0][0] == "openai_responses"
async def test_cache_hit_served_as_sse(self, client):
cached = orjson.dumps(rt.build_response_object(
response_id="resp_cached", model="test-model",
output_items=rt.chat_message_to_output_items(
{"role": "assistant", "content": "from-cache"})))
fake = _FakeCache(cached)
with (patch.object(api_responses, "get_llm_cache", return_value=fake),
patch.object(api_responses, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached")))):
resp = await client.post("/v1/responses",
json={"model": "test-model", "input": "ping",
"stream": True, "store": False,
"nomyo": {"cache": True}})
assert resp.headers["content-type"].startswith("text/event-stream")
events = _sse_events(resp.content.decode())
deltas = "".join(d["delta"] for t, d in events if t == "response.output_text.delta")
assert deltas == "from-cache"

View file

@ -1,151 +0,0 @@
"""
Unit tests for transitive backend-error handling in the four Ollama-native
streaming generators (``/api/generate``, ``/api/chat``, ``/api/embeddings``,
``/api/embed``).
These reproduce the reported failure mode: a backend (nginx in front of ollama)
returns a 504 Gateway Time-out *while the response is being streamed*, so the
``ollama`` client raises ``ResponseError`` from inside the StreamingResponse
generator. Before the fix this escaped as an opaque "Exception in ASGI
application" traceback; now ``_handle_stream_error`` logs the endpoint/model and
emits a terminal Ollama-format ``{"error": ..., "status_code": ...}`` line.
No real backend required the ollama client and routing are mocked.
"""
import json
from contextlib import ExitStack
from unittest.mock import AsyncMock, patch
import httpx
import ollama
import openai
import pytest
from conftest import TEST_OLLAMA
pytestmark = pytest.mark.asyncio
# ── Fakes ─────────────────────────────────────────────────────────────────────
class _Chunk:
"""Minimal Ollama-native streaming chunk the generators can consume."""
prompt_eval_count = 0
eval_count = 0
done = False
message = None
response = None
done_reason = None
def model_dump_json(self):
return '{"model": "fake", "done": false}'
def _one_then_raise(exc):
"""Async generator: yield one valid chunk, then fail mid-stream."""
async def _gen():
yield _Chunk()
raise exc
return _gen()
class _FakeAsyncClient:
"""Stand-in for ``ollama.AsyncClient`` that fails with ``exc``.
Streaming methods (chat/generate) fail *after* one chunk to mimic a
mid-stream 504; the embedding methods fail on the initial await.
"""
def __init__(self, exc, *args, **kwargs):
self._exc = exc
async def chat(self, **kwargs):
return _one_then_raise(self._exc)
async def generate(self, **kwargs):
return _one_then_raise(self._exc)
async def embeddings(self, **kwargs):
raise self._exc
async def embed(self, **kwargs):
raise self._exc
def _patches(exc, mark_unhealthy):
"""Patch routing + the ollama client so the native path hits ``exc``."""
stack = ExitStack()
stack.enter_context(
patch("api.ollama.choose_endpoint", AsyncMock(return_value=(TEST_OLLAMA, "fake")))
)
stack.enter_context(patch("api.ollama.is_openai_compatible", lambda ep: False))
stack.enter_context(patch("api.ollama.decrement_usage", AsyncMock()))
stack.enter_context(patch("api.ollama._mark_backend_unhealthy", mark_unhealthy))
# The native path now fetches a cached client via get_ollama_client() rather
# than constructing ollama.AsyncClient inline, so patch that seam.
stack.enter_context(
patch("api.ollama.get_ollama_client", lambda *a, **k: _FakeAsyncClient(exc))
)
return stack
# Route → request payload. stream=True only matters for chat/generate.
_ROUTES = {
"/api/chat": {"model": "fake", "stream": True, "messages": [{"role": "user", "content": "hi"}]},
"/api/generate": {"model": "fake", "stream": True, "prompt": "hi"},
"/api/embeddings": {"model": "fake", "prompt": "hi"},
"/api/embed": {"model": "fake", "input": "hi"},
}
def _last_json_line(text):
lines = [l for l in text.strip().split("\n") if l.strip()]
assert lines, "expected at least one ndjson line in the response body"
return json.loads(lines[-1])
# ── Tests ─────────────────────────────────────────────────────────────────────
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
async def test_504_surfaces_as_error_line(client, route, payload):
"""A 504 ResponseError becomes a terminal {"error", "status_code"} line."""
exc = ollama.ResponseError("<html>504 Gateway Time-out</html>", 504)
mark = AsyncMock()
with _patches(exc, mark):
resp = await client.post(route, json=payload)
# Streaming already started (or single-shot) → HTTP status is 200, the
# error is delivered in-band rather than as a 5xx crash.
assert resp.status_code == 200
err = _last_json_line(resp.text)
assert "error" in err
assert "504" in err["error"]
assert err["status_code"] == 504
# A plain 504 is not a connection-class failure → endpoint stays healthy.
mark.assert_not_called()
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
async def test_no_asgi_500_on_backend_failure(client, route, payload):
"""The generator must never let the backend error escape as a 500."""
exc = ollama.ResponseError("boom", 502)
with _patches(exc, AsyncMock()):
resp = await client.post(route, json=payload)
assert resp.status_code == 200
assert resp.status_code != 500
async def test_connection_error_marks_backend_unhealthy(client):
"""A connection-class failure mid-stream marks (endpoint, model) unhealthy."""
exc = openai.APIConnectionError(request=httpx.Request("POST", "http://x"))
mark = AsyncMock()
with _patches(exc, mark):
resp = await client.post("/api/chat", json=_ROUTES["/api/chat"])
assert resp.status_code == 200
err = _last_json_line(resp.text)
assert "error" in err
mark.assert_awaited_once()
# Called with the routed endpoint + model.
called_ep, called_model = mark.await_args.args[0], mark.await_args.args[1]
assert called_ep == TEST_OLLAMA
assert called_model == "fake"

View file

@ -1,116 +0,0 @@
"""Unit tests for context-window trimming logic."""
import pytest
import router
def _msgs(roles_contents):
return [{"role": r, "content": c} for r, c in roles_contents]
class TestCountMessageTokens:
def test_returns_int(self):
msgs = _msgs([("user", "hello")])
assert isinstance(router._count_message_tokens(msgs), int)
def test_empty_list(self):
assert router._count_message_tokens([]) >= 0
def test_longer_content_more_tokens(self):
short = _msgs([("user", "hi")])
long_ = _msgs([("user", "a " * 500)])
assert router._count_message_tokens(long_) > router._count_message_tokens(short)
def test_list_content(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "what do you see?"},
]}]
tokens = router._count_message_tokens(msgs)
assert tokens > 0
def test_multiple_messages(self):
msgs = _msgs([("system", "you are helpful"), ("user", "hello"), ("assistant", "hi!")])
assert router._count_message_tokens(msgs) > 10
class TestTrimMessagesForContext:
def test_short_history_unchanged(self):
msgs = _msgs([("user", "hello"), ("assistant", "hi"), ("user", "bye")])
result = router._trim_messages_for_context(msgs, n_ctx=4096)
assert result == msgs
def test_system_messages_always_kept(self):
msgs = (
_msgs([("system", "you are helpful")])
+ _msgs([("user", f"msg {i}") for i in range(50)])
+ _msgs([("user", "final question")])
)
result = router._trim_messages_for_context(msgs, n_ctx=512)
system_msgs = [m for m in result if m["role"] == "system"]
assert len(system_msgs) == 1
assert system_msgs[0]["content"] == "you are helpful"
def test_last_user_message_always_kept(self):
msgs = _msgs([("user", f"old msg {i}") for i in range(100)] + [("user", "very important last question")])
result = router._trim_messages_for_context(msgs, n_ctx=256)
assert result[-1]["content"] == "very important last question"
def test_oldest_dropped_first(self):
msgs = _msgs([
("user", "oldest msg"),
("assistant", "oldest reply"),
("user", "newer msg"),
("assistant", "newer reply"),
("user", "newest"),
])
# Use very small target to force trimming
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=10)
contents = [m["content"] for m in result]
# "oldest msg" should be dropped before "newest"
if "oldest msg" in contents:
assert "newest" in contents
else:
assert "newest" in contents
def test_result_starts_with_user(self):
msgs = _msgs([
("assistant", "leftover assistant"),
("user", "question"),
])
result = router._trim_messages_for_context(msgs, n_ctx=256, target_tokens=20)
if result:
assert result[0]["role"] == "user"
def test_target_tokens_overrides_safety_margin(self):
msgs = _msgs([("user", "a " * 200)])
result_small = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=10)
result_large = router._trim_messages_for_context(msgs, n_ctx=8192, target_tokens=5000)
# Both should return at least the last message
assert len(result_small) >= 1
assert len(result_large) >= 1
class TestCalibratedTrimTarget:
def test_returns_positive_int(self):
msgs = [{"role": "user", "content": "hello " * 100}]
result = router._calibrated_trim_target(msgs, n_ctx=4096, actual_tokens=3000)
assert isinstance(result, int)
assert result >= 1
def test_over_limit_reduces_target(self):
msgs = [{"role": "user", "content": "a " * 500}]
# actual_tokens > n_ctx means we need to shed more
target = router._calibrated_trim_target(msgs, n_ctx=2048, actual_tokens=2500)
assert target < router._count_message_tokens(msgs)
def test_well_within_limit_returns_current(self):
msgs = [{"role": "user", "content": "hi"}]
# actual_tokens << n_ctx means nothing to shed
target = router._calibrated_trim_target(msgs, n_ctx=16384, actual_tokens=50)
# Should return cur_tiktoken since to_shed == 0
assert target == max(1, router._count_message_tokens(msgs))
def test_minimum_is_one(self):
# Even if we need to shed everything, result is at least 1
msgs = [{"role": "user", "content": "hello"}]
target = router._calibrated_trim_target(msgs, n_ctx=100, actual_tokens=99999)
assert target >= 1

View file

@ -1,325 +0,0 @@
"""Unit tests for pure helper functions in router.py (no network, no app)."""
import time
import asyncio
from unittest.mock import MagicMock, patch
import aiohttp
import pytest
import router
class TestMaskSecrets:
def test_masks_openai_key(self):
text = "Authorization: Bearer sk-abcd1234XYZabcd1234XYZabcd1234XYZ"
result = router._mask_secrets(text)
assert "sk-***redacted***" in result
assert "sk-abcd1234" not in result
def test_masks_api_key_assignment(self):
result = router._mask_secrets("api_key: supersecretvalue123")
assert "supersecretvalue123" not in result
assert "***redacted***" in result
def test_masks_api_key_with_colon(self):
result = router._mask_secrets("api-key: mykey")
assert "mykey" not in result
def test_empty_string_returns_empty(self):
assert router._mask_secrets("") == ""
def test_none_returns_none(self):
assert router._mask_secrets(None) is None
def test_no_secrets_unchanged(self):
text = "this is a normal log line"
assert router._mask_secrets(text) == text
class TestIsFresh:
def test_fresh_within_ttl(self):
cached_at = time.time() - 10
assert router._is_fresh(cached_at, 300) is True
def test_expired_beyond_ttl(self):
cached_at = time.time() - 400
assert router._is_fresh(cached_at, 300) is False
def test_exactly_at_boundary(self):
cached_at = time.time() - 300
# May be True or False depending on timing, just verify it runs
result = router._is_fresh(cached_at, 300)
assert isinstance(result, bool)
def test_just_cached(self):
assert router._is_fresh(time.time(), 1) is True
class TestNormalizeLlamaModelName:
def test_strips_hf_prefix(self):
assert router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF") == "gpt-oss-20b-GGUF"
def test_strips_quant_suffix(self):
assert router._normalize_llama_model_name("model:Q8_0") == "model"
def test_strips_both(self):
result = router._normalize_llama_model_name("unsloth/gpt-oss-20b-GGUF:F16")
assert result == "gpt-oss-20b-GGUF"
def test_no_prefix_no_suffix(self):
assert router._normalize_llama_model_name("plain-model") == "plain-model"
def test_multiple_slashes(self):
result = router._normalize_llama_model_name("org/user/model-name:Q4_K_M")
assert result == "model-name"
class TestExtractLlamaQuant:
def test_extracts_quant(self):
assert router._extract_llama_quant("unsloth/model:Q8_0") == "Q8_0"
def test_no_quant_returns_empty(self):
assert router._extract_llama_quant("plain-model") == ""
def test_f16(self):
assert router._extract_llama_quant("model:F16") == "F16"
def test_q4_k_m(self):
assert router._extract_llama_quant("model:Q4_K_M") == "Q4_K_M"
class TestIsUnixSocketEndpoint:
def test_sock_endpoint_detected(self):
assert router._is_unix_socket_endpoint("http://192.168.0.52.sock/v1") is True
def test_regular_http_not_sock(self):
assert router._is_unix_socket_endpoint("http://192.168.0.52:8080/v1") is False
def test_ollama_not_sock(self):
assert router._is_unix_socket_endpoint("http://localhost:11434") is False
def test_dot_sock_in_host_detected(self):
assert router._is_unix_socket_endpoint("http://llama.sock/v1") is True
class TestGetSocketPath:
def test_returns_run_user_path(self):
import os
path = router._get_socket_path("http://192.168.0.52.sock/v1")
uid = os.getuid()
assert path == f"/run/user/{uid}/192.168.0.52.sock"
class TestIsBase64:
def test_valid_base64(self):
import base64
data = base64.b64encode(b"hello world").decode()
assert router.is_base64(data) is True
def test_invalid_base64(self):
assert router.is_base64("not-base64!@#$") is False
def test_empty_string(self):
# Empty string is valid base64 (decodes to empty bytes)
assert router.is_base64("") is True
def test_non_string(self):
# Non-strings fall through without returning True (returns None)
assert not router.is_base64(12345)
class TestIsLlamaModelLoaded:
def test_status_dict_loaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "loaded"}}) is True
def test_status_dict_unloaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": {"value": "unloaded"}}) is False
def test_status_string_loaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": "loaded"}) is True
def test_status_string_unloaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": "unloaded"}) is False
def test_no_status_field_always_loaded(self):
# No status field → always available (single-model server)
assert router._is_llama_model_loaded({"id": "m"}) is True
def test_status_none_always_loaded(self):
assert router._is_llama_model_loaded({"id": "m", "status": None}) is True
class TestEp2Base:
def test_adds_v1_to_ollama(self):
assert router.ep2base("http://localhost:11434") == "http://localhost:11434/v1"
def test_keeps_v1_if_present(self):
assert router.ep2base("http://host/v1") == "http://host/v1"
def test_llama_server_endpoint_unchanged(self):
ep = "http://192.168.0.50:8889/v1"
assert router.ep2base(ep) == ep
class TestDedupeOnKeys:
def test_removes_duplicate_by_single_key(self):
items = [{"name": "a", "x": 1}, {"name": "b", "x": 2}, {"name": "a", "x": 3}]
result = router.dedupe_on_keys(items, ["name"])
assert len(result) == 2
assert result[0]["name"] == "a"
assert result[1]["name"] == "b"
def test_removes_duplicate_by_two_keys(self):
items = [
{"digest": "abc", "name": "m1"},
{"digest": "abc", "name": "m1"},
{"digest": "def", "name": "m2"},
]
result = router.dedupe_on_keys(items, ["digest", "name"])
assert len(result) == 2
def test_empty_list(self):
assert router.dedupe_on_keys([], ["name"]) == []
def test_no_duplicates(self):
items = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
assert len(router.dedupe_on_keys(items, ["name"])) == 3
class TestFormatConnectionIssue:
def test_connector_error_message(self):
err = aiohttp.ClientConnectorError(
connection_key=MagicMock(host="localhost", port=11434),
os_error=OSError(111, "Connection refused"),
)
msg = router._format_connection_issue("http://localhost:11434", err)
assert "localhost" in msg
assert "Connection refused" in msg or "111" in msg
def test_timeout_error_message(self):
msg = router._format_connection_issue("http://host:1234", asyncio.TimeoutError())
assert "Timed out" in msg
assert "host:1234" in msg
def test_generic_error(self):
msg = router._format_connection_issue("http://host:1234", ValueError("boom"))
assert "host:1234" in msg
assert "boom" in msg
class TestIsExtOpenaiEndpoint:
def test_openai_com_is_ext(self):
cfg = MagicMock()
cfg.endpoints = []
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("https://api.openai.com/v1") is True
def test_ollama_default_port_not_ext(self):
cfg = MagicMock()
cfg.endpoints = ["http://host:11434"]
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("http://host:11434") is False
def test_llama_server_not_ext(self):
cfg = MagicMock()
cfg.endpoints = []
cfg.llama_server_endpoints = ["http://host:8080/v1"]
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("http://host:8080/v1") is False
def test_no_v1_not_ext(self):
cfg = MagicMock()
cfg.endpoints = ["http://host:11434"]
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_ext_openai_endpoint("http://host:11434") is False
class TestIsOpenaiCompatible:
def test_v1_endpoint_compatible(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_openai_compatible("http://host/v1") is True
def test_ollama_not_compatible(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.is_openai_compatible("http://localhost:11434") is False
def test_llama_server_in_list_compatible(self):
cfg = MagicMock()
cfg.llama_server_endpoints = ["http://host:8080"]
with patch.object(router, "config", cfg):
assert router.is_openai_compatible("http://host:8080") is True
class TestGetTrackingModel:
def test_ollama_adds_latest(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.get_tracking_model("http://ollama:11434", "llama3.2") == "llama3.2:latest"
def test_ollama_keeps_existing_tag(self):
cfg = MagicMock()
cfg.llama_server_endpoints = []
with patch.object(router, "config", cfg):
assert router.get_tracking_model("http://ollama:11434", "llama3.2:7b") == "llama3.2:7b"
def test_llama_server_normalizes(self):
ep = "http://host:8080/v1"
cfg = MagicMock()
cfg.llama_server_endpoints = [ep]
with patch.object(router, "config", cfg):
result = router.get_tracking_model(ep, "unsloth/model:Q8_0")
assert result == "model"
class TestLlamaSwapClassification:
def _cfg(self, *, server=None, swap=None):
cfg = MagicMock()
cfg.endpoints = []
cfg.llama_server_endpoints = server or []
cfg.llama_swap_endpoints = swap or []
return cfg
def test_is_llama_swap_only_for_swap_list(self):
from backends.normalize import is_llama_swap
swap_ep = "http://host:8890/v1"
server_ep = "http://host:8889/v1"
cfg = self._cfg(server=[server_ep], swap=[swap_ep])
with patch.object(router, "config", cfg):
assert is_llama_swap(swap_ep) is True
assert is_llama_swap(server_ep) is False
def test_is_llama_server_covers_both(self):
from backends.normalize import is_llama_server
swap_ep = "http://host:8890/v1"
server_ep = "http://host:8889/v1"
cfg = self._cfg(server=[server_ep], swap=[swap_ep])
with patch.object(router, "config", cfg):
assert is_llama_server(swap_ep) is True
assert is_llama_server(server_ep) is True
assert is_llama_server("http://host:11434") is False
def test_swap_is_openai_compatible_not_ext(self):
swap_ep = "http://host:8890/v1"
cfg = self._cfg(swap=[swap_ep])
with patch.object(router, "config", cfg):
assert router.is_openai_compatible(swap_ep) is True
assert router.is_ext_openai_endpoint(swap_ep) is False
def test_swap_tracking_model_normalized(self):
swap_ep = "http://host:8890/v1"
cfg = self._cfg(swap=[swap_ep])
with patch.object(router, "config", cfg):
assert router.get_tracking_model(swap_ep, "unsloth/model:Q8_0") == "model"
def test_llama_endpoints_dedupes_and_orders(self):
from backends.normalize import llama_endpoints
cfg = self._cfg(server=["a", "b"], swap=["b", "c"])
assert llama_endpoints(cfg) == ["a", "b", "c"]

View file

@ -1,173 +0,0 @@
"""Unit tests for router.rechunk — OpenAI ↔ Ollama chunk shape conversion."""
import time
from types import SimpleNamespace
import ollama
import router
def _ns(**kw):
return SimpleNamespace(**kw)
def _stream_chunk(content="hi", role="assistant", finish_reason=None,
usage=None, model="m"):
"""Build a SimpleNamespace mimicking a streaming OpenAI chunk."""
delta = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
tool_calls=None)
choice = _ns(delta=delta, finish_reason=finish_reason, logprobs=None)
return _ns(model=model, choices=[choice], usage=usage)
def _nonstream_chunk(content="hi", role="assistant", finish_reason="stop",
usage=None, model="m", tool_calls=None):
"""Build a SimpleNamespace mimicking a non-streaming OpenAI ChatCompletion."""
message = _ns(content=content, role=role, reasoning=None, reasoning_content=None,
tool_calls=tool_calls)
choice = _ns(message=message, finish_reason=finish_reason, logprobs=None)
return _ns(model=model, choices=[choice], usage=usage)
# ──────────────────────────────────────────────────────────────────────────────
# openai_chat_completion2ollama
# ──────────────────────────────────────────────────────────────────────────────
class TestChatCompletionToOllama:
def test_streaming_content_chunk(self):
chunk = _stream_chunk(content="hello", finish_reason=None, usage=None)
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert isinstance(out, ollama.ChatResponse)
assert out.message.role == "assistant"
assert out.message.content == "hello"
assert out.done is False # usage is None → not done yet
assert out.model == "m"
def test_streaming_empty_content_defaults(self):
# Some chunks have content=None — should coerce to empty string
chunk = _stream_chunk(content=None, role=None)
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert out.message.role == "assistant" # role defaulted
assert out.message.content == ""
def test_final_usage_only_chunk_marks_done(self):
usage = _ns(prompt_tokens=10, completion_tokens=5, total_tokens=15)
chunk = _ns(model="m", choices=[], usage=usage)
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert out.done is True
assert out.done_reason == "stop"
assert out.prompt_eval_count == 10
assert out.eval_count == 5
assert out.message.content == ""
def test_nonstreaming_with_content(self):
usage = _ns(prompt_tokens=2, completion_tokens=3, total_tokens=5)
chunk = _nonstream_chunk(content="response text", finish_reason="stop", usage=usage)
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
assert out.done is True
assert out.message.content == "response text"
assert out.prompt_eval_count == 2
assert out.eval_count == 3
def test_nonstreaming_tool_calls_converted(self):
"""Tool calls with JSON string arguments are parsed into dicts."""
tc = _ns(function=_ns(name="get_weather", arguments='{"city": "Paris"}'))
usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
chunk = _nonstream_chunk(
content="", finish_reason="tool_calls", usage=usage, tool_calls=[tc]
)
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
assert out.message.tool_calls is not None
assert len(out.message.tool_calls) == 1
first = out.message.tool_calls[0]
assert first.function.name == "get_weather"
assert first.function.arguments == {"city": "Paris"}
def test_nonstreaming_tool_calls_with_invalid_json_fall_back_to_empty(self):
tc = _ns(function=_ns(name="f", arguments="not-json"))
usage = _ns(prompt_tokens=1, completion_tokens=1, total_tokens=2)
chunk = _nonstream_chunk(content="", usage=usage, tool_calls=[tc])
out = router.rechunk.openai_chat_completion2ollama(chunk, False, time.perf_counter())
assert out.message.tool_calls[0].function.arguments == {}
def test_streaming_tool_calls_in_delta_are_skipped(self):
"""Streaming mode must not assemble tool calls (caller handles it)."""
chunk = _stream_chunk(content="x", finish_reason=None)
# Even if a chunk somehow carried tool_calls in the delta, streaming
# mode should ignore them.
out = router.rechunk.openai_chat_completion2ollama(chunk, True, time.perf_counter())
assert out.message.tool_calls is None
# ──────────────────────────────────────────────────────────────────────────────
# openai_completion2ollama
# ──────────────────────────────────────────────────────────────────────────────
class TestCompletionToOllama:
def test_streaming_text_chunk(self):
choice = _ns(text="word", finish_reason=None, reasoning=None)
chunk = _ns(model="m", choices=[choice], usage=None)
out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
assert isinstance(out, ollama.GenerateResponse)
assert out.response == "word"
assert out.done is False
def test_final_chunk_with_usage(self):
usage = _ns(prompt_tokens=4, completion_tokens=6, total_tokens=10)
choice = _ns(text="end", finish_reason="stop", reasoning=None)
chunk = _ns(model="m", choices=[choice], usage=usage)
out = router.rechunk.openai_completion2ollama(chunk, True, time.perf_counter())
assert out.done is True
assert out.prompt_eval_count == 4
assert out.eval_count == 6
# ──────────────────────────────────────────────────────────────────────────────
# embeddings / embed
# ──────────────────────────────────────────────────────────────────────────────
class TestEmbeddingConversions:
def test_openai_embeddings2ollama(self):
chunk = _ns(data=[_ns(embedding=[0.1, 0.2, 0.3])])
out = router.rechunk.openai_embeddings2ollama(chunk)
assert isinstance(out, ollama.EmbeddingsResponse)
assert list(out.embedding) == [0.1, 0.2, 0.3]
def test_openai_embed2ollama(self):
chunk = _ns(data=[_ns(embedding=[0.5, 0.6])])
out = router.rechunk.openai_embed2ollama(chunk, "my-embed-model")
assert isinstance(out, ollama.EmbedResponse)
assert out.model == "my-embed-model"
assert list(out.embeddings[0]) == [0.5, 0.6]
# ──────────────────────────────────────────────────────────────────────────────
# extract_usage_from_llama_timings
# ──────────────────────────────────────────────────────────────────────────────
class TestExtractUsageFromLlamaTimings:
def test_none_when_no_timings_attr(self):
obj = _ns()
assert router.rechunk.extract_usage_from_llama_timings(obj) is None
def test_prompt_plus_cache_sums(self):
obj = _ns(timings={"prompt_n": 1, "cache_n": 236, "predicted_n": 35})
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
assert prompt == 237
assert completion == 35
def test_missing_keys_default_to_zero(self):
obj = _ns(timings={"predicted_n": 12})
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
assert prompt == 0
assert completion == 12
def test_null_values_treated_as_zero(self):
obj = _ns(timings={"prompt_n": None, "cache_n": None, "predicted_n": None})
prompt, completion = router.rechunk.extract_usage_from_llama_timings(obj)
assert prompt == 0
assert completion == 0
def test_non_dict_timings_returns_none(self):
obj = _ns(timings="not-a-dict")
assert router.rechunk.extract_usage_from_llama_timings(obj) is None

View file

@ -1,200 +0,0 @@
"""Unit tests for message transformation functions."""
from unittest.mock import MagicMock
import pytest
import router
class TestStripAssistantPrefill:
def test_removes_trailing_assistant(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "prefill"},
]
result = router._strip_assistant_prefill(msgs)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_keeps_non_trailing_assistant(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "response"},
{"role": "user", "content": "follow-up"},
]
result = router._strip_assistant_prefill(msgs)
assert len(result) == 3
def test_empty_list_unchanged(self):
assert router._strip_assistant_prefill([]) == []
def test_single_user_message_unchanged(self):
msgs = [{"role": "user", "content": "hi"}]
assert router._strip_assistant_prefill(msgs) == msgs
class TestTransformToolCallsToOpenAI:
def test_adds_type_function(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "get_weather", "arguments": {"city": "Berlin"}}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
tc = result[0]["tool_calls"][0]
assert tc["type"] == "function"
def test_adds_id_when_missing(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "fn", "arguments": {}}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
assert "id" in result[0]["tool_calls"][0]
def test_converts_dict_arguments_to_string(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "fn", "arguments": {"key": "val"}}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
args = result[0]["tool_calls"][0]["function"]["arguments"]
assert isinstance(args, str)
import orjson
parsed = orjson.loads(args)
assert parsed == {"key": "val"}
def test_keeps_string_arguments_unchanged(self):
msgs = [{"role": "assistant", "tool_calls": [
{"function": {"name": "fn", "arguments": '{"key": "val"}'}}
]}]
result = router.transform_tool_calls_to_openai(msgs)
args = result[0]["tool_calls"][0]["function"]["arguments"]
assert args == '{"key": "val"}'
def test_links_tool_call_id_to_tool_response(self):
msgs = [
{"role": "assistant", "tool_calls": [
{"function": {"name": "get_weather", "arguments": {}}}
]},
{"role": "tool", "name": "get_weather", "content": "sunny"},
]
result = router.transform_tool_calls_to_openai(msgs)
tc_id = result[0]["tool_calls"][0]["id"]
assert result[1].get("tool_call_id") == tc_id
def test_non_tool_messages_unchanged(self):
msgs = [{"role": "user", "content": "hello"}]
result = router.transform_tool_calls_to_openai(msgs)
assert result == msgs
class TestStripImagesFromMessages:
def test_removes_image_url_parts(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "what is this?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
]}]
result = router._strip_images_from_messages(msgs)
content = result[0]["content"]
assert content == "what is this?"
def test_keeps_text_only_messages(self):
msgs = [{"role": "user", "content": "plain text"}]
result = router._strip_images_from_messages(msgs)
assert result[0]["content"] == "plain text"
def test_multiple_text_parts_kept_as_list(self):
msgs = [{"role": "user", "content": [
{"type": "text", "text": "part one"},
{"type": "text", "text": "part two"},
{"type": "image_url", "image_url": {"url": "data:..."}},
]}]
result = router._strip_images_from_messages(msgs)
content = result[0]["content"]
assert isinstance(content, list)
assert len(content) == 2
def test_all_images_removed_empty_list(self):
msgs = [{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "data:..."}},
]}]
result = router._strip_images_from_messages(msgs)
# Image-only content becomes empty list
content = result[0]["content"]
assert content == []
class TestAccumulateOpenAITcDelta:
def _make_chunk(self, index, name=None, args_fragment="", tc_id=None):
delta = MagicMock()
tc = MagicMock()
tc.index = index
tc.id = tc_id
tc.function = MagicMock()
tc.function.name = name
tc.function.arguments = args_fragment
delta.tool_calls = [tc]
chunk = MagicMock()
chunk.choices = [MagicMock(delta=delta)]
return chunk
def test_first_delta_creates_entry(self):
acc = {}
chunk = self._make_chunk(0, name="my_fn", args_fragment='{"k"')
router._accumulate_openai_tc_delta(chunk, acc)
assert 0 in acc
assert acc[0]["name"] == "my_fn"
assert acc[0]["arguments"] == '{"k"'
def test_subsequent_deltas_concatenate_args(self):
acc = {}
router._accumulate_openai_tc_delta(self._make_chunk(0, name="fn", args_fragment='{"k"'), acc)
router._accumulate_openai_tc_delta(self._make_chunk(0, args_fragment=': "v"}'), acc)
assert acc[0]["arguments"] == '{"k": "v"}'
def test_multiple_tool_calls_tracked_separately(self):
acc = {}
c1 = self._make_chunk(0, name="fn1", args_fragment="{}")
c2 = self._make_chunk(1, name="fn2", args_fragment="{}")
chunk = MagicMock()
tc1 = MagicMock()
tc1.index = 0
tc1.id = "id1"
tc1.function = MagicMock(name="fn1", arguments="{}")
tc2 = MagicMock()
tc2.index = 1
tc2.id = "id2"
tc2.function = MagicMock(name="fn2", arguments="{}")
chunk.choices = [MagicMock(delta=MagicMock(tool_calls=[tc1, tc2]))]
router._accumulate_openai_tc_delta(chunk, acc)
assert 0 in acc and 1 in acc
def test_no_choices_is_noop(self):
acc = {}
chunk = MagicMock(choices=[])
router._accumulate_openai_tc_delta(chunk, acc)
assert acc == {}
class TestBuildOllamaToolCalls:
def test_builds_from_accumulator(self):
acc = {0: {"id": "call_abc", "name": "get_weather", "arguments": '{"city": "Berlin"}'}}
result = router._build_ollama_tool_calls(acc)
assert result is not None
assert len(result) == 1
assert result[0].function.name == "get_weather"
assert result[0].function.arguments == {"city": "Berlin"}
def test_invalid_json_args_becomes_empty_dict(self):
acc = {0: {"id": "c1", "name": "fn", "arguments": "not-json"}}
result = router._build_ollama_tool_calls(acc)
assert result[0].function.arguments == {}
def test_empty_accumulator_returns_none(self):
assert router._build_ollama_tool_calls({}) is None
def test_preserves_order_by_index(self):
acc = {
1: {"id": "c2", "name": "fn2", "arguments": "{}"},
0: {"id": "c1", "name": "fn1", "arguments": "{}"},
}
result = router._build_ollama_tool_calls(acc)
assert result[0].function.name == "fn1"
assert result[1].function.name == "fn2"

177
tokens.py
View file

@ -1,177 +0,0 @@
"""Token-count write-behind pipeline.
``token_worker`` drains ``token_queue`` into the in-memory buffer (and into
``token_usage_counts`` for immediate SSE reporting). ``flush_buffer``
periodically persists the buffer to SQLite via ``TokenDatabase``.
``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left.
The lock order is ``buffer_lock`` then ``token_usage_lock`` see
choose_endpoint for why we never combine them with usage_lock.
"""
import asyncio
from datetime import datetime, timezone
from state import (
token_queue,
token_buffer,
time_series_buffer,
buffer_lock,
token_usage_counts,
token_usage_lock,
FLUSH_INTERVAL,
)
from sse import _capture_snapshot, _distribute_snapshot
from db import get_db
async def token_worker() -> None:
try:
while True:
endpoint, model, prompt, comp = await token_queue.get()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
# Accumulate counts in memory buffer (protected by lock)
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
# Add to time series buffer with timestamp (UTC)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
# Update in-memory counts for immediate reporting
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[token_worker] Task cancelled, processing remaining queue items...")
# Process any remaining items in the queue before exiting
while not token_queue.empty():
try:
endpoint, model, prompt, comp = token_queue.get_nowait()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.QueueEmpty:
break
print("[token_worker] Task cancelled, remaining items processed.")
raise
async def flush_buffer() -> None:
"""Periodically flush accumulated token counts to the database."""
try:
while True:
await asyncio.sleep(FLUSH_INTERVAL)
# Flush token counts and time series (protected by lock)
async with buffer_lock:
if token_buffer:
# Copy buffer before releasing lock for DB operation
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock to avoid blocking
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
# Flush any remaining data before exiting
try:
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
except Exception as e:
print(f"[flush_buffer] Error during shutdown flush: {e}")
raise
async def flush_remaining_buffers() -> None:
"""
Flush any in-memory buffers to the database on shutdown.
This is designed to be safely invoked during shutdown and should not raise.
"""
try:
flushed_entries = 0
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
flushed_entries += sum(len(v) for v in token_buffer.values())
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
flushed_entries += len(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
if flushed_entries:
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
else:
print("[shutdown] No in-memory entries to flush on shutdown.")
except Exception as e:
# Do not raise during shutdown log and continue teardown
print(f"[shutdown] Error flushing remaining buffers: {e}")

File diff suppressed because it is too large Load diff