Compare commits

...
Sign in to create a new pull request.

19 commits

Author SHA1 Message Date
d163fea154
fix: remove aioresponses
sec: bumb aiohttp 3.14

fix: tiktoken test issue by pre-cache the vocab file
2026-06-07 13:23:35 +02:00
3cd530586c
feat: cache backend clients per endpoint instead of building one (with a fresh SSL context) per request
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 3m59s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m25s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m46s
Build and Publish Docker Image / merge (push) Successful in 33s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 19m56s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 33s
2026-06-07 09:55:54 +02:00
1ce792c48b
feat: new load test added 2026-06-07 09:38:14 +02:00
75d204e7f3
feat: use SSE reconnect to prevent API Key modal to pop up in dashboard if no API Key is configured 2026-06-07 09:29:06 +02:00
497c87b02e
refac: code deduplication for error handling and call sites
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 4m2s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m37s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 17m43s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 34s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m47s
Build and Publish Docker Image / merge (push) Successful in 33s
2026-06-04 10:57:33 +02:00
2dceece0d6
feat: add test for ollama stream errors 2026-06-04 10:42:18 +02:00
d3b2ee3047
feat: surface an upstream ollama backend error transitively from a streaming generator 2026-06-04 10:33:47 +02:00
b754daf1af
feat: after closing the probe session, reset
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 3m52s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m23s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 15m16s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 34s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 11m59s
Build and Publish Docker Image / merge (push) Successful in 33s
2026-05-28 10:16:54 +02:00
820e217da6
fix: Lightweight health/introspection probes no longer compete with long-lived streaming completions for the proxy pool's per-host connection slots 2026-05-28 09:54:53 +02:00
13d796817f
feat: add authorization header to llama model endpoint fetch 2026-05-28 09:32:20 +02:00
4b5a70e787
refac: modularize apis VII 2026-05-19 14:57:39 +02:00
e74f5d1ba6
refac: request handling VI 2026-05-19 14:09:52 +02:00
8355bf9a1e
refac: modularize sse, routing, db and token handling V 2026-05-19 12:48:55 +02:00
3a9854c5db
refac: modularize backend IV 2026-05-19 12:05:51 +02:00
c88ba1e5a4
refac: modularize global states III 2026-05-19 11:18:06 +02:00
d2b31b6c7b
refac: modularize config II 2026-05-19 11:00:50 +02:00
90b6868f5a
refac: split into modules I 2026-05-19 10:05:27 +02:00
078855ba9a Merge pull request 'feat: completion errors on an endpoint:model key a caught, cached and rerouted (openai compatible endpoints)' (#87) from dev-0.9.x-completion-error-cache into dev-0.9.x
Reviewed-on: https://bitfreedom.net/code/code/nomyo-ai/nomyo-router/pulls/87
2026-05-19 07:40:40 +02:00
079b677e23
feat: completion errors on an endpoint:model key a caught, cached and rerouted (openai compatible endpoints)
All checks were successful
PR Tests / test (pull_request) Successful in 57s
2026-05-18 18:14:28 +02:00
34 changed files with 106321 additions and 4055 deletions

0
api/__init__.py Normal file
View file

278
api/management.py Normal file
View file

@ -0,0 +1,278 @@
"""Management / observability routes.
Read-only endpoints used by the dashboard and external monitoring:
* usage counters and token-counts breakdown,
* conversation-affinity introspection,
* endpoint health summary,
* LLM-response cache stats and invalidation,
* SSE live-stream of usage updates,
* hostname and ``/health`` probe.
"""
import asyncio
import socket
import time
from typing import Optional
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache
from config import get_config
from db import get_db
from state import (
usage_counts,
token_usage_counts,
_affinity_map,
_affinity_lock,
)
from sse import subscribe, unsubscribe
from backends.normalize import _normalize_llama_model_name
from backends.probe import _endpoint_health
router = APIRouter()
@router.get("/api/token_counts")
async def token_counts_proxy():
breakdown = []
total = 0
async for entry in get_db().load_token_counts():
total += entry['total_tokens']
breakdown.append({
"endpoint": entry["endpoint"],
"model": entry["model"],
"input_tokens": entry["input_tokens"],
"output_tokens": entry["output_tokens"],
"total_tokens": entry["total_tokens"],
})
return {"total_tokens": total, "breakdown": breakdown}
@router.post("/api/aggregate_time_series_days")
async def aggregate_time_series_days_proxy(request: Request):
"""
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
"""
try:
body_bytes = await request.body()
if not body_bytes:
days = 30
trim_old = False
else:
payload = orjson.loads(body_bytes.decode("utf-8"))
days = int(payload.get("days", 30))
trim_old = bool(payload.get("trim_old", False))
except Exception:
days = 30
trim_old = False
aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old)
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
@router.post("/api/stats")
async def stats_proxy(request: Request, model: Optional[str] = None):
"""
Return token usage statistics for a specific model.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
db = get_db()
token_data = await db.get_token_counts_for_model(model)
if not token_data:
raise HTTPException(
status_code=404, detail="No token data found for this model"
)
time_series = [
entry async for entry in db.get_time_series_for_model(model)
]
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
return {
'model': model,
'input_tokens': token_data['input_tokens'],
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
'time_series': time_series,
'endpoint_distribution': endpoint_distribution,
}
@router.get("/api/affinity_stats")
async def affinity_stats(request: Request):
"""
Aggregate live conversation-affinity pins, one entry per pinned conversation.
Each entry exposes only the endpoint, model, and remaining TTL in seconds
no fingerprints or content. When conversation_affinity is disabled the
`entries` list is always empty.
"""
config = get_config()
if not config.conversation_affinity:
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
now = time.monotonic()
entries: list[dict] = []
llama_eps = set(config.llama_server_endpoints)
async with _affinity_lock:
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
remaining = expires_at - now
if remaining <= 0:
_affinity_map.pop(fp, None)
continue
# Mirror the normalisation used by /api/ps_details so the dashboard
# can join affinity entries to PS rows by (endpoint, model).
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
entries.append({
"endpoint": ep,
"model": display_model,
"remaining": round(remaining, 2),
})
return {
"enabled": True,
"ttl": config.conversation_affinity_ttl,
"entries": entries,
}
@router.get("/api/usage")
async def usage_proxy(request: Request):
"""
Return a snapshot of the usage counter for each endpoint.
Useful for debugging / monitoring.
"""
return {"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts}
@router.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
Ollama endpoints and llama_server_endpoints. The frontend uses this
to display which endpoints are being proxied and their health.
Status is "error" when either liveness (/api/version) or routing
health (/api/ps) fails see issue #83.
"""
config = get_config()
async def check(url: str) -> dict:
return {"url": url, **(await _endpoint_health(url, timeout=5))}
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
llama_results = []
if config.llama_server_endpoints:
llama_results = await asyncio.gather(
*[check(ep) for ep in config.llama_server_endpoints]
)
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
"require_router_api_key": bool(config.router_api_key),
}
@router.get("/api/cache/stats")
async def cache_stats():
"""Return hit/miss counters and configuration for the LLM response cache."""
c = get_llm_cache()
if c is None:
return {"enabled": False}
return {"enabled": True, **c.stats()}
@router.post("/api/cache/invalidate")
async def cache_invalidate():
"""Clear all entries from the LLM response cache and reset counters."""
c = get_llm_cache()
if c is None:
return {"enabled": False, "cleared": False}
await c.clear()
return {"enabled": True, "cleared": True}
@router.get("/health")
async def health_proxy(request: Request):
"""
Healthcheck endpoint for monitoring the proxy.
* Queries each configured endpoint for both liveness and routing health:
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
OpenAI-compatible endpoints at `/models`.
* Returns a JSON object containing:
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
- `endpoints`: a mapping of endpoint URL `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
config = get_config()
# Run all health checks in parallel.
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
# (vLLM, llama-server, external) expose /models, which serves both
# purposes. Probing /api/version alone would miss the case where the
# Ollama process is up but /api/ps is failing — see issue #83.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints += llama_eps_extra
probe_results = await asyncio.gather(
*(_endpoint_health(ep) for ep in all_endpoints),
)
health_summary = dict(zip(all_endpoints, probe_results))
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
response_payload = {
"status": "ok" if overall_ok else "error",
"endpoints": health_summary,
}
http_status = 200 if overall_ok else 503
return JSONResponse(content=response_payload, status_code=http_status)
@router.get("/api/hostname")
async def get_hostname():
"""Return the hostname of the machine running the router."""
return JSONResponse(content={"hostname": socket.gethostname()})
@router.get("/api/usage-stream")
async def usage_stream(request: Request):
"""
ServerSentEvents that emits a JSON payload every time the
global `usage_counts` dictionary changes.
"""
async def event_generator():
# The queue that receives *every* new snapshot
queue = await subscribe()
try:
while True:
# If the client disconnects, cancel the loop
if await request.is_disconnected():
break
data = await queue.get()
if data is None:
break
# Send the data as a single SSE message
yield f"data: {data}\n\n"
finally:
# Cleanup: unsubscribe from the broadcast channel
await unsubscribe(queue)
return StreamingResponse(event_generator(), media_type="text/event-stream")

1134
api/ollama.py Normal file

File diff suppressed because it is too large Load diff

804
api/openai.py Normal file
View file

@ -0,0 +1,804 @@
"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
The chat-completions and completions handlers carry the full reactive-trim
logic for ``exceed_context_size_error`` plus connection-failure rerouting
(``_mark_backend_unhealthy``). The streaming branches assemble cached
responses on the fly so caching works for both streaming and non-streaming
clients.
"""
import asyncio
import base64
import math
import aiohttp
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache, openai_nonstream_to_sse
from config import get_config
from context_window import (
_count_message_tokens,
_trim_messages_for_context,
_calibrated_trim_target,
_endpoint_nctx,
_CTX_TRIM_SMALL_LIMIT,
)
from fingerprint import _conversation_fingerprint
from security import _mask_secrets
from state import token_queue, app_state, default_headers
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
from backends.normalize import (
dedupe_on_keys,
ep2base,
is_ext_openai_endpoint,
is_openai_compatible,
_normalize_llama_model_name,
)
from backends.probe import fetch
from backends.sessions import _make_openai_client, get_session
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
from requests.rechunk import rechunk
from routing import choose_endpoint, decrement_usage
router = APIRouter()
@router.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
"""
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
doc = payload.get("input")
# Normalize multimodal input: extract only text parts for embedding models
if isinstance(doc, list):
normalized = []
for item in doc:
if isinstance(item, dict):
# Multimodal content part - extract text only, skip images
if item.get("type") == "text":
normalized.append(item.get("text", ""))
# Skip image_url and other non-text types
else:
normalized.append(item)
doc = normalized if len(normalized) != 1 else normalized[0]
elif isinstance(doc, dict) and doc.get("type") == "text":
doc = doc.get("text", "")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not doc:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
api_key = "ollama"
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
try:
async_gen = await oclient.embeddings.create(input=doc, model=model)
result = async_gen.model_dump()
for item in result.get("data", []):
emb = item.get("embedding")
if emb:
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
return JSONResponse(content=result)
finally:
await decrement_usage(endpoint, tracking_model)
@router.post("/v1/chat/completions")
async def openai_chat_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
response_format = payload.get("response_format")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"response_format": response_format,
"stream_options": stream_options or {"include_usage": True },
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"seed": seed,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"stop": stop,
"stream": stream,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Reject unsupported image formats (SVG) before doing any work
for _msg in messages:
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
if _item.get("type") == "image_url":
_url = (_item.get("image_url") or {}).get("url", "")
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
raise HTTPException(
status_code=400,
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
)
# Cache lookup — before endpoint selection
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
_cached = await _cache.get_chat("openai_chat", model, messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ochat_stream():
yield _sse
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ochat_json():
yield _cached
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list:
"""Fetch remote image URLs and convert them to base64 data URLs so
Ollama/llama-server can handle them without making outbound HTTP requests."""
resolved = []
for msg in msgs:
content = msg.get("content")
if not isinstance(content, list):
resolved.append(msg)
continue
new_content = []
for item in content:
if item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url", "")
if url and not url.startswith("data:"):
try:
http: aiohttp.ClientSession = app_state["session"]
async with http.get(url) as resp:
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
img_bytes = await resp.read()
b64 = base64.b64encode(img_bytes).decode("utf-8")
new_content.append({
"type": "image_url",
"image_url": {"url": f"data:{ctype};base64,{b64}"}
})
except Exception as _ie:
print(f"[image] Failed to fetch image URL: {_ie}")
new_content.append(item)
else:
new_content.append(item)
else:
new_content.append(item)
resolved.append({**msg, "content": new_content})
return resolved
# Make the API call in handler scope — try/except inside async generators is unreliable
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
send_params = params
if not is_ext_openai_endpoint(endpoint):
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
send_params = {**params, "messages": resolved_msgs}
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
_pre_est = _count_message_tokens(send_params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = send_params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
send_params = {**send_params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
_e_str = str(e)
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
if "does not support tools" in _e_str:
# Model doesn't support tools — retry without them
print(f"[ochat] retry: no tools", flush=True)
try:
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_ctx_err:
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
if not n_ctx_limit:
import re as _re
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
except Exception as _helper_exc:
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
await decrement_usage(endpoint, tracking_model)
raise
dropped = len(msgs_to_trim) - len(trimmed_messages)
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
print(f"[ctx-trim] retry-1 ok", flush=True)
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
# Still too large — tool definitions likely consuming too many tokens, strip them too
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
print(f"[ctx-trim] retry-2 ok", flush=True)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
# Upstream connection failed (e.g. llama-server in router mode
# whose delegated worker died). Mark (endpoint, model) so the
# next request reroutes; the client will retry this one.
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
# Model doesn't support images — strip and retry
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
# 4. Async generator — only streams the already-established async_gen
async def stream_ochat_response():
try:
if stream == True:
content_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
delta = chunk.choices[0].delta
has_content = delta.content is not None
has_reasoning = (
getattr(delta, "reasoning_content", None) is not None
or getattr(delta, "reasoning", None) is not None
)
has_tool_calls = getattr(delta, "tool_calls", None) is not None
if has_content or has_reasoning or has_tool_calls:
yield f"data: {data}\n\n".encode("utf-8")
if has_content and delta.content:
content_parts.append(delta.content)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Detect context exhaustion mid-generation for small-ctx models.
# Guard: skip if max_tokens was set in the request — finish_reason=length
# could just mean the caller's token budget was exhausted, not the context window.
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
_inferred_nctx = (prompt_tok + comp_tok) or 0
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and content_parts:
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_chat", model, messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ochat_response(),
media_type="text/event-stream" if stream else "application/json",
)
@router.post("/v1/completions")
async def openai_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
optional_params = {
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options or {"include_usage": True },
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"suffix": suffix
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Cache lookup — completions prompt mapped to a single-turn messages list
_cache = get_llm_cache()
_compl_messages = [{"role": "user", "content": prompt}]
if _cache is not None and _cache_enabled:
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ocompl_stream():
yield _sse
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ocompl_json():
yield _cached
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
# Make the API call in handler scope (try/except inside async generators is unreliable)
try:
async_gen = await oclient.completions.create(**params)
except Exception as e:
if _is_backend_connection_error(e):
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, str(e))
await decrement_usage(endpoint, tracking_model)
raise
async def stream_ocompletions_response(model=model):
try:
if stream == True:
text_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
choice = chunk.choices[0]
has_text = getattr(choice, "text", None) is not None
has_reasoning = (
getattr(choice, "reasoning_content", None) is not None
or getattr(choice, "reasoning", None) is not None
)
if has_text or has_reasoning or choice.finish_reason is not None:
yield f"data: {data}\n\n".encode("utf-8")
if has_text and choice.text:
text_parts.append(choice.text)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and text_parts:
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
# Final DONE event
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ocompletions_response(),
media_type="text/event-stream" if stream else "application/json",
)
@router.get("/v1/models")
async def openai_models_proxy(request: Request):
"""
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
For Ollama endpoints: queries /api/tags (all models)
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
"""
config = get_config()
# 1. Query Ollama endpoints for all models via /api/tags
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
# 3. Query llama-server endpoints for loaded models via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
for ep in all_llama_endpoints
]
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
models = {'data': []}
# Add Ollama models (if any)
if ollama_models:
for modellist in ollama_models:
for model in modellist:
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add external OpenAI models (if any)
if ext_openai_models:
for modellist in ext_openai_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add llama-server models (all available, not just loaded)
if llama_models:
for modellist in llama_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"data": dedupe_on_keys(models['data'], ['name'])},
status_code=200,
)
@router.post("/v1/rerank")
@router.post("/rerank")
async def rerank_proxy(request: Request):
"""
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
Compatible with the Jina/Cohere rerank API convention used by llama-server,
vLLM, and services such as Cohere and Jina AI.
Ollama does not natively support reranking; requests routed to a plain Ollama
endpoint will receive a 501 Not Implemented response.
Request body:
model (str, required) reranker model name
query (str, required) search query
documents (list[str], required) candidate documents to rank
top_n (int, optional) limit returned results (default: all)
return_documents (bool, optional) include document text in results
max_tokens_per_doc (int, optional) truncation limit per document
Response (Jina/Cohere-compatible):
{
"id": "...",
"model": "...",
"usage": {"prompt_tokens": N, "total_tokens": N},
"results": [{"index": 0, "relevance_score": 0.95}, ...]
}
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
query = payload.get("query")
documents = payload.get("documents")
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if not query:
raise HTTPException(status_code=400, detail="Missing required field 'query'")
if not isinstance(documents, list) or not documents:
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Determine which endpoint serves this model
try:
endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException(
status_code=501,
detail=(
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
"dedicated reranker model."
),
)
if ":latest" in model:
model = model.split(":latest")[0]
# Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
if optional_key in payload:
upstream_payload[optional_key] = payload[optional_key]
# Determine upstream URL:
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
# External OpenAI endpoints expose /rerank under their /v1 base
if endpoint in config.llama_server_endpoints:
# llama-server: endpoint may or may not already contain /v1
if "/v1" in endpoint:
rerank_url = f"{endpoint}/rerank"
else:
rerank_url = f"{endpoint}/v1/rerank"
else:
# External OpenAI-compatible: ep2base gives us the /v1 base
rerank_url = f"{ep2base(endpoint)}/rerank"
api_key = config.api_keys.get(endpoint, "no-key")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
client: aiohttp.ClientSession = get_session(endpoint)
try:
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
response_bytes = await resp.read()
if resp.status >= 400:
raise HTTPException(
status_code=resp.status,
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
)
data = orjson.loads(response_bytes)
# Record token usage if the upstream returned a usage object
usage = data.get("usage") or {}
prompt_tok = usage.get("prompt_tokens") or 0
total_tok = usage.get("total_tokens") or 0
# For reranking there are no completion tokens; we record prompt tokens only
if prompt_tok or total_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
return JSONResponse(content=data)
finally:
await decrement_usage(endpoint, tracking_model)

30
api/static.py Normal file
View file

@ -0,0 +1,30 @@
"""Static-asset and dashboard routes."""
from pathlib import Path
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import HTMLResponse, RedirectResponse
# Directory containing static files (resolved relative to project root).
STATIC_DIR = Path(__file__).resolve().parent.parent / "static"
router = APIRouter()
@router.get("/favicon.ico")
async def redirect_favicon():
return RedirectResponse(url="/static/favicon.ico")
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""
Render the dynamic NOMYO Router dashboard listing the configured endpoints
and the models details, availability & task status.
"""
index_path = STATIC_DIR / "index.html"
try:
return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Page not found")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

0
backends/__init__.py Normal file
View file

136
backends/health.py Normal file
View file

@ -0,0 +1,136 @@
"""Backend health probes and error classification helpers.
Contains:
* cache-freshness check (``_is_fresh``)
* aiohttp response success assertion (``_ensure_success``)
* human-readable connection-issue formatter
* upstream-error detection that distinguishes connection failures from
legitimate 4xx responses (``_is_backend_connection_error``)
* per-(endpoint, model) unhealthy marker that feeds ``choose_endpoint``
* llama-server status interpretation (``_is_llama_model_loaded`` etc.)
"""
import asyncio
import time
from urllib.parse import urlparse
import aiohttp
import openai
from fastapi import HTTPException
from security import _mask_secrets
from state import _completion_error_cache, _completion_error_cache_lock
def _is_fresh(cached_at: float, ttl: int) -> bool:
return (time.time() - cached_at) < ttl
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
if resp.status >= 400:
text = await resp.text()
raise HTTPException(status_code=resp.status, detail=_mask_secrets(text))
def _format_connection_issue(url: str, error: Exception) -> str:
"""
Provide a human-friendly error string for connection failures so operators
know which endpoint and address failed from inside the container.
"""
parsed = urlparse(url)
host_hint = parsed.hostname or ""
port_hint = parsed.port or ""
if isinstance(error, aiohttp.ClientConnectorError):
resolved_host = getattr(error, "host", host_hint) or host_hint or "?"
resolved_port = getattr(error, "port", port_hint) or port_hint or "?"
parts = [
f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).",
"Ensure the endpoint address is reachable from within the container.",
]
if resolved_host in {"localhost", "127.0.0.1"}:
parts.append(
"Inside Docker, 'localhost' refers to the container itself; use "
"'host.docker.internal' or a Docker network alias if the service "
"runs on the host machine."
)
os_error = getattr(error, "os_error", None)
if isinstance(os_error, OSError):
errno = getattr(os_error, "errno", None)
strerror = os_error.strerror or str(os_error)
if errno is not None or strerror:
parts.append(f"OS error [{errno}]: {strerror}.")
elif os_error:
parts.append(f"OS error: {os_error}.")
parts.append(f"Original error: {error}.")
return " ".join(parts)
if isinstance(error, asyncio.TimeoutError):
return (
f"Timed out waiting for {url}. "
"The remote endpoint may be offline or slow to respond."
)
return f"Error while contacting {url}: {error}"
def _is_backend_connection_error(exc: Exception) -> bool:
"""True for upstream connection-class failures observed via the OpenAI client.
Targets the case where a llama-server in router mode keeps answering
/v1/models but its delegated worker for a specific model is dead, so
chat/completions calls return 5xx with 'proxy error: Could not establish
connection' (or the SDK raises APIConnectionError outright).
Excludes BadRequestError with exceed_context_size_error by design those
must stay on the reactive-trim path.
"""
if isinstance(exc, openai.APIConnectionError):
return True
if isinstance(exc, openai.InternalServerError):
msg = str(exc).lower()
return (
"proxy error" in msg
or "could not establish connection" in msg
or "connection refused" in msg
)
return False
async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None:
"""Record (endpoint, model) as broken so choose_endpoint avoids it.
Cleared only by TTL the dead-worker failure mode is invisible to the
/v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot
rely on a successful probe as a recovery signal.
"""
async with _completion_error_cache_lock:
_completion_error_cache[(endpoint, model)] = time.time()
print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True)
def _is_llama_model_loaded(item: dict) -> bool:
"""Return True if a llama-server /v1/models item has status 'loaded'.
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
If no status field is present, the model is always-loaded (not dynamically managed)."""
status = item.get("status")
if status is None:
return True # No status field: model is always loaded (e.g. single-model servers)
if isinstance(status, dict):
return status.get("value") == "loaded"
if isinstance(status, str):
return status == "loaded"
return False
def _is_llama_model_loaded_or_sleeping(item: dict) -> bool:
"""Return True if status is 'loaded' or 'sleeping'.
Newer llama-server versions report 'sleeping' in /v1/models when a model is idle;
ps_details needs to include these so _fetch_llama_props can detect and unload them."""
status = item.get("status")
if status is None:
return True
if isinstance(status, dict):
return status.get("value") in ("loaded", "sleeping")
if isinstance(status, str):
return status in ("loaded", "sleeping")
return False

113
backends/normalize.py Normal file
View file

@ -0,0 +1,113 @@
"""Endpoint URL, model-name, and endpoint-classification helpers.
The endpoint classifiers read live config via ``get_config()`` so that the
startup-time rebind of ``config`` in router.py is picked up at call time.
"""
from config import get_config
def _normalize_llama_model_name(name: str) -> str:
"""Extract the model name from a huggingface-style identifier.
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
"""
if "/" in name:
name = name.rsplit("/", 1)[1]
if ":" in name:
name = name.split(":")[0]
return name
def _extract_llama_quant(name: str) -> str:
"""Extract the quantization level from a huggingface-style identifier.
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
Returns empty string if no quant suffix is present.
"""
if ":" in name:
return name.rsplit(":", 1)[1]
return ""
def ep2base(ep):
if "/v1" in ep:
base_url = ep
else:
base_url = ep + "/v1"
return base_url
def dedupe_on_keys(dicts, key_fields):
"""
Helper function to deduplicate endpoint details based on given dict keys.
"""
seen = set()
out = []
for d in dicts:
# Build a tuple of the values for the chosen keys
key = tuple(d.get(k) for k in key_fields)
if key not in seen:
seen.add(key)
out.append(d)
return out
def is_ext_openai_endpoint(endpoint: str) -> bool:
"""
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
Returns True for:
- External services like OpenAI.com, Groq, etc.
Returns False for:
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
- llama-server endpoints (explicitly configured in llama_server_endpoints)
"""
cfg = get_config()
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
if endpoint in cfg.llama_server_endpoints:
return False
if "/v1" not in endpoint:
return False
base_endpoint = endpoint.replace('/v1', '')
if base_endpoint in cfg.endpoints:
return False # It's Ollama's /v1
# Check for default Ollama port
if ':11434' in endpoint:
return False # It's Ollama
return True # It's an external OpenAI endpoint
def is_openai_compatible(endpoint: str) -> bool:
"""
Return True if the endpoint speaks the OpenAI API (not native Ollama).
This includes external OpenAI endpoints AND llama-server endpoints.
"""
return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints
def get_tracking_model(endpoint: str, model: str) -> str:
"""
Normalize model name for tracking purposes so it matches the PS table key.
- For llama-server endpoints: strips HF prefix and quantization suffix
- For Ollama endpoints: appends ":latest" if no version suffix is present
- For external OpenAI endpoints: returns as-is (not shown in PS)
This ensures consistent model naming across all routes for usage tracking.
"""
# External OpenAI endpoints are not shown in PS, keep as-is
if is_ext_openai_endpoint(endpoint):
return model
# llama-server endpoints use normalized names in PS
if endpoint in get_config().llama_server_endpoints:
return _normalize_llama_model_name(model)
# Ollama endpoints: append ":latest" if no version suffix
if ":" not in model:
return model + ":latest"
return model

456
backends/probe.py Normal file
View file

@ -0,0 +1,456 @@
"""Backend probe / discovery primitives.
The ``fetch`` class wraps the three discovery paths the router uses:
* ``available_models`` what the endpoint advertises (Ollama ``/api/tags``
or OpenAI-style ``/v1/models``)
* ``loaded_models`` what is currently resident (Ollama ``/api/ps`` or
llama-server ``/v1/models`` filtered on ``status == "loaded"``)
* ``endpoint_details`` arbitrary detail fetch used by management routes
Each path goes through three layers of cache: success cache, error cache,
and an in-flight request map. Stale-while-revalidate refreshes happen in
background tasks tracked by the ``_bg_refresh_*`` maps in ``state``.
``_raw_probe`` and ``_endpoint_health`` are the lower-level dual probes
used by ``/health`` and ``/api/config`` to distinguish a healthy daemon
with a broken model-introspection path from a dead daemon.
"""
import asyncio
import time
from typing import List, Optional, Set
import aiohttp
from config import get_config
from state import (
_models_cache,
_models_cache_lock,
_loaded_models_cache,
_loaded_models_cache_lock,
_available_error_cache,
_available_error_cache_lock,
_loaded_error_cache,
_loaded_error_cache_lock,
_inflight_available_models,
_inflight_loaded_models,
_inflight_lock,
_bg_refresh_available,
_bg_refresh_loaded,
_bg_refresh_lock,
default_headers,
)
from backends.sessions import get_probe_session
from backends.health import (
_is_fresh,
_ensure_success,
_format_connection_issue,
_is_llama_model_loaded,
)
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible
class fetch:
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Internal function that performs the actual HTTP request to fetch available models.
This is called by available_models() after checking caches and in-flight requests.
"""
cfg = get_config()
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
ep_base = endpoint.rstrip("/")
if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint:
endpoint_url = f"{ep_base}/v1/models"
key = "data"
elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints:
endpoint_url = f"{ep_base}/models"
key = "data"
else:
endpoint_url = f"{ep_base}/api/tags"
key = "models"
client: aiohttp.ClientSession = get_probe_session(endpoint)
try:
async with client.get(endpoint_url, headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
items = data.get(key, [])
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
async with _models_cache_lock:
_models_cache[endpoint] = (models, time.time())
return models
except Exception as e:
# Treat any error as if the endpoint offers no models
message = _format_connection_issue(endpoint_url, e)
print(f"[fetch.available_models] {message}")
# Update error cache with lock protection
async with _available_error_cache_lock:
_available_error_cache[endpoint] = time.time()
return set()
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
"""
Background task to refresh available models cache without blocking the caller.
Used for stale-while-revalidate pattern.
Deduplicates: only one background refresh runs per endpoint at a time.
"""
async with _bg_refresh_lock:
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
return # A refresh is already running for this endpoint
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
_bg_refresh_available[endpoint] = task
try:
await task
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
finally:
async with _bg_refresh_lock:
if _bg_refresh_available.get(endpoint) is task:
_bg_refresh_available.pop(endpoint, None)
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
Query <endpoint>/api/tags and return a set of all model names that the
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
every model that is installed on the Ollama instance, regardless of
whether the model is currently loaded into memory.
Uses request coalescing to prevent cache stampede: if multiple requests
arrive when cache is expired, only one actual HTTP request is made.
Uses stale-while-revalidate: when the cache is between 300-600s old,
the stale data is returned immediately while a background refresh runs.
This prevents model blackouts caused by transient timeouts.
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
set is returned.
"""
# Check models cache with lock protection
async with _models_cache_lock:
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
# FRESH: <= 300s old - return immediately
if _is_fresh(cached_at, 300):
return models
# STALE: 300-600s old - return stale data and refresh in background
if _is_fresh(cached_at, 600):
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
return models # Return stale data immediately
# EXPIRED: > 600s old - too stale, must refresh synchronously
del _models_cache[endpoint]
# Check error cache with lock protection
async with _available_error_cache_lock:
if endpoint in _available_error_cache:
err_age = time.time() - _available_error_cache[endpoint]
if err_age < 30:
# Very fresh error (<30s) endpoint likely still down, bail fast
return set()
elif err_age < 300:
# Stale error (30-300s) endpoint may have recovered, probe in background
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
return set()
# Error expired (>300s) remove and fall through to fresh fetch
del _available_error_cache[endpoint]
# Request coalescing: check if another request is already fetching this endpoint
async with _inflight_lock:
if endpoint in _inflight_available_models:
# Another request is already fetching - wait for it
task = _inflight_available_models[endpoint]
else:
# Create new fetch task
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
_inflight_available_models[endpoint] = task
try:
# Wait for the fetch to complete (either ours or another request's)
result = await task
return result
finally:
# Clean up in-flight tracking (only if we created it)
async with _inflight_lock:
if _inflight_available_models.get(endpoint) == task:
_inflight_available_models.pop(endpoint, None)
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
"""
Internal function that performs the actual HTTP request to fetch loaded models.
This is called by loaded_models() after checking caches and in-flight requests.
For Ollama endpoints: queries /api/ps and returns model names
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
"""
client: aiohttp.ClientSession = get_probe_session(endpoint)
cfg = get_config()
# Check if this is a llama-server endpoint
if endpoint in cfg.llama_server_endpoints:
# Query /v1/models for llama-server. Send the configured key as a
# Bearer token — current llama.cpp leaves /models public, but a
# build/config that protects it would otherwise 401 this probe.
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
api_key = cfg.api_keys.get(endpoint)
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
try:
async with client.get(f"{endpoint}/models", headers=headers) as resp:
await _ensure_success(resp)
data = await resp.json()
# Filter for loaded models only
items = data.get("data", [])
models = {
item.get("id")
for item in items
if item.get("id") and _is_llama_model_loaded(item)
}
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
# Probe succeeded — clear any stale error so the endpoint
# becomes routable again.
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/models", e)
print(f"[fetch.loaded_models] {message}")
# Record the failure so `choose_endpoint` can avoid routing
# to an unhealthy backend and repeated probes short-circuit.
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
else:
# Original Ollama /api/ps logic
try:
async with client.get(f"{endpoint}/api/ps") as resp:
await _ensure_success(resp)
data = await resp.json()
# The response format is:
# {"models": [{"name": "model1"}, {"name": "model2"}]}
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
# Update cache with lock protection
async with _loaded_models_cache_lock:
_loaded_models_cache[endpoint] = (models, time.time())
async with _loaded_error_cache_lock:
_loaded_error_cache.pop(endpoint, None)
return models
except Exception as e:
# If anything goes wrong we simply assume the endpoint has no models
message = _format_connection_issue(f"{endpoint}/api/ps", e)
print(f"[fetch.loaded_models] {message}")
async with _loaded_error_cache_lock:
_loaded_error_cache[endpoint] = time.time()
return set()
async def _refresh_loaded_models(endpoint: str) -> None:
"""
Background task to refresh loaded models cache without blocking the caller.
Used for stale-while-revalidate pattern.
Deduplicates: only one background refresh runs per endpoint at a time.
"""
async with _bg_refresh_lock:
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
return # A refresh is already running for this endpoint
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
_bg_refresh_loaded[endpoint] = task
try:
await task
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
finally:
async with _bg_refresh_lock:
if _bg_refresh_loaded.get(endpoint) is task:
_bg_refresh_loaded.pop(endpoint, None)
async def loaded_models(endpoint: str) -> Set[str]:
"""
Query <endpoint>/api/ps and return a set of model names that are currently
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
set is returned.
Uses request coalescing to prevent cache stampede and stale-while-revalidate
to serve requests immediately even when cache is stale (refreshing in background).
"""
if is_ext_openai_endpoint(endpoint):
return set()
# Check loaded models cache with lock protection
async with _loaded_models_cache_lock:
if endpoint in _loaded_models_cache:
models, cached_at = _loaded_models_cache[endpoint]
# FRESH: < 10s old - return immediately
if _is_fresh(cached_at, 10):
return models
# STALE: 10-60s old - return stale data and refresh in background
if _is_fresh(cached_at, 60):
# Kick off background refresh (fire-and-forget)
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
return models # Return stale data immediately
# EXPIRED: > 60s old - too stale, must refresh synchronously
del _loaded_models_cache[endpoint]
# Check error cache with lock protection
async with _loaded_error_cache_lock:
if endpoint in _loaded_error_cache:
if _is_fresh(_loaded_error_cache[endpoint], 300):
return set()
# Error expired - remove it
del _loaded_error_cache[endpoint]
# Request coalescing: check if another request is already fetching this endpoint
async with _inflight_lock:
if endpoint in _inflight_loaded_models:
# Another request is already fetching - wait for it
task = _inflight_loaded_models[endpoint]
else:
# Create new fetch task
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
_inflight_loaded_models[endpoint] = task
try:
# Wait for the fetch to complete (either ours or another request's)
result = await task
return result
finally:
# Clean up in-flight tracking (only if we created it)
async with _inflight_lock:
if _inflight_loaded_models.get(endpoint) == task:
_inflight_loaded_models.pop(endpoint, None)
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]:
"""
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
When ``skip_error_cache`` is False (the default), the call is short-circuited
if the endpoint recently failed (recorded in ``_available_error_cache``).
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
``timeout`` overrides the session default for this single request (seconds, total).
"""
# Fast-fail if the endpoint is known to be down (unless caller opts out)
if not skip_error_cache:
async with _available_error_cache_lock:
if endpoint in _available_error_cache:
if _is_fresh(_available_error_cache[endpoint], 300):
return []
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}"
client: aiohttp.ClientSession = get_probe_session(endpoint)
req_kwargs = {}
if timeout is not None:
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
try:
async with client.get(request_url, headers=headers, **req_kwargs) as resp:
await _ensure_success(resp)
data = await resp.json()
detail = data.get(detail, [])
return detail
except Exception as e:
# If anything goes wrong we cannot reply details
message = _format_connection_issue(request_url, e)
print(f"[fetch.endpoint_details] {message}")
if not skip_error_cache:
async with _available_error_cache_lock:
_available_error_cache[endpoint] = time.time()
return []
# -------------------------------------------------------------
# Endpoint health probes (shared by /api/config and /health)
# -------------------------------------------------------------
async def _raw_probe(
ep: str,
route: str,
api_key: Optional[str] = None,
timeout: Optional[float] = None,
) -> tuple[bool, object]:
"""Direct HTTP probe that distinguishes success from failure
(unlike `fetch.endpoint_details`, which returns [] on either).
Returns `(ok, payload_or_error_message)`.
"""
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
if api_key is not None:
headers["Authorization"] = "Bearer " + api_key
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
req_kwargs = {}
if timeout is not None:
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
try:
client: aiohttp.ClientSession = get_probe_session(ep)
async with client.get(url, headers=headers, **req_kwargs) as resp:
await _ensure_success(resp)
data = await resp.json()
return True, data
except Exception as exc:
return False, _format_connection_issue(url, exc)
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
"""Probe an endpoint and return `{status, version?, detail?}`.
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
that a daemon which is reachable but has a broken model-introspection
path (issue #83) is reported as `error` rather than `ok`.
OpenAI-compatible endpoints use a single `/models` probe.
"""
if is_openai_compatible(ep):
ok, payload = await _raw_probe(
ep, "/models", get_config().api_keys.get(ep), timeout=timeout,
)
if ok:
return {"status": "ok", "version": "latest"}
return {"status": "error", "detail": str(payload)}
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
_raw_probe(ep, "/api/version", timeout=timeout),
_raw_probe(ep, "/api/ps", timeout=timeout),
)
version_value = (
version_payload.get("version")
if version_ok and isinstance(version_payload, dict)
else None
)
if version_ok and ps_ok:
return {"status": "ok", "version": version_value}
if not version_ok and not ps_ok:
return {"status": "error", "detail": str(version_payload)}
# Partial failure — daemon reachable but one probe failed. Report
# as "error" so callers can surface the issue; include `version` so
# the operator knows the daemon itself is alive.
if not ps_ok:
return {
"status": "error",
"version": version_value,
"detail": f"/api/ps: {ps_payload}",
}
return {
"status": "error",
"detail": f"/api/version: {version_payload}",
}

121
backends/sessions.py Normal file
View file

@ -0,0 +1,121 @@
"""aiohttp / OpenAI client factories aware of Unix-socket endpoints.
Unix socket endpoints follow the ``.sock`` hostname convention (e.g.
``http://192.168.0.52.sock/v1``) and resolve to ``/run/user/<uid>/<host>``.
Their sessions/clients live in ``state.app_state`` so that startup can
populate them once and routes can reuse them.
"""
import os
import aiohttp
import ollama
import openai
from state import app_state
from backends.normalize import ep2base
def _is_unix_socket_endpoint(endpoint: str) -> bool:
"""Return True if endpoint uses Unix socket (.sock hostname convention).
Detects URLs like http://192.168.0.52.sock/v1 where the host ends with
.sock, indicating the connection should use a Unix domain socket at
/tmp/<host> instead of TCP.
"""
try:
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
return host.endswith(".sock")
except IndexError:
return False
def _get_socket_path(endpoint: str) -> str:
"""Derive Unix socket file path from a .sock endpoint URL.
http://192.168.0.52.sock/v1 -> /run/user/<uid>/192.168.0.52.sock
"""
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
return f"/run/user/{os.getuid()}/{host}"
def get_session(endpoint: str) -> aiohttp.ClientSession:
"""Return the appropriate aiohttp session for the given endpoint.
Unix socket endpoints (.sock) get their own UnixConnector session.
All other endpoints share the main TCP session.
"""
if _is_unix_socket_endpoint(endpoint):
sess = app_state["socket_sessions"].get(endpoint)
if sess is not None:
return sess
return app_state["session"]
def get_probe_session(endpoint: str) -> aiohttp.ClientSession:
"""Return the session used for lightweight health/introspection probes.
Probes (available/loaded models, endpoint health) run on a connection
pool kept separate from the proxy/streaming session, so a burst of
long-lived completion requests cannot starve them otherwise a probe
would queue waiting for a connection, hit its deadline, and mark a
perfectly healthy endpoint as unavailable under load.
Unix socket endpoints keep their dedicated per-endpoint session. TCP
endpoints use the shared probe session, falling back to the main
session when the probe pool has not been initialised (e.g. in tests).
"""
if _is_unix_socket_endpoint(endpoint):
sess = app_state["socket_sessions"].get(endpoint)
if sess is not None:
return sess
return app_state.get("probe_session") or app_state["session"]
def get_ollama_client(endpoint: str) -> ollama.AsyncClient:
"""Return a cached ``ollama.AsyncClient`` for the endpoint, creating it once.
``ollama.AsyncClient`` wraps an ``httpx.AsyncClient`` whose construction
builds an SSL context and reloads the OS trust store (~40 ms). It is safe to
reuse concurrently, so we keep one per endpoint instead of building a fresh
one on every request otherwise that 40 ms of CPU runs on the event loop
per request and caps single-worker throughput at ~25 req/s.
"""
cache = app_state["ollama_clients"]
client = cache.get(endpoint)
if client is None:
client = ollama.AsyncClient(host=endpoint)
cache[endpoint] = client
return client
def _make_openai_client(
endpoint: str,
default_headers: dict | None = None,
api_key: str = "no-key",
) -> openai.AsyncOpenAI:
"""Return a cached AsyncOpenAI client configured for the given endpoint.
Clients are cached per ``(endpoint, api_key)`` and reused across requests:
constructing one builds an SSL context and reloads the OS trust store
(~40 ms), which serializes the event loop if done per request. For Unix
socket endpoints, injects the pre-created httpx UDS transport so the OpenAI
SDK connects via the socket instead of TCP.
"""
cache = app_state["openai_clients"]
cache_key = (endpoint, api_key)
client = cache.get(cache_key)
if client is not None:
return client
base_url = ep2base(endpoint)
kwargs: dict = {"api_key": api_key}
if default_headers is not None:
kwargs["default_headers"] = default_headers
if _is_unix_socket_endpoint(endpoint):
http_client = app_state["httpx_clients"].get(endpoint)
if http_client is not None:
kwargs["http_client"] = http_client
base_url = "http://localhost/v1"
client = openai.AsyncOpenAI(base_url=base_url, **kwargs)
cache[cache_key] = client
return client

139
config.py Normal file
View file

@ -0,0 +1,139 @@
"""Router configuration loader.
Pydantic ``BaseSettings`` model populated from YAML (path resolved via
``_config_path_from_env``) with ``${VAR}`` expansion, plus env-var overrides
under the ``NOMYO_ROUTER_`` prefix.
"""
import os
import re
from pathlib import Path
from typing import Dict, List, Optional
import yaml
from pydantic import Field
from pydantic_settings import BaseSettings
class Config(BaseSettings):
# List of Ollama endpoints
endpoints: list[str] = Field(
default_factory=lambda: [
"http://localhost:11434",
]
)
# List of llama-server endpoints (OpenAI-compatible with /v1/models status info)
llama_server_endpoints: List[str] = Field(default_factory=list)
# Max concurrent connections per endpointmodel pair, see OLLAMA_NUM_PARALLEL
max_concurrent_connections: int = 1
# Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}}
endpoint_config: Dict[str, Dict] = Field(default_factory=dict)
# When True, config order = priority; routes by utilization ratio + config index (WRR)
priority_routing: bool = Field(default=False)
# Conversation affinity: route the same conversation back to the endpoint that
# previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm.
# Soft preference — falls back to the standard algorithm when the affine endpoint
# is saturated or no longer has the model loaded.
conversation_affinity: bool = Field(default=False)
# TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min):
# if the backend has already evicted the model, the KV cache is cold anyway.
conversation_affinity_ttl: int = Field(default=300)
api_keys: Dict[str, str] = Field(default_factory=dict)
# Optional router-level API key used to gate access to this service and dashboard
router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY")
# Database configuration
db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db"))
# Semantic LLM Cache configuration
cache_enabled: bool = Field(default=False)
# Backend: "memory" (default, in-process), "sqlite" (persistent), "redis" (distributed)
cache_backend: str = Field(default="memory")
# Cosine similarity threshold: 1.0 = exact match only, <1.0 = semantic (requires :semantic image)
cache_similarity: float = Field(default=1.0)
# TTL in seconds; None = cache forever
cache_ttl: Optional[int] = Field(default=3600)
# SQLite backend: path to cache database file
cache_db_path: str = Field(default="llm_cache.db")
# Redis backend: connection URL
cache_redis_url: str = Field(default="redis://localhost:6379/0")
# Weight of BM25-weighted chat-history embedding vs last-user-message embedding
# 0.3 = 30% history context signal, 70% question signal
cache_history_weight: float = Field(default=0.3)
class Config:
# YAML loading is handled manually via Config.from_yaml(); env vars use this prefix.
env_prefix = "NOMYO_ROUTER_"
@classmethod
def _expand_env_refs(cls, obj):
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
if isinstance(obj, dict):
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
if isinstance(obj, list):
return [cls._expand_env_refs(v) for v in obj]
if isinstance(obj, str):
# Only expand if it is exactly ${VAR}
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
if m:
return os.getenv(m.group(1), "")
return obj
@classmethod
def from_yaml(cls, path: Path) -> "Config":
"""Load the YAML file and create the Config instance."""
if path.exists():
with path.open("r", encoding="utf-8") as fp:
data = yaml.safe_load(fp) or {}
cleaned = cls._expand_env_refs(data)
if isinstance(cleaned, dict):
# Accept hyphenated config key and map it to the field name
key_aliases = [
# canonical field name
"router_api_key",
# lowercase, hyphen/underscore variants
"nomyo-router-api-key",
"nomyo_router_api_key",
"nomyo-router_api_key",
"nomyo_router-api_key",
# uppercase env-style variants
"NOMYO-ROUTER_API_KEY",
"NOMYO_ROUTER_API_KEY",
]
for alias in key_aliases:
if alias in cleaned:
cleaned["router_api_key"] = cleaned.get("router_api_key", cleaned.pop(alias))
break
# If not present in YAML (or empty), fall back to env var explicitly
if not cleaned.get("router_api_key"):
env_key = os.getenv("NOMYO_ROUTER_API_KEY")
if env_key:
cleaned["router_api_key"] = env_key
return cls(**cleaned)
return cls()
def _config_path_from_env() -> Path:
"""
Resolve the configuration file path. Defaults to `config.yaml`
in the current working directory unless NOMYO_ROUTER_CONFIG_PATH
is set.
"""
candidate = os.getenv("NOMYO_ROUTER_CONFIG_PATH")
if candidate:
return Path(candidate).expanduser()
return Path("config.yaml")
# ------------------------------------------------------------------
# Shared config accessor
# ------------------------------------------------------------------
# Submodules read config at call time via get_config() instead of importing
# a bound name. The single source of truth is ``router.config`` — the lazy
# import below resolves it after router.py has finished loading, and lets
# tests that ``patch.object(router, "config", cfg)`` flow through.
def get_config() -> "Config":
"""Return the currently active Config from router.py."""
import router # lazy to avoid module-load circular import
return router.config

120
context_window.py Normal file
View file

@ -0,0 +1,120 @@
"""Sliding-window context-trim helpers.
Mirrors what llama.cpp's context-shift used to do: count tokens with tiktoken
(cl100k_base) when available, drop oldest non-system messages until the prompt
fits inside (n_ctx - safety_margin).
Also owns the per-(endpoint, model) n_ctx cache that the routes populate from
exceed_context_size_error bodies and from finish_reason=="length" signals.
"""
import os
# Point tiktoken at the vendored cl100k_base vocab so the encoding loads offline,
# without a network download. The download would otherwise fail anyway: this repo
# has a top-level `requests` package that shadows the pip `requests` tiktoken's
# downloader imports, so get_encoding() would silently fall back to char/4. See
# vendor/tiktoken/. setdefault lets an explicit env override win.
os.environ.setdefault(
"TIKTOKEN_CACHE_DIR",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "vendor", "tiktoken"),
)
try:
import tiktoken as _tiktoken
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
except Exception:
_tiktoken_enc = None
def _count_message_tokens(messages: list) -> int:
"""Approximate token count for a message list.
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
Falls back to char/4 heuristic if tiktoken is unavailable.
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
"""
if _tiktoken_enc is None:
return sum(len(str(m.get("content", ""))) for m in messages) // 4
total = 2 # priming tokens
for msg in messages:
total += 4 # per-message role/separator overhead
content = msg.get("content", "")
if isinstance(content, str):
total += len(_tiktoken_enc.encode(content))
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
total += len(_tiktoken_enc.encode(part.get("text", "")))
return total
def _trim_messages_for_context(
messages: list,
n_ctx: int,
safety_margin: int = None,
target_tokens: int = None,
) -> list:
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
Keeps all system messages and the most recent non-system messages that fit
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
first (FIFO). The last message is always preserved.
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
response, including RAG tool results and tool call JSON synthesis.
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
Pass a calibrated value when actual n_prompt_tokens is known from the error
body so that tiktoken underestimation vs the backend tokenizer is corrected.
"""
if target_tokens is not None:
target = target_tokens
else:
if safety_margin is None:
safety_margin = n_ctx // 4
target = n_ctx - safety_margin
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
while len(non_system) > 1:
if _count_message_tokens(system_msgs + non_system) <= target:
break
non_system.pop(0) # drop oldest non-system message
# Ensure the first non-system message is a user message (chat templates require it).
# Drop any leading assistant/tool messages that were left after trimming.
while non_system and non_system[0].get("role") != "user":
non_system.pop(0)
return system_msgs + non_system
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
_count_message_tokens only counts message text, so we cannot derive an accurate
per-token scale from the ratio. Instead we compute the *delta* we need to remove
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
Example: actual=17993, n_ctx=16384, headroom=4096 need to shed 5705 backend
tokens shed 6846 tiktoken tokens from messages.
"""
cur_tiktoken = _count_message_tokens(msgs)
headroom = n_ctx // 4 # reserve for generated output
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
tiktoken_to_shed = int(to_shed * 1.2)
return max(1, cur_tiktoken - tiktoken_to_shed)
# Per-(endpoint, model) n_ctx cache.
# Populated from two sources:
# 1. 400 exceed_context_size_error body → n_ctx field
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
# so large-context models (200k+ for coding) are never touched.
_endpoint_nctx: dict[tuple[str, str], int] = {}
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this

11
db.py
View file

@ -4,6 +4,17 @@ from pathlib import Path
from datetime import datetime, timezone from datetime import datetime, timezone
from collections import defaultdict from collections import defaultdict
def get_db() -> "TokenDatabase":
"""Return the live TokenDatabase instance held by router.py.
Resolved lazily so submodules can access it without import cycles, and
so test patches of ``router.db`` flow through to all callers.
"""
import router # lazy to avoid module-load circular import
return router.db
class TokenDatabase: class TokenDatabase:
def __init__(self, db_path: str = "token_counts.db"): def __init__(self, db_path: str = "token_counts.db"):
self.db_path = db_path self.db_path = db_path

35
fingerprint.py Normal file
View file

@ -0,0 +1,35 @@
"""Conversation fingerprinting for prompt-cache-aware routing."""
import hashlib
from typing import Optional
def _conversation_fingerprint(model: str, messages: Optional[list],
prompt: Optional[str]) -> Optional[str]:
"""
Stable hash over (model, first system + first user turn). That prefix
determines whether the backend's prompt cache is reusable; later turns
don't influence the routing decision because they extend the same prefix.
Returns None when there is no usable prefix.
"""
parts: list[str] = [model or "_"]
if messages:
for m in messages:
role = m.get("role") if isinstance(m, dict) else None
if role not in ("system", "user"):
continue
content = m.get("content")
if isinstance(content, list): # OpenAI multimodal parts
content = "".join(
p.get("text", "") for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
if not isinstance(content, str):
continue
parts.append(f"{role}:{content}")
if role == "user":
break
elif prompt:
parts.append(f"user:{prompt}")
else:
return None
return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()

66
images.py Normal file
View file

@ -0,0 +1,66 @@
"""Image and timestamp helpers used by the Ollama/OpenAI request pipeline."""
import base64
import io
import time
from datetime import datetime, timezone
from PIL import Image
def iso8601_ns():
ns = time.time_ns()
sec, ns_rem = divmod(ns, 1_000_000_000)
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
return (
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
f"{ns_rem:09d}Z"
)
def is_base64(image_string):
try:
if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode():
return True
except Exception:
return False
def resize_image_if_needed(image_data):
try:
# Check if already data-url
if image_data.startswith("data:"):
try:
header, image_data = image_data.split(",", 1)
except ValueError:
pass
# Decode the base64 image data
image_bytes = base64.b64decode(image_data)
with Image.open(io.BytesIO(image_bytes)) as image:
if image.mode not in ("RGB", "L"):
image = image.convert("RGB")
# Get current size
width, height = image.size
# Calculate the new dimensions while maintaining aspect ratio
if width > 512 or height > 512:
aspect_ratio = width / height
if aspect_ratio > 1: # Width is larger
new_width = 512
new_height = int(512 / aspect_ratio)
else: # Height is larger
new_height = 512
new_width = int(512 * aspect_ratio)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Encode the resized image back to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return resized_image_data
except Exception as e:
print(f"Error processing image: {e}")
return None

0
requests/__init__.py Normal file
View file

218
requests/chat.py Normal file
View file

@ -0,0 +1,218 @@
"""High-level chat request orchestrator.
``_make_chat_request`` is the shared core that:
* picks an endpoint via ``choose_endpoint`` (which atomically reserves a slot),
* dispatches to either the native Ollama client or an OpenAI-compatible
client based on endpoint type,
* applies reactive context trimming when the backend rejects with
``exceed_context_size_error``,
* counts tokens for billing/SSE,
* always releases the reservation via ``decrement_usage`` in ``finally``.
``_make_moe_requests`` builds on it to implement the
"3 responses + 3 critiques + 1 final" mixture-of-experts dance.
"""
import asyncio
import re
import time
import ollama
import enhance
from config import get_config
from state import default_headers, token_queue
from context_window import _trim_messages_for_context, _calibrated_trim_target
from backends.normalize import is_openai_compatible
from backends.sessions import _make_openai_client
from routing import choose_endpoint, decrement_usage
from requests.messages import (
get_last_user_content,
transform_tool_calls_to_openai,
transform_images_to_data_urls,
_strip_assistant_prefill,
_strip_images_from_messages,
_accumulate_openai_tc_delta,
_build_ollama_tool_calls,
)
from requests.rechunk import rechunk
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
config = get_config()
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")[0]
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
try:
if use_openai:
start_ts = time.perf_counter()
try:
response = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
raise
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2):
print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
else:
raise
elif "image input is not supported" in _e_str:
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
response = await oclient.chat.completions.create(**params)
else:
raise
if stream:
# For streaming, we need to collect all chunks
chunks = []
tc_acc = {} # accumulate tool-call deltas
async for chunk in response:
chunks.append(chunk)
_accumulate_openai_tc_delta(chunk, tc_acc)
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
# Inject fully-accumulated tool calls into the final response
if tc_acc and response.message:
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
else:
prompt_tok = 0
comp_tok = 0
if response.usage is not None:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(response)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if stream:
# For streaming, collect all chunks
chunks = []
async for chunk in response:
chunks.append(chunk)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, tracking_model)
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make MOE (Multiple Opinions Ensemble) requests.
Generates 3 responses, 3 critiques, and returns the final selected response.
"""
query = get_last_user_content(messages)
if not query:
raise ValueError("No user query found in messages")
if options is None:
options = {}
options["temperature"] = 1
moe_reqs = []
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
for n, r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
# Select final response
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)

187
requests/messages.py Normal file
View file

@ -0,0 +1,187 @@
"""Message-shape transforms used across the chat/completions paths.
Covers the directions between Ollama's native message format and the
OpenAI Chat Completions format:
* tool-call normalization (Ollama OpenAI),
* images encoded as base64 lists OpenAI multimodal ``image_url`` parts,
* trailing-assistant prefill strip (rejected by Claude/OpenAI),
* streaming tool_calls accumulation across deltas,
* logprobs translation (OpenAI choice Ollama ``Logprob``).
"""
import secrets
import ollama
import orjson
from ollama._types import TokenLogprob, Logprob
from images import is_base64, resize_image_if_needed
def get_last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
def _strip_assistant_prefill(messages: list) -> list:
"""Remove a trailing assistant message used as prefill.
OpenAI-compatible endpoints (including Claude) do not support prefill and
will reject requests where the last message has role 'assistant'."""
if messages and messages[-1].get("role") == "assistant":
return messages[:-1]
return messages
def transform_tool_calls_to_openai(message_list):
"""
Ensure tool_calls in assistant messages conform to the OpenAI format:
- Each tool call must have "type": "function"
- Each tool call must have an "id"
- arguments must be a JSON string, not a dict
Also ensure tool-role messages have a tool_call_id.
"""
# Track generated IDs so tool-role messages can reference them
last_tool_call_ids = {}
for msg in message_list:
role = msg.get("role")
if role == "assistant" and "tool_calls" in msg:
for tc in msg["tool_calls"]:
if "type" not in tc:
tc["type"] = "function"
if "id" not in tc:
tc["id"] = f"call_{secrets.token_hex(16)}"
func = tc.get("function", {})
if isinstance(func.get("arguments"), dict):
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
# Remember the id for the following tool-role message
name = func.get("name")
if name:
last_tool_call_ids[name] = tc["id"]
elif role == "tool":
if "tool_call_id" not in msg:
# Try to match by name from a preceding assistant tool_call
name = msg.get("name") or msg.get("tool_name")
if name and name in last_tool_call_ids:
msg["tool_call_id"] = last_tool_call_ids.pop(name)
return message_list
def transform_images_to_data_urls(message_list):
for message in message_list:
if "images" in message:
images = message.pop("images")
if not isinstance(images, list):
continue
new_content = []
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
if not is_base64(image):
raise ValueError(f"Image string is not a valid base64 encoded string.")
resized_image = resize_image_if_needed(image)
if resized_image:
data_url = f"data:image/png;base64,{resized_image}"
#new_content.append({
# "type": "text",
# "text": ""
#})
new_content.append({
"type": "image_url",
"image_url": {
"url": data_url
}
})
message["content"] = new_content
return message_list
def _strip_images_from_messages(messages: list) -> list:
"""Remove image_url parts from message content, keeping only text."""
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_only = [p for p in content if p.get("type") != "image_url"]
if len(text_only) == 1 and text_only[0].get("type") == "text":
content = text_only[0]["text"]
else:
content = text_only
result.append({**msg, "content": content})
else:
result.append(msg)
return result
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
``accumulator`` is a dict mapping tool-call *index* to
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
is the concatenation of all JSON fragments seen so far.
"""
if not chunk.choices:
return
delta = chunk.choices[0].delta
tc_deltas = getattr(delta, "tool_calls", None)
if not tc_deltas:
return
for tc in tc_deltas:
idx = tc.index
if idx not in accumulator:
accumulator[idx] = {
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
"name": tc.function.name if tc.function else None,
"arguments": "",
}
else:
if getattr(tc, "id", None):
accumulator[idx]["id"] = tc.id
if tc.function and tc.function.name:
accumulator[idx]["name"] = tc.function.name
if tc.function and tc.function.arguments:
accumulator[idx]["arguments"] += tc.function.arguments
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
if not accumulator:
return None
result = []
for idx in sorted(accumulator.keys()):
tc = accumulator[idx]
try:
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
except (orjson.JSONDecodeError, TypeError):
args = {}
result.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
))
return result
def _convert_openai_logprobs(choice) -> list | None:
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
lp = getattr(choice, "logprobs", None)
if lp is None:
return None
content = getattr(lp, "content", None)
if not content:
return None
result = []
for entry in content:
top = [
TokenLogprob(token=alt.token, logprob=alt.logprob)
for alt in (entry.top_logprobs or [])
]
result.append(Logprob(
token=entry.token,
logprob=entry.logprob,
top_logprobs=top or None,
))
return result

151
requests/rechunk.py Normal file
View file

@ -0,0 +1,151 @@
"""OpenAI → Ollama response shape converters.
Methods on the ``rechunk`` class are called as bare functions
(``rechunk.openai_chat_completion2ollama(...)``) there is no instance
state. The class is just a namespace.
``extract_usage_from_llama_timings`` reads the ``timings`` field that
llama-server returns in place of OpenAI's ``usage`` so the router can still
count tokens for those backends.
"""
import time
import ollama
import orjson
from images import iso8601_ns
from requests.messages import _convert_openai_logprobs
class rechunk:
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
now = time.perf_counter()
if chunk.choices == [] and chunk.usage is not None:
return ollama.ChatResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True,
done_reason='stop',
total_duration=int((now - start_ts) * 1_000_000_000),
load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens),
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
eval_count=int(chunk.usage.completion_tokens),
eval_duration=int((now - start_ts) * 1_000_000_000),
message=ollama.Message(role="assistant", content=""),
)
with_thinking = chunk.choices[0] if chunk.choices[0] else None
if stream == True:
thinking = (getattr(with_thinking.delta, "reasoning_content", None) or getattr(with_thinking.delta, "reasoning", None)) if with_thinking else None
role = chunk.choices[0].delta.role or "assistant"
content = chunk.choices[0].delta.content or ''
else:
thinking = (getattr(with_thinking.message, "reasoning_content", None) or getattr(with_thinking.message, "reasoning", None)) if with_thinking else None
role = chunk.choices[0].message.role or "assistant"
content = chunk.choices[0].message.content or ''
# Convert OpenAI tool_calls to Ollama format
# In streaming mode, tool_calls arrive as partial deltas across multiple chunks
# (name only in first delta, arguments as incremental JSON fragments).
# Callers must accumulate deltas and inject the final result; skip here.
ollama_tool_calls = None
if not stream:
raw_tool_calls = getattr(with_thinking.message, "tool_calls", None) if with_thinking else None
if raw_tool_calls:
ollama_tool_calls = []
for tc in raw_tool_calls:
try:
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
except (orjson.JSONDecodeError, TypeError):
args = {}
ollama_tool_calls.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc.function.name, arguments=args)
))
# Convert OpenAI logprobs to Ollama format
ollama_logprobs = _convert_openai_logprobs(with_thinking) if with_thinking else None
assistant_msg = ollama.Message(
role=role,
content=content,
thinking=thinking,
images=None,
tool_name=None,
tool_calls=ollama_tool_calls)
rechunk = ollama.ChatResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=100000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
message=assistant_msg,
logprobs=ollama_logprobs)
return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
now = time.perf_counter()
with_thinking = chunk.choices[0] if chunk.choices[0] else None
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
rechunk = ollama.GenerateResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=10000,
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
response=chunk.choices[0].text or '',
thinking=thinking)
return rechunk
def openai_embeddings2ollama(chunk: dict) -> ollama.EmbeddingsResponse:
rechunk = ollama.EmbeddingsResponse(embedding=chunk.data[0].embedding)
return rechunk
def openai_embed2ollama(chunk: dict, model: str) -> ollama.EmbedResponse:
rechunk = ollama.EmbedResponse(
model=model,
created_at=iso8601_ns(),
done=None,
done_reason=None,
total_duration=None,
load_duration=None,
prompt_eval_count=None,
prompt_eval_duration=None,
eval_count=None,
eval_duration=None,
embeddings=[chunk.data[0].embedding])
return rechunk
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
llama-server returns a ``timings`` dict instead of the standard OpenAI
``usage`` field::
"timings": {
"cache_n": 236, // prompt tokens reused from cache
"prompt_n": 1, // prompt tokens processed
"predicted_n": 35 // predicted (completion) tokens
}
prompt_tokens = prompt_n + cache_n
completion_tokens = predicted_n
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
timings are found.
"""
timings = getattr(obj, "timings", None)
if timings is None:
return None
if isinstance(timings, dict):
prompt_n = timings.get("prompt_n", 0) or 0
cache_n = timings.get("cache_n", 0) or 0
predicted_n = timings.get("predicted_n", 0) or 0
return (prompt_n + cache_n, predicted_n)
return None

View file

@ -1,5 +1,5 @@
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.13.5 aiohttp==3.14.0
aiosignal==1.4.0 aiosignal==1.4.0
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.13.0 anyio==4.13.0

4174
router.py

File diff suppressed because it is too large Load diff

294
routing.py Normal file
View file

@ -0,0 +1,294 @@
"""Endpoint selection (load-balancing + conversation affinity).
``choose_endpoint`` is the heart of routing it picks an endpoint that
advertises the model, prefers ones with the model already loaded and a free
slot, applies the conversation-affinity hint when available, and honors
config-order priority routing when ``priority_routing`` is set.
``increment_usage`` / ``decrement_usage`` keep the per-(endpoint, model)
counter that drives utilization-based selection; they fan out an SSE
snapshot on every change.
"""
import asyncio
import random
import time
from typing import Optional
from config import get_config
from state import (
usage_counts,
usage_lock,
_loaded_error_cache,
_loaded_error_cache_lock,
_completion_error_cache,
_completion_error_cache_lock,
_COMPLETION_ERROR_TTL,
_affinity_map,
_affinity_lock,
_AFFINITY_MAX_ENTRIES,
)
from sse import _capture_snapshot, _distribute_snapshot
from backends.health import _is_fresh
from backends.normalize import (
is_ext_openai_endpoint,
is_openai_compatible,
get_tracking_model,
)
from backends.probe import fetch
async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock:
usage_counts[endpoint][model] += 1
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock:
# Avoid negative counts
current = usage_counts[endpoint].get(model, 0)
if current > 0:
usage_counts[endpoint][model] = current - 1
# Optionally, clean up zero entries
if usage_counts[endpoint].get(model, 0) == 0:
usage_counts[endpoint].pop(model, None)
#if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
def get_max_connections(ep: str) -> int:
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
cfg = get_config()
return cfg.endpoint_config.get(ep, {}).get(
"max_concurrent_connections", cfg.max_concurrent_connections
)
async def choose_endpoint(model: str, reserve: bool = True,
affinity_key: Optional[str] = None) -> tuple[str, str]:
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
ensuring that the chosen endpoint actually *advertises* the model.
The selection algorithm:
1 Query every endpoint for its advertised models (`/api/tags`).
2 Build a list of endpoints that contain the requested model.
2.5 If conversation affinity is enabled and the caller passes
``affinity_key``, prefer the endpoint that previously served the
same conversation but only when it still has the model loaded
and a free slot. Otherwise fall through to the standard logic.
3 For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
4 If none are both loaded and free, fall back to any endpoint
from the filtered list that simply has a free slot and randomly
select one.
5 If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
6 If no endpoint advertises the model at all, raise an error.
"""
config = get_config()
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
# Include both config.endpoints and config.llama_server_endpoints
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints = config.endpoints + llama_eps_extra
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
# 6
if not candidate_endpoints:
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
model_without_latest = model.split(":latest")[0]
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
]
if not candidate_endpoints:
# Only add :latest suffix if model doesn't already have a version suffix
if ":" not in model:
model = model + ":latest"
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
if not candidate_endpoints:
raise RuntimeError(
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
f"advertise the model '{model}'."
)
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# 3⃣.5 Exclude endpoints whose loaded-model probe has been failing
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
# would appear with an empty loaded set but pass through to the
# free-slot fallback (step 4) — sending completion calls to an
# unhealthy backend. See issue #83.
async with _loaded_error_cache_lock:
unhealthy = {
ep for ep, ts in _loaded_error_cache.items()
if _is_fresh(ts, 300)
}
if unhealthy:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in unhealthy
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# If *every* candidate is unhealthy we still fall through with the
# original list — refusing to route is worse than retrying a
# possibly-recovered backend.
# 3⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
# failed with a backend connection error (e.g. llama-server in router mode
# whose delegated worker for *this* model died). /v1/models keeps reporting
# OK in that case, so the probe-level filter above cannot catch it.
async with _completion_error_cache_lock:
completion_broken = {
ep for (ep, m), ts in _completion_error_cache.items()
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
}
if completion_broken:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in completion_broken
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# Same fallback: if every candidate is broken for this model, fall
# through and let the upstream retry — possibly the operator restarted
# the dead worker.
# Look up a possible affinity hint *before* taking usage_lock. The two
# locks are never held together to avoid lock-ordering issues.
affine_ep: Optional[str] = None
if config.conversation_affinity and affinity_key:
async with _affinity_lock:
entry = _affinity_map.get(affinity_key)
if entry is not None:
ep, _stored_model, expires_at = entry
if expires_at < time.monotonic():
_affinity_map.pop(affinity_key, None)
else:
affine_ep = ep
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
# tracking names for llama-server (HF prefix / quant suffix stripped).
def tracking_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
def utilization_ratio(ep: str) -> float:
return tracking_usage(ep) / get_max_connections(ep)
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
selected: Optional[str] = None
# 2⃣.5 Conversation affinity preference — only honour the hint when
# the affine endpoint still advertises the model loaded *and* has a
# free slot. Otherwise fall back to the standard algorithm.
if affine_ep:
ep_loaded = {
ep: set(models)
for ep, models in zip(candidate_endpoints, loaded_sets)
}
if (affine_ep in candidate_endpoints
and model in ep_loaded.get(affine_ep, set())
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
selected = affine_ep
if selected is None:
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < get_max_connections(ep)
]
if loaded_and_free:
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
else:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
selected = random.choice(loaded_and_free)
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < get_max_connections(ep)
]
if endpoints_with_free_slot:
if config.priority_routing:
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = endpoints_with_free_slot[0]
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
selected = min(candidate_endpoints, key=tracking_usage)
tracking_model = get_tracking_model(selected, model)
snapshot = None
if reserve:
usage_counts[selected][tracking_model] += 1
snapshot = _capture_snapshot()
if snapshot is not None:
await _distribute_snapshot(snapshot)
# Record / refresh affinity *after* releasing usage_lock.
if reserve and config.conversation_affinity and affinity_key:
expires_at = time.monotonic() + config.conversation_affinity_ttl
async with _affinity_lock:
_affinity_map[affinity_key] = (selected, model, expires_at)
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
now = time.monotonic()
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
_affinity_map.pop(k, None)
return selected, tracking_model

14
security.py Normal file
View file

@ -0,0 +1,14 @@
"""Secret-masking helpers used when logging or surfacing backend errors."""
import re
def _mask_secrets(text: str) -> str:
"""
Mask common API key patterns to avoid leaking secrets in logs or error payloads.
"""
if not text:
return text
# OpenAI-style keys (sk-...) and generic "api key" mentions
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
return text

62
sse.py Normal file
View file

@ -0,0 +1,62 @@
"""Server-sent-events plumbing.
Captures the current ``usage_counts`` / ``token_usage_counts`` snapshot and
fan-outs it to every subscribed asyncio.Queue. Routes that need a live
dashboard feed call ``subscribe`` / ``unsubscribe`` to obtain a queue.
"""
import asyncio
from typing import Dict
import orjson
from state import (
usage_counts,
token_usage_counts,
_subscribers,
_subscribers_lock,
)
def _capture_snapshot() -> str:
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
return orjson.dumps({
"usage_counts": dict(usage_counts),
"token_usage_counts": dict(token_usage_counts)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
async def _distribute_snapshot(snapshot: str) -> None:
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
async with _subscribers_lock:
for q in _subscribers:
if q.full():
try:
await q.get()
except asyncio.QueueEmpty:
pass
await q.put(snapshot)
async def close_all_sse_queues():
for q in list(_subscribers):
# sentinel value that the generator will recognise
await q.put(None)
async def subscribe() -> asyncio.Queue:
"""
Returns a new Queue that will receive every snapshot.
"""
q: asyncio.Queue = asyncio.Queue(maxsize=10)
async with _subscribers_lock:
_subscribers.add(q)
return q
async def unsubscribe(q: asyncio.Queue):
async with _subscribers_lock:
_subscribers.discard(q)
async def get_usage_counts() -> Dict:
return dict(usage_counts) # shallow copy

116
state.py Normal file
View file

@ -0,0 +1,116 @@
"""Shared mutable router state.
All process-wide caches, locks, in-flight task maps, queues, counters and
buffers used by the router live here. These names are only ever *mutated*
(dict/set updates, lock acquisitions, queue put/get) never rebound so
importing them via ``from state import `` is safe from every module.
Rebound singletons (``config``, ``db``, ``token_worker_task``,
``flush_task``) intentionally stay in router.py so their reassignment on
startup is visible to all callers.
"""
import asyncio
from collections import defaultdict
from typing import Dict, Set
# ------------------------------------------------------------------
# Inmemory caches
# ------------------------------------------------------------------
# Successful results are cached for 300s
_models_cache: dict[str, tuple[Set[str], float]] = {}
_loaded_models_cache: dict[str, tuple[Set[str], float]] = {}
# Transient errors are cached separately per concern so that a failure
# in one path does not poison the other.
_available_error_cache: dict[str, float] = {}
_loaded_error_cache: dict[str, float] = {}
# Per-(endpoint, model) completion-path failures. A llama-server in router
# mode can keep returning /v1/models 200 OK after its delegated worker for
# a specific model dies — the probe-level caches above will not catch this.
# We record signals observed during actual completion attempts so
# choose_endpoint can avoid the affected (endpoint, model) pair without
# poisoning unrelated models on the same backend.
_completion_error_cache: dict[tuple[str, str], float] = {}
_COMPLETION_ERROR_TTL = 300
# ------------------------------------------------------------------
# Cache locks
# ------------------------------------------------------------------
_models_cache_lock = asyncio.Lock()
_loaded_models_cache_lock = asyncio.Lock()
_available_error_cache_lock = asyncio.Lock()
_loaded_error_cache_lock = asyncio.Lock()
_completion_error_cache_lock = asyncio.Lock()
# ------------------------------------------------------------------
# In-flight request tracking (prevents cache stampede)
# ------------------------------------------------------------------
_inflight_available_models: dict[str, asyncio.Task] = {}
_inflight_loaded_models: dict[str, asyncio.Task] = {}
_inflight_lock = asyncio.Lock()
_bg_refresh_available: dict[str, asyncio.Task] = {}
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
_bg_refresh_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Queues
# ------------------------------------------------------------------
_subscribers: Set[asyncio.Queue] = set()
_subscribers_lock = asyncio.Lock()
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
# ------------------------------------------------------------------
# HTTP client / connector cache
# ------------------------------------------------------------------
app_state = {
"session": None,
"connector": None,
"probe_session": None, # dedicated session for health/introspection probes
"probe_connector": None, # connection pool isolated from proxy traffic
"socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints
"httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints
# Long-lived backend clients, reused across requests. Constructing these is
# expensive (~40 ms each — every new client builds an SSL context and reloads
# the OS trust store via truststore), so building one per request serializes
# the event loop and caps throughput. Created once at startup, closed on
# shutdown. See backends.sessions.get_ollama_client / _make_openai_client.
"ollama_clients": {}, # endpoint -> ollama.AsyncClient
"openai_clients": {}, # (endpoint, api_key) -> openai.AsyncOpenAI
}
# Default outbound HTTP headers attached to every backend request.
default_headers = {
"HTTP-Referer": "https://nomyo.ai",
"Referer": "https://nomyo.ai",
"X-Title": "NOMYO Router",
}
# ------------------------------------------------------------------
# Token Count Buffer (for write-behind pattern)
# ------------------------------------------------------------------
# Structure: {endpoint: {model: (input_tokens, output_tokens)}}
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0)))
# Time series buffer with timestamp
time_series_buffer: list[dict[str, int | str]] = []
# Lock to protect buffer access from race conditions
buffer_lock = asyncio.Lock()
# Configuration for periodic flushing
FLUSH_INTERVAL = 10 # seconds
# ------------------------------------------------------------------
# Perendpoint permodel active connection counters
# ------------------------------------------------------------------
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
usage_lock = asyncio.Lock() # protects access to usage_counts
token_usage_lock = asyncio.Lock()
# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic).
# Keeps the same conversation pinned to the endpoint that already has its
# KV-cache prefix warm. Model is stored so the dashboard can aggregate live
# entries per (endpoint, model) without recomputing fingerprints.
# Never held together with usage_lock.
_affinity_map: Dict[str, tuple[str, str, float]] = {}
_affinity_lock = asyncio.Lock()
_AFFINITY_MAX_ENTRIES = 10000

View file

@ -1171,11 +1171,13 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
} }
}; };
source.onerror = async (err) => { source.onerror = (err) => {
console.error("SSE connection error. Retrying...", err); // EventSource auto-reconnects on transient drops as long as we
source.close(); // don't close it. Don't treat a dropped stream as an auth failure:
await showApiKeyModal("Enter the NOMYO Router API key to view live usage."); // auth prompting is handled by loadEndpoints()/authedFetch() on the
loadUsage(); // REST endpoints. A genuine 401 closes the stream permanently here
// (no reconnect loop), and the REST path surfaces the modal.
console.error("SSE connection error; awaiting auto-reconnect.", err);
}; };
window.addEventListener("beforeunload", () => source.close()); window.addEventListener("beforeunload", () => source.close());
} }

138
test/load/README.md Normal file
View file

@ -0,0 +1,138 @@
# Load testing the NOMYO Router
`loadtest.py` is a self-contained load generator (asyncio + httpx) with a built-in
**mock backend** so you can measure the router's own concurrency ceiling on a given
machine — independent of real GPU/backend compute.
It answers the question *"how many concurrent connections can the router sustain
on this box?"* by hammering it with N concurrent virtual clients and reporting
throughput, latency percentiles and (for streaming) time-to-first-token.
Run everything from the project root with the project venv active:
```bash
source ~/.venv/nomyo-router/bin/activate # whatever venv has the router deps
```
## The three modes
### 1. `--mock-backend` (recommended) — fully self-contained
Spawns a fast fake Ollama/OpenAI backend **and** the router (wired to it via a
temporary config), drives load against the router, then tears both down. Because
the backend is trivial, the numbers reflect the **router's proxy overhead**, not
model inference time.
```bash
python test/load/loadtest.py --mock-backend --stream --concurrency 128 --duration 30
```
### 2. Default — drive an already-running router
```bash
python test/load/loadtest.py --url http://127.0.0.1:12434 \
--api ollama --stream --concurrency 64 --duration 30 --model llama3
```
### 3. `--serve-mock` — just the mock backend
Run only the fake backend and point your own router `config.yaml` at it
(`endpoints: [http://127.0.0.1:11434]`):
```bash
python test/load/loadtest.py --serve-mock --mock-port 11434 --mock-tokens 64
```
## Finding the concurrency knee
`--ramp` sweeps several concurrency levels and prints a table. The knee is where
`req/s` stops rising and `p99` latency starts climbing sharply:
```bash
python test/load/loadtest.py --mock-backend --stream \
--ramp 8,32,64,128,256 --duration 15
```
```
conc req ok err req/s p50ms p90ms p99ms maxms ttftP50 ttftP99
---------------------------------------------------------------------------------------------
8 120 120 0 19.8 404.6 448.3 478.6 501.4 358.4 391.7
32 140 140 0 21.5 1487.1 1641.8 2341.8 2397.4 1269.8 1476.3
64 148 148 0 21.3 2953.0 4632.5 5204.3 5267.0 1207.8 3031.7
128 168 168 0 19.0 6376.4 8608.9 8726.9 8739.8 2843.1 8348.6
```
> Reading the table above: throughput stays flat (~20 req/s) while latency grows
> linearly with concurrency — the classic signature of a **single-worker
> serialization bottleneck**. Raising `--router-workers` lets throughput scale
> across CPU cores; the per-worker ceiling is what each table row measures.
## Streaming vs non-streaming, Ollama vs OpenAI
| flag | effect |
|------|--------|
| `--stream` / `--no-stream` | streamed response (default) vs a single buffered response |
| `--api ollama` | drives `POST /api/chat` (default) |
| `--api openai` | drives `POST /v1/chat/completions` |
Streaming runs additionally report **TTFT** (time-to-first-token), which isolates
prefill/routing latency from total stream duration.
## Shaping the mock backend (the "fake GPU")
The mock's latency is fully configurable, so you can model anything from an
instant echo (measure pure proxy overhead) to a slow, long-streaming model
(measure how many slow streams the box holds open at once):
| flag | meaning |
|------|---------|
| `--mock-ttft-ms` | prefill latency before the first token (ms) |
| `--mock-tokens` | number of completion tokens emitted |
| `--mock-tok-ms` | per-token decode delay (ms) — inverse of tokens/sec |
| `--mock-models` | comma-separated model names advertised in `/api/tags` & `/api/ps` |
Example — simulate a realistic 40 tok/s model with 300 ms prefill emitting 200
tokens, and see how many concurrent such streams the router holds:
```bash
python test/load/loadtest.py --mock-backend --stream --ramp 16,64,256 \
--mock-ttft-ms 300 --mock-tokens 200 --mock-tok-ms 25 --duration 20
```
## Load shape & misc flags
| flag | default | meaning |
|------|---------|---------|
| `--concurrency N` | 32 | concurrent virtual clients |
| `--duration S` | 20 | seconds per stage (ignored if `--requests` set) |
| `--requests N` | — | send exactly N requests instead of timing out |
| `--warmup S` | 2 | unmeasured warmup before each stage (hot caches/connections) |
| `--timeout S` | 120 | per-request timeout |
| `--model NAME` | `mock` | model name requested (must match what the backend advertises) |
| `--prompt STR` | … | user prompt sent in every request |
| `--json PATH` | — | also write the full results as JSON |
### `--mock-backend` orchestration knobs
| flag | default | meaning |
|------|---------|---------|
| `--router-workers N` | 1 | `uvicorn --workers` for the spawned router |
| `--router-max-conc N` | = peak concurrency | `max_concurrent_connections` in the generated config (so the router doesn't queue unless you want it to) |
| `--router-port` / `--mock-port` | auto | fix the ports instead of auto-picking free ones |
| `--keep-config` | off | keep the generated temp `config.yaml` for inspection |
## Notes & caveats
- **Single-machine bias.** With `--mock-backend`, the driver, router and mock all
share the same CPU, so they compete for cores. For an upper-bound number, run
the driver on a separate machine against a real router (`--url`), or pin
processes to different cores.
- The generated config sets `conversation_affinity: false` and
`cache_enabled: false` to measure the raw proxy path. The temp config and a
throwaway token DB (under the system temp dir) are deleted on exit.
- To measure the router's *admission* limit instead of raw throughput, set
`--router-max-conc` low (e.g. `2`) — requests beyond the limit queue on the
least-busy endpoint rather than erroring.
- Requires the router's own dependencies (`aiohttp`, `httpx`, `uvicorn`, …); it
reuses the project venv, no extra packages needed.
```

803
test/load/loadtest.py Normal file
View file

@ -0,0 +1,803 @@
#!/usr/bin/env python3
"""
NOMYO Router load test asyncio + httpx driver with a built-in mock backend.
Three modes
-----------
1. Drive an already-running router (default)::
python test/load/loadtest.py --url http://127.0.0.1:12434 \
--concurrency 64 --duration 30 --stream
2. Fully self-contained "mock backend" mode spins up a fast fake Ollama/OpenAI
backend AND the router (wired to that backend via a temp config), load-tests
them, then tears both down. This isolates the *router's* proxy overhead from
real GPU compute, so the numbers tell you how many concurrent connections the
router itself can sustain on this machine::
python test/load/loadtest.py --mock-backend \
--concurrency 128 --duration 30 --stream
3. Run just the mock backend (point your own router config at it)::
python test/load/loadtest.py --serve-mock --mock-port 11434
Both streaming and non-streaming are supported (--stream / --no-stream), against
either the Ollama API (--api ollama -> POST /api/chat) or the OpenAI-compatible
API (--api openai -> POST /v1/chat/completions).
Finding the concurrency ceiling
-------------------------------
Use --ramp to sweep concurrency levels and print a table; the "knee" is where
p99 latency climbs sharply or req/s stops increasing::
python test/load/loadtest.py --mock-backend --stream \
--ramp 16,32,64,128,256 --duration 15
The mock backend is a configurable "fake GPU": --mock-ttft-ms (prefill latency),
--mock-tokens (completion length) and --mock-tok-ms (per-token decode delay) let
you model anything from an instant echo (measure pure proxy overhead) to a slow,
long-streaming model (measure how many slow streams the box holds open).
"""
from __future__ import annotations
import argparse
import asyncio
import contextlib
import json
import math
import os
import signal
import socket
import statistics
import sys
import tempfile
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import httpx
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
_WORDS = (
"the quick brown fox jumps over the lazy dog while a router proxies many "
"concurrent streaming completions across several ollama and openai backends "
"without dropping a single token under sustained synthetic load testing"
).split()
def _gen_text(n_tokens: int) -> str:
"""Deterministic pseudo-completion of roughly ``n_tokens`` space-separated tokens."""
return " ".join(_WORDS[i % len(_WORDS)] for i in range(max(0, n_tokens)))
def _rfc3339_now() -> str:
# Ollama-style timestamp, e.g. 2024-01-01T00:00:00.000000Z
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
def _count_prompt_tokens(messages: list) -> int:
total = 0
for m in messages or []:
c = m.get("content")
if isinstance(c, str):
total += len(c.split())
elif isinstance(c, list):
for part in c:
if isinstance(part, dict) and isinstance(part.get("text"), str):
total += len(part["text"].split())
return max(1, total)
def _free_port() -> int:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
s.close()
return port
# ===========================================================================
# Mock backend (a fast, configurable fake Ollama + OpenAI-compatible server)
# ===========================================================================
def build_mock_app(models: list[str], ttft_ms: float, tokens: int, tok_ms: float):
"""Construct the aiohttp mock-backend application.
Serves the native-Ollama surface the router uses for discovery and the
`/api/chat` path (`/api/version`, `/api/tags`, `/api/ps`, `/api/chat`,
`/api/generate`) plus the OpenAI-compatible surface used by the
`/v1/chat/completions` path (`/v1/models`, `/v1/chat/completions`,
`/v1/completions`).
"""
from aiohttp import web # imported lazily so the driver has no hard aiohttp dep
ttft = ttft_ms / 1000.0
tok_delay = tok_ms / 1000.0
def _tag_entry(name: str) -> dict:
return {
"name": name,
"model": name,
"modified_at": _rfc3339_now(),
"size": 4_000_000_000,
"digest": "0" * 64,
"details": {
"parent_model": "",
"format": "gguf",
"family": "mock",
"families": ["mock"],
"parameter_size": "7B",
"quantization_level": "Q4_0",
},
}
async def version(_req):
return web.json_response({"version": "0.0.0-nomyo-mock"})
async def tags(_req):
return web.json_response({"models": [_tag_entry(m) for m in models]})
async def ps(_req):
# Report every advertised model as loaded with VRAM so choose_endpoint
# treats this endpoint as "loaded + free".
out = []
for m in models:
e = _tag_entry(m)
e["size_vram"] = e["size"]
e["expires_at"] = "2999-01-01T00:00:00Z"
out.append(e)
return web.json_response({"models": out})
async def v1_models(_req):
now = int(time.time())
return web.json_response({
"object": "list",
"data": [{"id": m, "object": "model", "created": now, "owned_by": "mock"} for m in models],
})
# ----- Ollama /api/chat -------------------------------------------------
async def api_chat(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", True)
prompt_tok = _count_prompt_tokens(payload.get("messages", []))
t0 = time.perf_counter()
if stream:
resp = web.StreamResponse(
status=200, headers={"Content-Type": "application/x-ndjson"}
)
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
line = {
"model": model,
"created_at": _rfc3339_now(),
"message": {"role": "assistant", "content": _WORDS[i % len(_WORDS)] + " "},
"done": False,
}
await resp.write(json.dumps(line).encode() + b"\n")
dur_ns = int((time.perf_counter() - t0) * 1e9)
final = {
"model": model,
"created_at": _rfc3339_now(),
"message": {"role": "assistant", "content": ""},
"done": True,
"done_reason": "stop",
"total_duration": dur_ns,
"load_duration": 0,
"prompt_eval_count": prompt_tok,
"prompt_eval_duration": int(ttft * 1e9),
"eval_count": tokens,
"eval_duration": dur_ns,
}
await resp.write(json.dumps(final).encode() + b"\n")
await resp.write_eof()
return resp
# non-streaming: simulate the whole generation latency, then one object
await asyncio.sleep(ttft + tokens * tok_delay)
dur_ns = int((time.perf_counter() - t0) * 1e9)
return web.json_response({
"model": model,
"created_at": _rfc3339_now(),
"message": {"role": "assistant", "content": _gen_text(tokens)},
"done": True,
"done_reason": "stop",
"total_duration": dur_ns,
"load_duration": 0,
"prompt_eval_count": prompt_tok,
"prompt_eval_duration": int(ttft * 1e9),
"eval_count": tokens,
"eval_duration": dur_ns,
})
# ----- Ollama /api/generate --------------------------------------------
async def api_generate(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", True)
prompt_tok = max(1, len(str(payload.get("prompt", "")).split()))
t0 = time.perf_counter()
if stream:
resp = web.StreamResponse(status=200, headers={"Content-Type": "application/x-ndjson"})
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
await resp.write(json.dumps({
"model": model, "created_at": _rfc3339_now(),
"response": _WORDS[i % len(_WORDS)] + " ", "done": False,
}).encode() + b"\n")
dur_ns = int((time.perf_counter() - t0) * 1e9)
await resp.write(json.dumps({
"model": model, "created_at": _rfc3339_now(), "response": "", "done": True,
"done_reason": "stop", "total_duration": dur_ns,
"prompt_eval_count": prompt_tok, "eval_count": tokens, "eval_duration": dur_ns,
}).encode() + b"\n")
await resp.write_eof()
return resp
await asyncio.sleep(ttft + tokens * tok_delay)
dur_ns = int((time.perf_counter() - t0) * 1e9)
return web.json_response({
"model": model, "created_at": _rfc3339_now(), "response": _gen_text(tokens),
"done": True, "done_reason": "stop", "total_duration": dur_ns,
"prompt_eval_count": prompt_tok, "eval_count": tokens, "eval_duration": dur_ns,
})
# ----- OpenAI /v1/chat/completions -------------------------------------
async def v1_chat(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", False)
want_usage = bool((payload.get("stream_options") or {}).get("include_usage"))
prompt_tok = _count_prompt_tokens(payload.get("messages", []))
created = int(time.time())
cid = "chatcmpl-mock"
if stream:
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
def _sse(obj: dict) -> bytes:
return b"data: " + json.dumps(obj).encode() + b"\n\n"
# first chunk carries the role
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
}))
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {"content": _WORDS[i % len(_WORDS)] + " "}, "finish_reason": None}],
}))
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}))
if want_usage:
await resp.write(_sse({
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
"choices": [],
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
"total_tokens": prompt_tok + tokens},
}))
await resp.write(b"data: [DONE]\n\n")
await resp.write_eof()
return resp
await asyncio.sleep(ttft + tokens * tok_delay)
return web.json_response({
"id": cid, "object": "chat.completion", "created": created, "model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": _gen_text(tokens)},
"finish_reason": "stop", "logprobs": None}],
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
"total_tokens": prompt_tok + tokens},
})
# ----- OpenAI /v1/completions ------------------------------------------
async def v1_completions(req):
payload = await req.json()
model = payload.get("model", models[0] if models else "mock")
stream = payload.get("stream", False)
prompt_tok = max(1, len(str(payload.get("prompt", "")).split()))
created = int(time.time())
cid = "cmpl-mock"
if stream:
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
await resp.prepare(req)
if ttft:
await asyncio.sleep(ttft)
for i in range(tokens):
if tok_delay and i:
await asyncio.sleep(tok_delay)
await resp.write(b"data: " + json.dumps({
"id": cid, "object": "text_completion", "created": created, "model": model,
"choices": [{"index": 0, "text": _WORDS[i % len(_WORDS)] + " ", "finish_reason": None}],
}).encode() + b"\n\n")
await resp.write(b"data: [DONE]\n\n")
await resp.write_eof()
return resp
await asyncio.sleep(ttft + tokens * tok_delay)
return web.json_response({
"id": cid, "object": "text_completion", "created": created, "model": model,
"choices": [{"index": 0, "text": _gen_text(tokens), "finish_reason": "stop"}],
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
"total_tokens": prompt_tok + tokens},
})
app = web.Application(client_max_size=64 * 1024 * 1024)
app.add_routes([
web.get("/api/version", version),
web.get("/api/tags", tags),
web.get("/api/ps", ps),
web.post("/api/chat", api_chat),
web.post("/api/generate", api_generate),
web.get("/v1/models", v1_models),
web.post("/v1/chat/completions", v1_chat),
web.post("/v1/completions", v1_completions),
])
return app
def serve_mock(args) -> None:
from aiohttp import web
models = [m.strip() for m in args.mock_models.split(",") if m.strip()]
app = build_mock_app(models, args.mock_ttft_ms, args.mock_tokens, args.mock_tok_ms)
print(f"[mock] serving models={models} on http://{args.mock_host}:{args.mock_port} "
f"(ttft={args.mock_ttft_ms}ms tokens={args.mock_tokens} tok={args.mock_tok_ms}ms)",
flush=True)
web.run_app(app, host=args.mock_host, port=args.mock_port, print=None)
# ===========================================================================
# Load driver
# ===========================================================================
@dataclass
class Sample:
ok: bool
status: int
latency: float # full request wall time (s)
ttft: Optional[float] # time-to-first-byte for streaming (s), else None
err: Optional[str] = None
@dataclass
class Stats:
concurrency: int
wall: float = 0.0
samples: list = field(default_factory=list)
@property
def ok(self) -> list:
return [s for s in self.samples if s.ok]
@property
def n_total(self) -> int:
return len(self.samples)
@property
def n_ok(self) -> int:
return len(self.ok)
def _pct(values: list[float], p: float) -> float:
if not values:
return float("nan")
s = sorted(values)
if len(s) == 1:
return s[0]
k = (len(s) - 1) * (p / 100.0)
lo = math.floor(k)
hi = math.ceil(k)
if lo == hi:
return s[int(k)]
return s[lo] + (s[hi] - s[lo]) * (k - lo)
def _build_request(args):
"""Return (path, json_payload) for a single request."""
messages = [{"role": "user", "content": args.prompt}]
if args.api == "openai":
path = "/v1/chat/completions"
body = {"model": args.model, "messages": messages, "stream": args.stream}
else:
path = "/api/chat"
body = {"model": args.model, "messages": messages, "stream": args.stream}
return path, body
async def _one_request(client: httpx.AsyncClient, url: str, body: dict, stream: bool) -> Sample:
t0 = time.perf_counter()
try:
if stream:
ttft = None
async with client.stream("POST", url, json=body) as resp:
status = resp.status_code
async for _chunk in resp.aiter_bytes():
if ttft is None:
ttft = time.perf_counter() - t0
# drain complete
lat = time.perf_counter() - t0
ok = 200 <= status < 300
return Sample(ok=ok, status=status, latency=lat, ttft=ttft,
err=None if ok else f"HTTP {status}")
else:
resp = await client.post(url, json=body)
lat = time.perf_counter() - t0
ok = 200 <= resp.status_code < 300
# touch body so the full response is received
_ = resp.content
return Sample(ok=ok, status=resp.status_code, latency=lat, ttft=None,
err=None if ok else f"HTTP {resp.status_code}")
except Exception as e: # noqa: BLE001 — record any transport error as a failed sample
lat = time.perf_counter() - t0
return Sample(ok=False, status=0, latency=lat, ttft=None,
err=f"{type(e).__name__}: {str(e)[:120]}")
async def run_stage(args, concurrency: int) -> Stats:
path, body = _build_request(args)
url = args.url.rstrip("/") + path
stats = Stats(concurrency=concurrency)
limits = httpx.Limits(max_connections=concurrency + 50,
max_keepalive_connections=concurrency + 50)
timeout = httpx.Timeout(args.timeout, connect=15.0)
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
# warmup (unmeasured): make a few requests so caches/connections are hot
if args.warmup > 0:
warm_deadline = time.perf_counter() + args.warmup
async def _warm():
while time.perf_counter() < warm_deadline:
await _one_request(client, url, body, args.stream)
await asyncio.gather(*[_warm() for _ in range(min(concurrency, 8))])
use_duration = args.requests is None
deadline = time.perf_counter() + args.duration if use_duration else None
remaining = args.requests if not use_duration else None
remaining_lock = asyncio.Lock()
async def worker():
nonlocal remaining
while True:
if use_duration:
if time.perf_counter() >= deadline:
return
else:
async with remaining_lock:
if remaining <= 0:
return
remaining -= 1
s = await _one_request(client, url, body, args.stream)
stats.samples.append(s)
wall0 = time.perf_counter()
await asyncio.gather(*[worker() for _ in range(concurrency)])
stats.wall = time.perf_counter() - wall0
return stats
def _print_stage(stats: Stats, args, header: bool) -> None:
lat = [s.latency * 1000 for s in stats.ok]
ttfts = [s.ttft * 1000 for s in stats.ok if s.ttft is not None]
rps = stats.n_ok / stats.wall if stats.wall else 0.0
errs = stats.n_total - stats.n_ok
if args.ramp:
if header:
cols = f"{'conc':>5} {'req':>7} {'ok':>7} {'err':>5} {'req/s':>9} " \
f"{'p50ms':>8} {'p90ms':>8} {'p99ms':>9} {'maxms':>9}"
if args.stream:
cols += f" {'ttftP50':>8} {'ttftP99':>8}"
print(cols)
print("-" * len(cols))
row = (f"{stats.concurrency:>5} {stats.n_total:>7} {stats.n_ok:>7} {errs:>5} "
f"{rps:>9.1f} {_pct(lat,50):>8.1f} {_pct(lat,90):>8.1f} "
f"{_pct(lat,99):>9.1f} {(max(lat) if lat else float('nan')):>9.1f}")
if args.stream:
row += f" {_pct(ttfts,50):>8.1f} {_pct(ttfts,99):>8.1f}"
print(row, flush=True)
return
# single-stage detailed report
print(f"\n=== Results (concurrency={stats.concurrency}, "
f"{'stream' if args.stream else 'non-stream'}, api={args.api}) ===")
print(f" wall time : {stats.wall:8.2f} s")
print(f" requests : {stats.n_total} total, {stats.n_ok} ok, {errs} failed")
print(f" throughput : {rps:8.1f} req/s")
if lat:
print(f" latency p50 : {_pct(lat,50):8.1f} ms")
print(f" p90 : {_pct(lat,90):8.1f} ms")
print(f" p95 : {_pct(lat,95):8.1f} ms")
print(f" p99 : {_pct(lat,99):8.1f} ms")
print(f" max : {max(lat):8.1f} ms")
print(f" mean : {statistics.mean(lat):8.1f} ms")
if ttfts:
print(f" TTFT p50 : {_pct(ttfts,50):8.1f} ms")
print(f" p90 : {_pct(ttfts,90):8.1f} ms")
print(f" p99 : {_pct(ttfts,99):8.1f} ms")
if errs:
by_err: dict[str, int] = {}
for s in stats.samples:
if not s.ok:
by_err[s.err or "unknown"] = by_err.get(s.err or "unknown", 0) + 1
print(" errors:")
for k, v in sorted(by_err.items(), key=lambda kv: -kv[1]):
print(f" {v:>6} {k}")
async def run_driver(args) -> list[Stats]:
stages = ([int(x) for x in args.ramp.split(",")] if args.ramp else [args.concurrency])
results: list[Stats] = []
for i, c in enumerate(stages):
stats = await run_stage(args, c)
_print_stage(stats, args, header=(i == 0))
results.append(stats)
return results
# ===========================================================================
# Orchestration: --mock-backend (spawn mock + router, run, tear down)
# ===========================================================================
PROJECT_ROOT = Path(__file__).resolve().parents[2]
async def _wait_http_ok(url: str, timeout: float, accept=(200,)) -> bool:
deadline = time.perf_counter() + timeout
async with httpx.AsyncClient(timeout=5.0) as client:
while time.perf_counter() < deadline:
try:
r = await client.get(url)
if r.status_code in accept:
return True
except Exception:
pass
await asyncio.sleep(0.25)
return False
def _write_temp_config(mock_url: str, models: list[str], max_conc: int) -> Path:
fd, path = tempfile.mkstemp(prefix="nomyo_loadtest_", suffix=".yaml")
os.close(fd)
cfg = (
"# Auto-generated by test/load/loadtest.py --mock-backend. Safe to delete.\n"
"endpoints:\n"
f" - {mock_url}\n"
"llama_server_endpoints: []\n"
f"max_concurrent_connections: {max_conc}\n"
"priority_routing: false\n"
"conversation_affinity: false\n"
"cache_enabled: false\n"
"nomyo-router-api-key: \"\"\n"
"api_keys:\n"
f" \"{mock_url}\": \"mock\"\n"
)
Path(path).write_text(cfg)
return Path(path)
async def run_with_mock_backend(args) -> list[Stats]:
mock_port = args.mock_port or _free_port()
router_port = args.router_port or _free_port()
mock_url = f"http://127.0.0.1:{mock_port}"
router_url = f"http://127.0.0.1:{router_port}"
models = [m.strip() for m in args.mock_models.split(",") if m.strip()]
# Size the router's per-endpoint admission limit so it does not artificially
# serialize the load (unless the user explicitly wants to measure that).
peak = max([int(x) for x in args.ramp.split(",")]) if args.ramp else args.concurrency
max_conc = args.router_max_conc if args.router_max_conc else max(peak, 1)
cfg_path = _write_temp_config(mock_url, models, max_conc)
db_path = Path(tempfile.gettempdir()) / f"nomyo_loadtest_{os.getpid()}.db"
env = dict(os.environ)
env["NOMYO_ROUTER_CONFIG_PATH"] = str(cfg_path)
env["NOMYO_ROUTER_DB_PATH"] = str(db_path)
mock_proc = None
router_proc = None
try:
# 1. mock backend first, so the router never caches it as "down"
mock_cmd = [
sys.executable, str(Path(__file__).resolve()), "--serve-mock",
"--mock-host", "127.0.0.1", "--mock-port", str(mock_port),
"--mock-models", args.mock_models,
"--mock-ttft-ms", str(args.mock_ttft_ms),
"--mock-tokens", str(args.mock_tokens),
"--mock-tok-ms", str(args.mock_tok_ms),
]
print(f"[orchestrator] starting mock backend: {mock_url}", flush=True)
mock_proc = await asyncio.create_subprocess_exec(*mock_cmd)
if not await _wait_http_ok(f"{mock_url}/api/version", timeout=15):
raise RuntimeError("mock backend did not become ready")
# 2. router
router_cmd = [
sys.executable, "-m", "uvicorn", "router:app",
"--host", "127.0.0.1", "--port", str(router_port),
# Per-request access logging is pure noise (and overhead) under load.
"--no-access-log",
]
if args.router_workers and args.router_workers > 1:
router_cmd += ["--workers", str(args.router_workers)]
print(f"[orchestrator] starting router: {router_url} "
f"(workers={args.router_workers}, max_concurrent_connections={max_conc})", flush=True)
router_proc = await asyncio.create_subprocess_exec(
*router_cmd, cwd=str(PROJECT_ROOT), env=env
)
# /health returns 200 only once it can reach the (healthy) mock backend
if not await _wait_http_ok(f"{router_url}/health", timeout=40, accept=(200,)):
raise RuntimeError("router did not become healthy")
print("[orchestrator] router healthy — starting load\n", flush=True)
# 3. drive load against the router
args.url = router_url
return await run_driver(args)
finally:
for name, proc in (("router", router_proc), ("mock", mock_proc)):
if proc and proc.returncode is None:
with contextlib.suppress(ProcessLookupError):
proc.send_signal(signal.SIGINT)
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(proc.wait(), timeout=8)
if proc.returncode is None:
with contextlib.suppress(ProcessLookupError):
proc.kill()
print(f"[orchestrator] stopped {name}", flush=True)
if args.keep_config:
print(f"[orchestrator] kept config: {cfg_path}", flush=True)
else:
with contextlib.suppress(FileNotFoundError):
cfg_path.unlink()
for suffix in ("", "-shm", "-wal"):
with contextlib.suppress(FileNotFoundError):
Path(str(db_path) + suffix).unlink()
# ===========================================================================
# CLI
# ===========================================================================
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
description="NOMYO Router load test (asyncio + httpx) with a built-in mock backend.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
mode = p.add_argument_group("mode")
mode.add_argument("--serve-mock", action="store_true",
help="Run ONLY the mock backend (foreground) and exit on Ctrl-C.")
mode.add_argument("--mock-backend", action="store_true",
help="Spawn mock backend + router, load-test them, then tear down.")
tgt = p.add_argument_group("target (driver)")
tgt.add_argument("--url", default="http://127.0.0.1:12434",
help="Router base URL to drive (ignored with --mock-backend).")
tgt.add_argument("--api", choices=["ollama", "openai"], default="ollama",
help="ollama -> POST /api/chat ; openai -> POST /v1/chat/completions")
tgt.add_argument("--model", default="mock", help="Model name to request.")
tgt.add_argument("--prompt", default="Say hello and count to ten.",
help="User prompt sent in every request.")
stream_grp = tgt.add_mutually_exclusive_group()
stream_grp.add_argument("--stream", dest="stream", action="store_true",
help="Stream the response (default).")
stream_grp.add_argument("--no-stream", dest="stream", action="store_false",
help="Request a single non-streamed response.")
p.set_defaults(stream=True)
load = p.add_argument_group("load shape")
load.add_argument("--concurrency", type=int, default=32,
help="Number of concurrent virtual clients.")
load.add_argument("--duration", type=float, default=20.0,
help="Seconds to run each stage (ignored if --requests given).")
load.add_argument("--requests", type=int, default=None,
help="Send exactly N requests instead of running for --duration.")
load.add_argument("--ramp", default=None,
help="Comma-separated concurrency stages, e.g. 16,32,64,128 "
"(prints a table to find the knee).")
load.add_argument("--warmup", type=float, default=2.0,
help="Seconds of unmeasured warmup before each stage.")
load.add_argument("--timeout", type=float, default=120.0,
help="Per-request timeout (seconds).")
load.add_argument("--json", dest="json_out", default=None,
help="Also write the results as JSON to this path.")
mock = p.add_argument_group("mock backend tuning")
mock.add_argument("--mock-host", default="127.0.0.1")
mock.add_argument("--mock-port", type=int, default=0,
help="Mock backend port (0 = auto-pick a free port).")
mock.add_argument("--mock-models", default="mock",
help="Comma-separated model names the mock advertises.")
mock.add_argument("--mock-ttft-ms", type=float, default=0.0,
help="Simulated prefill latency before the first token (ms).")
mock.add_argument("--mock-tokens", type=int, default=64,
help="Completion length in tokens the mock emits.")
mock.add_argument("--mock-tok-ms", type=float, default=0.0,
help="Simulated per-token decode delay (ms) = inverse of tok/s.")
orch = p.add_argument_group("router orchestration (--mock-backend only)")
orch.add_argument("--router-port", type=int, default=0,
help="Router port (0 = auto-pick a free port).")
orch.add_argument("--router-workers", type=int, default=1,
help="uvicorn --workers for the spawned router.")
orch.add_argument("--router-max-conc", type=int, default=0,
help="max_concurrent_connections in the generated config "
"(0 = match peak concurrency so the router does not queue).")
orch.add_argument("--keep-config", action="store_true",
help="Do not delete the generated temp config on exit.")
return p
def _dump_json(path: str, args, results: list[Stats]) -> None:
out = {
"config": {k: getattr(args, k) for k in (
"api", "model", "stream", "duration", "requests", "warmup", "timeout",
"mock_tokens", "mock_ttft_ms", "mock_tok_ms")},
"stages": [],
}
for st in results:
lat = [s.latency * 1000 for s in st.ok]
ttfts = [s.ttft * 1000 for s in st.ok if s.ttft is not None]
out["stages"].append({
"concurrency": st.concurrency,
"wall_s": st.wall,
"requests": st.n_total,
"ok": st.n_ok,
"errors": st.n_total - st.n_ok,
"rps": (st.n_ok / st.wall) if st.wall else 0.0,
"latency_ms": {p: _pct(lat, p) for p in (50, 90, 95, 99)} | (
{"max": max(lat), "mean": statistics.mean(lat)} if lat else {}),
"ttft_ms": {p: _pct(ttfts, p) for p in (50, 90, 99)} if ttfts else {},
})
Path(path).write_text(json.dumps(out, indent=2))
print(f"\n[driver] wrote JSON results to {path}", flush=True)
def main() -> None:
args = build_parser().parse_args()
if args.serve_mock:
try:
serve_mock(args)
except KeyboardInterrupt:
pass
return
if args.requests is not None and args.requests <= 0:
print("--requests must be > 0", file=sys.stderr)
sys.exit(2)
if args.mock_backend:
results = asyncio.run(run_with_mock_backend(args)) or []
else:
results = asyncio.run(run_driver(args))
if args.json_out and results:
_dump_json(args.json_out, args, results)
if __name__ == "__main__":
main()

View file

@ -1,4 +1,3 @@
pytest>=8.0 pytest>=8.0
pytest-asyncio>=0.24 pytest-asyncio>=0.24
pytest-cov>=5.0 pytest-cov>=5.0
aioresponses>=0.7

View file

@ -1,11 +1,19 @@
"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking.""" """Tests for fetch.available_models and fetch.loaded_models.
The backend probes obtain their HTTP client via ``backends.probe.get_probe_session``
and only ever call ``async with client.get(url, headers=...) as resp``. We patch that
seam with a tiny fake session instead of mocking aiohttp's internals (aioresponses),
so the suite stays independent of aiohttp's private ClientResponse/ConnectionKey
structure across version bumps.
"""
import time import time
from contextlib import contextmanager
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import pytest import pytest
from aioresponses import aioresponses
import router import router
import backends.probe as probe
from conftest import TEST_OLLAMA, TEST_LLAMA from conftest import TEST_OLLAMA, TEST_LLAMA
MOCK_OLLAMA_EP = "http://mock-ollama:11434" MOCK_OLLAMA_EP = "http://mock-ollama:11434"
@ -22,6 +30,73 @@ def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
return cfg return cfg
# ── Fake probe session ────────────────────────────────────────────────────────
class _MockResponse:
"""Minimal stand-in for the aiohttp response used by the probes."""
def __init__(self, *, status=200, payload=None, text=None):
self.status = status
self._payload = payload
self._text = text if text is not None else ""
async def json(self):
return self._payload
async def text(self):
return self._text
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
return False
class _RaisingCtx:
"""``async with client.get(...)`` that raises on entry — mimics a failed connection."""
def __init__(self, exc):
self._exc = exc
async def __aenter__(self):
raise self._exc
async def __aexit__(self, *exc):
return False
class _MockProbeSession:
"""Stand-in for the aiohttp ClientSession returned by ``get_probe_session``.
Routes are registered by exact URL via :meth:`add_get`. A registered exception
is raised when the route is entered; otherwise a :class:`_MockResponse` is yielded.
An unregistered GET fails loudly so tests can't silently pass on a wrong URL.
"""
def __init__(self):
self._routes = {}
def add_get(self, url, *, status=200, payload=None, text=None, exception=None):
self._routes[url] = exception if exception is not None else _MockResponse(
status=status, payload=payload, text=text
)
def get(self, url, **kwargs):
if url not in self._routes:
raise AssertionError(f"unexpected probe GET {url}")
entry = self._routes[url]
return _RaisingCtx(entry) if isinstance(entry, Exception) else entry
@contextmanager
def mock_probe():
"""Patch the probe's session factory to return a fresh :class:`_MockProbeSession`."""
session = _MockProbeSession()
with patch.object(probe, "get_probe_session", lambda endpoint: session):
yield session
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_caches(aio_session): def clear_caches(aio_session):
"""aio_session fixture already clears caches and sets up app_state.""" """aio_session fixture already clears caches and sets up app_state."""
@ -31,8 +106,8 @@ def clear_caches(aio_session):
class TestFetchAvailableModels: class TestFetchAvailableModels:
async def test_ollama_tags(self): async def test_ollama_tags(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags", f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [ payload={"models": [
{"name": "llama3.2:latest"}, {"name": "llama3.2:latest"},
@ -44,8 +119,8 @@ class TestFetchAvailableModels:
async def test_openai_compatible_models_endpoint(self): async def test_openai_compatible_models_endpoint(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_LLAMA_EP}/models", f"{MOCK_LLAMA_EP}/models",
payload={"data": [{"id": "unsloth/model:Q8_0"}]}, payload={"data": [{"id": "unsloth/model:Q8_0"}]},
) )
@ -54,8 +129,8 @@ class TestFetchAvailableModels:
async def test_caches_successful_result(self): async def test_caches_successful_result(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags", f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]}, payload={"models": [{"name": "llama3.2:latest"}]},
) )
@ -66,20 +141,19 @@ class TestFetchAvailableModels:
async def test_returns_empty_on_http_500(self): async def test_returns_empty_on_http_500(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"}) m.add_get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
models = await router.fetch.available_models(MOCK_OLLAMA_EP) models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set() assert models == set()
async def test_returns_empty_on_connection_error(self): async def test_returns_empty_on_connection_error(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
import aiohttp import aiohttp
with patch.object(router, "config", cfg), aioresponses() as m: cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
m.get( with patch.object(router, "config", cfg), mock_probe() as m:
m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags", f"{MOCK_OLLAMA_EP}/api/tags",
exception=aiohttp.ClientConnectorError( exception=aiohttp.ClientConnectionError(
connection_key=MagicMock(host="mock-ollama", port=11434), "Cannot connect to host mock-ollama:11434 [Connection refused]"
os_error=OSError(111, "refused"),
), ),
) )
models = await router.fetch.available_models(MOCK_OLLAMA_EP) models = await router.fetch.available_models(MOCK_OLLAMA_EP)
@ -87,8 +161,8 @@ class TestFetchAvailableModels:
async def test_stale_cache_returned_while_refresh_runs(self): async def test_stale_cache_returned_while_refresh_runs(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags", f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]}, payload={"models": [{"name": "llama3.2:latest"}]},
) )
@ -99,8 +173,8 @@ class TestFetchAvailableModels:
models, _ = router._models_cache[MOCK_OLLAMA_EP] models, _ = router._models_cache[MOCK_OLLAMA_EP]
router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400) router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/tags", f"{MOCK_OLLAMA_EP}/api/tags",
payload={"models": [{"name": "llama3.2:latest"}]}, payload={"models": [{"name": "llama3.2:latest"}]},
) )
@ -114,8 +188,8 @@ class TestFetchAvailableModels:
async with router._available_error_cache_lock: async with router._available_error_cache_lock:
router._available_error_cache[MOCK_OLLAMA_EP] = time.time() router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
with patch.object(router, "config", cfg), aioresponses(): with patch.object(router, "config", cfg), mock_probe():
# No HTTP mock registered — if a call happens it will raise # No route registered — if a call happens it raises AssertionError
models = await router.fetch.available_models(MOCK_OLLAMA_EP) models = await router.fetch.available_models(MOCK_OLLAMA_EP)
assert models == set() assert models == set()
@ -123,8 +197,8 @@ class TestFetchAvailableModels:
class TestFetchLoadedModels: class TestFetchLoadedModels:
async def test_ollama_ps(self): async def test_ollama_ps(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/ps", f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "llama3.2:latest"}]}, payload={"models": [{"name": "llama3.2:latest"}]},
) )
@ -133,8 +207,8 @@ class TestFetchLoadedModels:
async def test_llama_server_filters_loaded(self): async def test_llama_server_filters_loaded(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_LLAMA_EP}/models", f"{MOCK_LLAMA_EP}/models",
payload={"data": [ payload={"data": [
{"id": "model-a", "status": {"value": "loaded"}}, {"id": "model-a", "status": {"value": "loaded"}},
@ -146,8 +220,8 @@ class TestFetchLoadedModels:
async def test_llama_server_no_status_field_always_loaded(self): async def test_llama_server_no_status_field_always_loaded(self):
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP]) cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_LLAMA_EP}/models", f"{MOCK_LLAMA_EP}/models",
payload={"data": [{"id": "always-on-model"}]}, payload={"data": [{"id": "always-on-model"}]},
) )
@ -156,8 +230,8 @@ class TestFetchLoadedModels:
async def test_returns_empty_on_error(self): async def test_returns_empty_on_error(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={}) m.add_get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP) models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert models == set() assert models == set()
@ -170,8 +244,8 @@ class TestFetchLoadedModels:
async def test_caches_result(self): async def test_caches_result(self):
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/ps", f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "qwen:7b"}]}, payload={"models": [{"name": "qwen:7b"}]},
) )
@ -183,15 +257,15 @@ class TestFetchLoadedModels:
# Regression: issue #83 — /api/ps failures must be recorded so # Regression: issue #83 — /api/ps failures must be recorded so
# `choose_endpoint` can exclude unhealthy backends from routing. # `choose_endpoint` can exclude unhealthy backends from routing.
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[]) cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={}) m.add_get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
await router.fetch.loaded_models(MOCK_OLLAMA_EP) await router.fetch.loaded_models(MOCK_OLLAMA_EP)
assert MOCK_OLLAMA_EP in router._loaded_error_cache assert MOCK_OLLAMA_EP in router._loaded_error_cache
async def test_records_error_for_llama_server_on_failure(self): async def test_records_error_for_llama_server_on_failure(self):
cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP]) cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP])
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={}) m.add_get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
await router.fetch.loaded_models(MOCK_LLAMA_EP) await router.fetch.loaded_models(MOCK_LLAMA_EP)
assert MOCK_LLAMA_EP in router._loaded_error_cache assert MOCK_LLAMA_EP in router._loaded_error_cache
@ -201,8 +275,8 @@ class TestFetchLoadedModels:
# network probe instead of short-circuiting on the error cache. # network probe instead of short-circuiting on the error cache.
async with router._loaded_error_cache_lock: async with router._loaded_error_cache_lock:
router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301 router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301
with patch.object(router, "config", cfg), aioresponses() as m: with patch.object(router, "config", cfg), mock_probe() as m:
m.get( m.add_get(
f"{MOCK_OLLAMA_EP}/api/ps", f"{MOCK_OLLAMA_EP}/api/ps",
payload={"models": [{"name": "qwen:7b"}]}, payload={"models": [{"name": "qwen:7b"}]},
) )

View file

@ -10,6 +10,7 @@ import pytest
from fastapi import HTTPException from fastapi import HTTPException
import router import router
from api import openai as api_openai
_BYPASS = HTTPException(status_code=599, detail="bypassed") _BYPASS = HTTPException(status_code=599, detail="bypassed")
@ -47,8 +48,8 @@ class TestOpenAIChatCompletionsCacheHit:
# Patch the route's references to both helpers — they're imported by name # Patch the route's references to both helpers — they're imported by name
# into router's namespace at module load time. # into router's namespace at module load time.
with ( with (
patch.object(router, "get_llm_cache", return_value=fake), patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint", patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))), AsyncMock(side_effect=AssertionError("backend must not be reached"))),
): ):
resp = await client.post( resp = await client.post(
@ -70,8 +71,8 @@ class TestOpenAIChatCompletionsCacheHit:
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload): async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload) fake = _FakeCache(cache_hit_payload)
with ( with (
patch.object(router, "get_llm_cache", return_value=fake), patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint", patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))), AsyncMock(side_effect=AssertionError("backend must not be reached"))),
): ):
resp = await client.post( resp = await client.post(
@ -98,8 +99,8 @@ class TestOpenAIChatCompletionsCacheHit:
"""When nomyo.cache=False, get_chat is never called even if a cache exists.""" """When nomyo.cache=False, get_chat is never called even if a cache exists."""
fake = _FakeCache(b"") # has a response, but should never be consulted fake = _FakeCache(b"") # has a response, but should never be consulted
with ( with (
patch.object(router, "get_llm_cache", return_value=fake), patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint", patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)), AsyncMock(side_effect=_BYPASS)),
): ):
resp = await client.post( resp = await client.post(
@ -117,8 +118,8 @@ class TestOpenAIChatCompletionsCacheHit:
async def test_no_cache_configured_bypasses_cache_check(self, client): async def test_no_cache_configured_bypasses_cache_check(self, client):
"""get_llm_cache() returning None should not break the route.""" """get_llm_cache() returning None should not break the route."""
with ( with (
patch.object(router, "get_llm_cache", return_value=None), patch.object(api_openai, "get_llm_cache", return_value=None),
patch.object(router, "choose_endpoint", patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=_BYPASS)), AsyncMock(side_effect=_BYPASS)),
): ):
resp = await client.post( resp = await client.post(
@ -140,8 +141,8 @@ class TestOpenAICompletionsCacheHit:
async def test_nonstream_cache_hit(self, client, cache_hit_payload): async def test_nonstream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload) fake = _FakeCache(cache_hit_payload)
with ( with (
patch.object(router, "get_llm_cache", return_value=fake), patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint", patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))), AsyncMock(side_effect=AssertionError("backend must not be reached"))),
): ):
resp = await client.post( resp = await client.post(
@ -163,8 +164,8 @@ class TestOpenAICompletionsCacheHit:
async def test_stream_cache_hit(self, client, cache_hit_payload): async def test_stream_cache_hit(self, client, cache_hit_payload):
fake = _FakeCache(cache_hit_payload) fake = _FakeCache(cache_hit_payload)
with ( with (
patch.object(router, "get_llm_cache", return_value=fake), patch.object(api_openai, "get_llm_cache", return_value=fake),
patch.object(router, "choose_endpoint", patch.object(api_openai, "choose_endpoint",
AsyncMock(side_effect=AssertionError("backend must not be reached"))), AsyncMock(side_effect=AssertionError("backend must not be reached"))),
): ):
resp = await client.post( resp = await client.post(

151
test/test_stream_errors.py Normal file
View file

@ -0,0 +1,151 @@
"""
Unit tests for transitive backend-error handling in the four Ollama-native
streaming generators (``/api/generate``, ``/api/chat``, ``/api/embeddings``,
``/api/embed``).
These reproduce the reported failure mode: a backend (nginx in front of ollama)
returns a 504 Gateway Time-out *while the response is being streamed*, so the
``ollama`` client raises ``ResponseError`` from inside the StreamingResponse
generator. Before the fix this escaped as an opaque "Exception in ASGI
application" traceback; now ``_handle_stream_error`` logs the endpoint/model and
emits a terminal Ollama-format ``{"error": ..., "status_code": ...}`` line.
No real backend required the ollama client and routing are mocked.
"""
import json
from contextlib import ExitStack
from unittest.mock import AsyncMock, patch
import httpx
import ollama
import openai
import pytest
from conftest import TEST_OLLAMA
pytestmark = pytest.mark.asyncio
# ── Fakes ─────────────────────────────────────────────────────────────────────
class _Chunk:
"""Minimal Ollama-native streaming chunk the generators can consume."""
prompt_eval_count = 0
eval_count = 0
done = False
message = None
response = None
done_reason = None
def model_dump_json(self):
return '{"model": "fake", "done": false}'
def _one_then_raise(exc):
"""Async generator: yield one valid chunk, then fail mid-stream."""
async def _gen():
yield _Chunk()
raise exc
return _gen()
class _FakeAsyncClient:
"""Stand-in for ``ollama.AsyncClient`` that fails with ``exc``.
Streaming methods (chat/generate) fail *after* one chunk to mimic a
mid-stream 504; the embedding methods fail on the initial await.
"""
def __init__(self, exc, *args, **kwargs):
self._exc = exc
async def chat(self, **kwargs):
return _one_then_raise(self._exc)
async def generate(self, **kwargs):
return _one_then_raise(self._exc)
async def embeddings(self, **kwargs):
raise self._exc
async def embed(self, **kwargs):
raise self._exc
def _patches(exc, mark_unhealthy):
"""Patch routing + the ollama client so the native path hits ``exc``."""
stack = ExitStack()
stack.enter_context(
patch("api.ollama.choose_endpoint", AsyncMock(return_value=(TEST_OLLAMA, "fake")))
)
stack.enter_context(patch("api.ollama.is_openai_compatible", lambda ep: False))
stack.enter_context(patch("api.ollama.decrement_usage", AsyncMock()))
stack.enter_context(patch("api.ollama._mark_backend_unhealthy", mark_unhealthy))
# The native path now fetches a cached client via get_ollama_client() rather
# than constructing ollama.AsyncClient inline, so patch that seam.
stack.enter_context(
patch("api.ollama.get_ollama_client", lambda *a, **k: _FakeAsyncClient(exc))
)
return stack
# Route → request payload. stream=True only matters for chat/generate.
_ROUTES = {
"/api/chat": {"model": "fake", "stream": True, "messages": [{"role": "user", "content": "hi"}]},
"/api/generate": {"model": "fake", "stream": True, "prompt": "hi"},
"/api/embeddings": {"model": "fake", "prompt": "hi"},
"/api/embed": {"model": "fake", "input": "hi"},
}
def _last_json_line(text):
lines = [l for l in text.strip().split("\n") if l.strip()]
assert lines, "expected at least one ndjson line in the response body"
return json.loads(lines[-1])
# ── Tests ─────────────────────────────────────────────────────────────────────
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
async def test_504_surfaces_as_error_line(client, route, payload):
"""A 504 ResponseError becomes a terminal {"error", "status_code"} line."""
exc = ollama.ResponseError("<html>504 Gateway Time-out</html>", 504)
mark = AsyncMock()
with _patches(exc, mark):
resp = await client.post(route, json=payload)
# Streaming already started (or single-shot) → HTTP status is 200, the
# error is delivered in-band rather than as a 5xx crash.
assert resp.status_code == 200
err = _last_json_line(resp.text)
assert "error" in err
assert "504" in err["error"]
assert err["status_code"] == 504
# A plain 504 is not a connection-class failure → endpoint stays healthy.
mark.assert_not_called()
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
async def test_no_asgi_500_on_backend_failure(client, route, payload):
"""The generator must never let the backend error escape as a 500."""
exc = ollama.ResponseError("boom", 502)
with _patches(exc, AsyncMock()):
resp = await client.post(route, json=payload)
assert resp.status_code == 200
assert resp.status_code != 500
async def test_connection_error_marks_backend_unhealthy(client):
"""A connection-class failure mid-stream marks (endpoint, model) unhealthy."""
exc = openai.APIConnectionError(request=httpx.Request("POST", "http://x"))
mark = AsyncMock()
with _patches(exc, mark):
resp = await client.post("/api/chat", json=_ROUTES["/api/chat"])
assert resp.status_code == 200
err = _last_json_line(resp.text)
assert "error" in err
mark.assert_awaited_once()
# Called with the routed endpoint + model.
called_ep, called_model = mark.await_args.args[0], mark.await_args.args[1]
assert called_ep == TEST_OLLAMA
assert called_model == "fake"

177
tokens.py Normal file
View file

@ -0,0 +1,177 @@
"""Token-count write-behind pipeline.
``token_worker`` drains ``token_queue`` into the in-memory buffer (and into
``token_usage_counts`` for immediate SSE reporting). ``flush_buffer``
periodically persists the buffer to SQLite via ``TokenDatabase``.
``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left.
The lock order is ``buffer_lock`` then ``token_usage_lock`` see
choose_endpoint for why we never combine them with usage_lock.
"""
import asyncio
from datetime import datetime, timezone
from state import (
token_queue,
token_buffer,
time_series_buffer,
buffer_lock,
token_usage_counts,
token_usage_lock,
FLUSH_INTERVAL,
)
from sse import _capture_snapshot, _distribute_snapshot
from db import get_db
async def token_worker() -> None:
try:
while True:
endpoint, model, prompt, comp = await token_queue.get()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
# Accumulate counts in memory buffer (protected by lock)
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
# Add to time series buffer with timestamp (UTC)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
# Update in-memory counts for immediate reporting
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[token_worker] Task cancelled, processing remaining queue items...")
# Process any remaining items in the queue before exiting
while not token_queue.empty():
try:
endpoint, model, prompt, comp = token_queue.get_nowait()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.QueueEmpty:
break
print("[token_worker] Task cancelled, remaining items processed.")
raise
async def flush_buffer() -> None:
"""Periodically flush accumulated token counts to the database."""
try:
while True:
await asyncio.sleep(FLUSH_INTERVAL)
# Flush token counts and time series (protected by lock)
async with buffer_lock:
if token_buffer:
# Copy buffer before releasing lock for DB operation
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock to avoid blocking
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
# Flush any remaining data before exiting
try:
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
except Exception as e:
print(f"[flush_buffer] Error during shutdown flush: {e}")
raise
async def flush_remaining_buffers() -> None:
"""
Flush any in-memory buffers to the database on shutdown.
This is designed to be safely invoked during shutdown and should not raise.
"""
try:
flushed_entries = 0
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
flushed_entries += sum(len(v) for v in token_buffer.values())
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
flushed_entries += len(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock
db = get_db()
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
if flushed_entries:
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
else:
print("[shutdown] No in-memory entries to flush on shutdown.")
except Exception as e:
# Do not raise during shutdown log and continue teardown
print(f"[shutdown] Error flushing remaining buffers: {e}")

File diff suppressed because it is too large Load diff