Compare commits
15 commits
dev-0.9.x-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 10b9a4e910 | |||
| f827e0deba | |||
|
|
4ed75cfea6 | ||
|
|
04ce9d92da | ||
| 0de6a5a65d | |||
| ff8cfce9c7 | |||
| f5e08aa896 | |||
| fd5b131ff4 | |||
| 775829796c | |||
|
|
3e1e206740 | ||
| 7cb17cb791 | |||
| 2d1f2506b9 | |||
| 2e58a383c5 | |||
|
|
9bba10d7f4 | ||
| 770de3b93f |
35 changed files with 4052 additions and 106320 deletions
|
|
@ -11,7 +11,7 @@ jobs:
|
|||
opencode:
|
||||
if: |
|
||||
contains(github.event.comment.body, '/oc') ||
|
||||
contains(github.event.comment.body, '/opencode')
|
||||
contains(github.event.review.body, '/oc')
|
||||
runs-on: docker-amd64
|
||||
container:
|
||||
image: node:lts-bookworm
|
||||
|
|
@ -54,9 +54,7 @@ jobs:
|
|||
uses: ./.opencode-action
|
||||
with:
|
||||
nomyo_api_key: ${{ secrets.NOMYO_API_KEY }}
|
||||
model: nomyo/unsloth/Qwen3.6-35B-A3B-GGUF:UD-Q4_K_M
|
||||
model: nomyo/unsloth/Qwen3.6-35B-A3B-MTP-GGUF:Q4_K_XL
|
||||
forgejo_api_url: https://bitfreedom.net/code/
|
||||
forgejo_token: ${{ secrets.FORGEJO_TOKEN }}
|
||||
forgejo_push_token: ${{ secrets.FORGEJO_PUSH_TOKEN }}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,278 +0,0 @@
|
|||
"""Management / observability routes.
|
||||
|
||||
Read-only endpoints used by the dashboard and external monitoring:
|
||||
* usage counters and token-counts breakdown,
|
||||
* conversation-affinity introspection,
|
||||
* endpoint health summary,
|
||||
* LLM-response cache stats and invalidation,
|
||||
* SSE live-stream of usage updates,
|
||||
* hostname and ``/health`` probe.
|
||||
"""
|
||||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache
|
||||
from config import get_config
|
||||
from db import get_db
|
||||
from state import (
|
||||
usage_counts,
|
||||
token_usage_counts,
|
||||
_affinity_map,
|
||||
_affinity_lock,
|
||||
)
|
||||
from sse import subscribe, unsubscribe
|
||||
from backends.normalize import _normalize_llama_model_name
|
||||
from backends.probe import _endpoint_health
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/api/token_counts")
|
||||
async def token_counts_proxy():
|
||||
breakdown = []
|
||||
total = 0
|
||||
async for entry in get_db().load_token_counts():
|
||||
total += entry['total_tokens']
|
||||
breakdown.append({
|
||||
"endpoint": entry["endpoint"],
|
||||
"model": entry["model"],
|
||||
"input_tokens": entry["input_tokens"],
|
||||
"output_tokens": entry["output_tokens"],
|
||||
"total_tokens": entry["total_tokens"],
|
||||
})
|
||||
return {"total_tokens": total, "breakdown": breakdown}
|
||||
|
||||
|
||||
@router.post("/api/aggregate_time_series_days")
|
||||
async def aggregate_time_series_days_proxy(request: Request):
|
||||
"""
|
||||
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if not body_bytes:
|
||||
days = 30
|
||||
trim_old = False
|
||||
else:
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
days = int(payload.get("days", 30))
|
||||
trim_old = bool(payload.get("trim_old", False))
|
||||
except Exception:
|
||||
days = 30
|
||||
trim_old = False
|
||||
aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old)
|
||||
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
|
||||
|
||||
|
||||
@router.post("/api/stats")
|
||||
async def stats_proxy(request: Request, model: Optional[str] = None):
|
||||
"""
|
||||
Return token usage statistics for a specific model.
|
||||
"""
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
|
||||
if not model:
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
model = payload.get("model")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
db = get_db()
|
||||
token_data = await db.get_token_counts_for_model(model)
|
||||
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No token data found for this model"
|
||||
)
|
||||
|
||||
time_series = [
|
||||
entry async for entry in db.get_time_series_for_model(model)
|
||||
]
|
||||
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'input_tokens': token_data['input_tokens'],
|
||||
'output_tokens': token_data['output_tokens'],
|
||||
'total_tokens': token_data['total_tokens'],
|
||||
'time_series': time_series,
|
||||
'endpoint_distribution': endpoint_distribution,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/affinity_stats")
|
||||
async def affinity_stats(request: Request):
|
||||
"""
|
||||
Aggregate live conversation-affinity pins, one entry per pinned conversation.
|
||||
Each entry exposes only the endpoint, model, and remaining TTL in seconds —
|
||||
no fingerprints or content. When conversation_affinity is disabled the
|
||||
`entries` list is always empty.
|
||||
"""
|
||||
config = get_config()
|
||||
if not config.conversation_affinity:
|
||||
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
|
||||
|
||||
now = time.monotonic()
|
||||
entries: list[dict] = []
|
||||
llama_eps = set(config.llama_server_endpoints)
|
||||
async with _affinity_lock:
|
||||
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
|
||||
remaining = expires_at - now
|
||||
if remaining <= 0:
|
||||
_affinity_map.pop(fp, None)
|
||||
continue
|
||||
# Mirror the normalisation used by /api/ps_details so the dashboard
|
||||
# can join affinity entries to PS rows by (endpoint, model).
|
||||
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
|
||||
entries.append({
|
||||
"endpoint": ep,
|
||||
"model": display_model,
|
||||
"remaining": round(remaining, 2),
|
||||
})
|
||||
return {
|
||||
"enabled": True,
|
||||
"ttl": config.conversation_affinity_ttl,
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/usage")
|
||||
async def usage_proxy(request: Request):
|
||||
"""
|
||||
Return a snapshot of the usage counter for each endpoint.
|
||||
Useful for debugging / monitoring.
|
||||
"""
|
||||
return {"usage_counts": usage_counts,
|
||||
"token_usage_counts": token_usage_counts}
|
||||
|
||||
|
||||
@router.get("/api/config")
|
||||
async def config_proxy(request: Request):
|
||||
"""
|
||||
Return a simple JSON object that contains the configured
|
||||
Ollama endpoints and llama_server_endpoints. The front‑end uses this
|
||||
to display which endpoints are being proxied and their health.
|
||||
Status is "error" when either liveness (/api/version) or routing
|
||||
health (/api/ps) fails — see issue #83.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
async def check(url: str) -> dict:
|
||||
return {"url": url, **(await _endpoint_health(url, timeout=5))}
|
||||
|
||||
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
|
||||
llama_results = []
|
||||
if config.llama_server_endpoints:
|
||||
llama_results = await asyncio.gather(
|
||||
*[check(ep) for ep in config.llama_server_endpoints]
|
||||
)
|
||||
|
||||
return {
|
||||
"endpoints": ollama_results,
|
||||
"llama_server_endpoints": llama_results,
|
||||
"require_router_api_key": bool(config.router_api_key),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/cache/stats")
|
||||
async def cache_stats():
|
||||
"""Return hit/miss counters and configuration for the LLM response cache."""
|
||||
c = get_llm_cache()
|
||||
if c is None:
|
||||
return {"enabled": False}
|
||||
return {"enabled": True, **c.stats()}
|
||||
|
||||
|
||||
@router.post("/api/cache/invalidate")
|
||||
async def cache_invalidate():
|
||||
"""Clear all entries from the LLM response cache and reset counters."""
|
||||
c = get_llm_cache()
|
||||
if c is None:
|
||||
return {"enabled": False, "cleared": False}
|
||||
await c.clear()
|
||||
return {"enabled": True, "cleared": True}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_proxy(request: Request):
|
||||
"""
|
||||
Health‑check endpoint for monitoring the proxy.
|
||||
|
||||
* Queries each configured endpoint for both liveness and routing health:
|
||||
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
|
||||
OpenAI-compatible endpoints at `/models`.
|
||||
* Returns a JSON object containing:
|
||||
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
|
||||
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
|
||||
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
|
||||
"""
|
||||
config = get_config()
|
||||
# Run all health checks in parallel.
|
||||
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
|
||||
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
|
||||
# (vLLM, llama-server, external) expose /models, which serves both
|
||||
# purposes. Probing /api/version alone would miss the case where the
|
||||
# Ollama process is up but /api/ps is failing — see issue #83.
|
||||
all_endpoints = list(config.endpoints)
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
all_endpoints += llama_eps_extra
|
||||
|
||||
probe_results = await asyncio.gather(
|
||||
*(_endpoint_health(ep) for ep in all_endpoints),
|
||||
)
|
||||
|
||||
health_summary = dict(zip(all_endpoints, probe_results))
|
||||
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
|
||||
|
||||
response_payload = {
|
||||
"status": "ok" if overall_ok else "error",
|
||||
"endpoints": health_summary,
|
||||
}
|
||||
|
||||
http_status = 200 if overall_ok else 503
|
||||
return JSONResponse(content=response_payload, status_code=http_status)
|
||||
|
||||
|
||||
@router.get("/api/hostname")
|
||||
async def get_hostname():
|
||||
"""Return the hostname of the machine running the router."""
|
||||
return JSONResponse(content={"hostname": socket.gethostname()})
|
||||
|
||||
|
||||
@router.get("/api/usage-stream")
|
||||
async def usage_stream(request: Request):
|
||||
"""
|
||||
Server‑Sent‑Events that emits a JSON payload every time the
|
||||
global `usage_counts` dictionary changes.
|
||||
"""
|
||||
async def event_generator():
|
||||
# The queue that receives *every* new snapshot
|
||||
queue = await subscribe()
|
||||
try:
|
||||
while True:
|
||||
# If the client disconnects, cancel the loop
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
data = await queue.get()
|
||||
if data is None:
|
||||
break
|
||||
# Send the data as a single SSE message
|
||||
yield f"data: {data}\n\n"
|
||||
finally:
|
||||
# Clean‑up: unsubscribe from the broadcast channel
|
||||
await unsubscribe(queue)
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
1134
api/ollama.py
1134
api/ollama.py
File diff suppressed because it is too large
Load diff
804
api/openai.py
804
api/openai.py
|
|
@ -1,804 +0,0 @@
|
|||
"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``,
|
||||
``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``).
|
||||
|
||||
The chat-completions and completions handlers carry the full reactive-trim
|
||||
logic for ``exceed_context_size_error`` plus connection-failure rerouting
|
||||
(``_mark_backend_unhealthy``). The streaming branches assemble cached
|
||||
responses on the fly so caching works for both streaming and non-streaming
|
||||
clients.
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import math
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from cache import get_llm_cache, openai_nonstream_to_sse
|
||||
from config import get_config
|
||||
from context_window import (
|
||||
_count_message_tokens,
|
||||
_trim_messages_for_context,
|
||||
_calibrated_trim_target,
|
||||
_endpoint_nctx,
|
||||
_CTX_TRIM_SMALL_LIMIT,
|
||||
)
|
||||
from fingerprint import _conversation_fingerprint
|
||||
from security import _mask_secrets
|
||||
from state import token_queue, app_state, default_headers
|
||||
from backends.health import _is_backend_connection_error, _mark_backend_unhealthy
|
||||
from backends.normalize import (
|
||||
dedupe_on_keys,
|
||||
ep2base,
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
_normalize_llama_model_name,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
from backends.sessions import _make_openai_client, get_session
|
||||
from requests.messages import _strip_assistant_prefill, _strip_images_from_messages
|
||||
from requests.rechunk import rechunk
|
||||
from routing import choose_endpoint, decrement_usage
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def openai_embedding_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
doc = payload.get("input")
|
||||
|
||||
# Normalize multimodal input: extract only text parts for embedding models
|
||||
if isinstance(doc, list):
|
||||
normalized = []
|
||||
for item in doc:
|
||||
if isinstance(item, dict):
|
||||
# Multimodal content part - extract text only, skip images
|
||||
if item.get("type") == "text":
|
||||
normalized.append(item.get("text", ""))
|
||||
# Skip image_url and other non-text types
|
||||
else:
|
||||
normalized.append(item)
|
||||
doc = normalized if len(normalized) != 1 else normalized[0]
|
||||
elif isinstance(doc, dict) and doc.get("type") == "text":
|
||||
doc = doc.get("text", "")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not doc:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'input'"
|
||||
)
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# 2. Endpoint logic
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
if is_openai_compatible(endpoint):
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
else:
|
||||
api_key = "ollama"
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
|
||||
|
||||
try:
|
||||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||||
result = async_gen.model_dump()
|
||||
for item in result.get("data", []):
|
||||
emb = item.get("embedding")
|
||||
if emb:
|
||||
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
|
||||
return JSONResponse(content=result)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def openai_chat_completions_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
messages = payload.get("messages")
|
||||
frequency_penalty = payload.get("frequency_penalty")
|
||||
presence_penalty = payload.get("presence_penalty")
|
||||
response_format = payload.get("response_format")
|
||||
seed = payload.get("seed")
|
||||
stop = payload.get("stop")
|
||||
stream = payload.get("stream")
|
||||
stream_options = payload.get("stream_options")
|
||||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
tools = payload.get("tools")
|
||||
logprobs = payload.get("logprobs")
|
||||
top_logprobs = payload.get("top_logprobs")
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
messages = _strip_assistant_prefill(messages)
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"response_format": response_format,
|
||||
"stream_options": stream_options or {"include_usage": True },
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"seed": seed,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Reject unsupported image formats (SVG) before doing any work
|
||||
for _msg in messages:
|
||||
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
|
||||
if _item.get("type") == "image_url":
|
||||
_url = (_item.get("image_url") or {}).get("url", "")
|
||||
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
|
||||
)
|
||||
|
||||
# Cache lookup — before endpoint selection
|
||||
_cache = get_llm_cache()
|
||||
if _cache is not None and _cache_enabled:
|
||||
_cached = await _cache.get_chat("openai_chat", model, messages)
|
||||
if _cached is not None:
|
||||
if stream:
|
||||
_sse = openai_nonstream_to_sse(_cached, model)
|
||||
async def _serve_cached_ochat_stream():
|
||||
yield _sse
|
||||
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
async def _serve_cached_ochat_json():
|
||||
yield _cached
|
||||
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
|
||||
|
||||
# 2. Endpoint logic
|
||||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
# 3. Helpers and API call — done in handler scope so try/except works reliably
|
||||
async def _normalize_images_in_messages(msgs: list) -> list:
|
||||
"""Fetch remote image URLs and convert them to base64 data URLs so
|
||||
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
||||
resolved = []
|
||||
for msg in msgs:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
resolved.append(msg)
|
||||
continue
|
||||
new_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url", "")
|
||||
if url and not url.startswith("data:"):
|
||||
try:
|
||||
http: aiohttp.ClientSession = app_state["session"]
|
||||
async with http.get(url) as resp:
|
||||
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
img_bytes = await resp.read()
|
||||
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
||||
new_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{ctype};base64,{b64}"}
|
||||
})
|
||||
except Exception as _ie:
|
||||
print(f"[image] Failed to fetch image URL: {_ie}")
|
||||
new_content.append(item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
resolved.append({**msg, "content": new_content})
|
||||
return resolved
|
||||
|
||||
# Make the API call in handler scope — try/except inside async generators is unreliable
|
||||
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
|
||||
send_params = params
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
||||
send_params = {**params, "messages": resolved_msgs}
|
||||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
|
||||
_pre_est = _count_message_tokens(send_params.get("messages", []))
|
||||
if _pre_est > _pre_target:
|
||||
_pre_msgs = send_params.get("messages", [])
|
||||
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||||
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||||
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||||
send_params = {**send_params, "messages": _pre_trimmed}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**send_params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
|
||||
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
|
||||
if "does not support tools" in _e_str:
|
||||
# Model doesn't support tools — retry without them
|
||||
print(f"[ochat] retry: no tools", flush=True)
|
||||
try:
|
||||
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
||||
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_ctx_err:
|
||||
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
|
||||
if not n_ctx_limit:
|
||||
import re as _re
|
||||
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
|
||||
if not n_ctx_limit:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||||
|
||||
msgs_to_trim = send_params.get("messages", [])
|
||||
try:
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
except Exception as _helper_exc:
|
||||
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
dropped = len(msgs_to_trim) - len(trimmed_messages)
|
||||
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-1 ok", flush=True)
|
||||
except Exception as e2:
|
||||
_e2_str = str(e2)
|
||||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||||
# Still too large — tool definitions likely consuming too many tokens, strip them too
|
||||
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
|
||||
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
|
||||
print(f"[ctx-trim] retry-2 ok", flush=True)
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif _is_backend_connection_error(e):
|
||||
# Upstream connection failed (e.g. llama-server in router mode
|
||||
# whose delegated worker died). Mark (endpoint, model) so the
|
||||
# next request reroutes; the client will retry this one.
|
||||
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
# Model doesn't support images — strip and retry
|
||||
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||||
try:
|
||||
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
|
||||
except Exception:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
else:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
|
||||
# 4. Async generator — only streams the already-established async_gen
|
||||
async def stream_ochat_response():
|
||||
try:
|
||||
if stream == True:
|
||||
content_parts: list[str] = []
|
||||
usage_snapshot: dict = {}
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
has_content = delta.content is not None
|
||||
has_reasoning = (
|
||||
getattr(delta, "reasoning_content", None) is not None
|
||||
or getattr(delta, "reasoning", None) is not None
|
||||
)
|
||||
has_tool_calls = getattr(delta, "tool_calls", None) is not None
|
||||
if has_content or has_reasoning or has_tool_calls:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
if has_content and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Detect context exhaustion mid-generation for small-ctx models.
|
||||
# Guard: skip if max_tokens was set in the request — finish_reason=length
|
||||
# could just mean the caller's token budget was exhausted, not the context window.
|
||||
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
|
||||
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
|
||||
_inferred_nctx = (prompt_tok + comp_tok) or 0
|
||||
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||||
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||||
# Cache assembled streaming response — before [DONE] so it always runs
|
||||
if _cache is not None and _cache_enabled and content_parts:
|
||||
assembled = orjson.dumps({
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
|
||||
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_chat("openai_chat", model, messages, assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response
|
||||
if _cache is not None and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ochat_response(),
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def openai_completions_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Parse and validate request
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
prompt = payload.get("prompt")
|
||||
frequency_penalty = payload.get("frequency_penalty")
|
||||
presence_penalty = payload.get("presence_penalty")
|
||||
seed = payload.get("seed")
|
||||
stop = payload.get("stop")
|
||||
stream = payload.get("stream")
|
||||
stream_options = payload.get("stream_options")
|
||||
temperature = payload.get("temperature")
|
||||
top_p = payload.get("top_p")
|
||||
max_tokens = payload.get("max_tokens")
|
||||
max_completion_tokens = payload.get("max_completion_tokens")
|
||||
suffix = payload.get("suffix")
|
||||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||||
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'model'"
|
||||
)
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing required field 'prompt'"
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
model = model[0]
|
||||
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
optional_params = {
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"presence_penalty": presence_penalty,
|
||||
"seed": seed,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
"stream_options": stream_options or {"include_usage": True },
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"suffix": suffix
|
||||
}
|
||||
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Cache lookup — completions prompt mapped to a single-turn messages list
|
||||
_cache = get_llm_cache()
|
||||
_compl_messages = [{"role": "user", "content": prompt}]
|
||||
if _cache is not None and _cache_enabled:
|
||||
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
|
||||
if _cached is not None:
|
||||
if stream:
|
||||
_sse = openai_nonstream_to_sse(_cached, model)
|
||||
async def _serve_cached_ocompl_stream():
|
||||
yield _sse
|
||||
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
async def _serve_cached_ocompl_json():
|
||||
yield _cached
|
||||
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
|
||||
|
||||
# 2. Endpoint logic
|
||||
_affinity_key = _conversation_fingerprint(model, None, prompt)
|
||||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
# Make the API call in handler scope (try/except inside async generators is unreliable)
|
||||
try:
|
||||
async_gen = await oclient.completions.create(**params)
|
||||
except Exception as e:
|
||||
if _is_backend_connection_error(e):
|
||||
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||||
await _mark_backend_unhealthy(endpoint, model, str(e))
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise
|
||||
|
||||
async def stream_ocompletions_response(model=model):
|
||||
try:
|
||||
if stream == True:
|
||||
text_parts: list[str] = []
|
||||
usage_snapshot: dict = {}
|
||||
async for chunk in async_gen:
|
||||
data = (
|
||||
chunk.model_dump_json()
|
||||
if hasattr(chunk, "model_dump_json")
|
||||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
has_text = getattr(choice, "text", None) is not None
|
||||
has_reasoning = (
|
||||
getattr(choice, "reasoning_content", None) is not None
|
||||
or getattr(choice, "reasoning", None) is not None
|
||||
)
|
||||
if has_text or has_reasoning or choice.finish_reason is not None:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
if has_text and choice.text:
|
||||
text_parts.append(choice.text)
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Cache assembled streaming response — before [DONE] so it always runs
|
||||
if _cache is not None and _cache_enabled and text_parts:
|
||||
assembled = orjson.dumps({
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
|
||||
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
||||
}) + b"\n"
|
||||
try:
|
||||
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
else orjson.dumps(async_gen)
|
||||
)
|
||||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||||
yield cache_bytes
|
||||
# Cache non-streaming response
|
||||
if _cache is not None and _cache_enabled:
|
||||
try:
|
||||
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
|
||||
except Exception as _ce:
|
||||
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
|
||||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ocompletions_response(),
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def openai_models_proxy(request: Request):
|
||||
"""
|
||||
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
|
||||
|
||||
For Ollama endpoints: queries /api/tags (all models)
|
||||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||||
"""
|
||||
config = get_config()
|
||||
# 1. Query Ollama endpoints for all models via /api/tags
|
||||
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||||
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
||||
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
||||
# 3. Query llama-server endpoints for loaded models via /v1/models
|
||||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||||
llama_tasks = [
|
||||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||||
for ep in all_llama_endpoints
|
||||
]
|
||||
|
||||
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
||||
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
|
||||
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
||||
|
||||
models = {'data': []}
|
||||
|
||||
# Add Ollama models (if any)
|
||||
if ollama_models:
|
||||
for modellist in ollama_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
||||
model['id'] = model.get('name', model.get('id', ''))
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'].append(model)
|
||||
|
||||
# Add external OpenAI models (if any)
|
||||
if ext_openai_models:
|
||||
for modellist in ext_openai_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys():
|
||||
model['id'] = model.get('name', model.get('id', ''))
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'].append(model)
|
||||
|
||||
# Add llama-server models (all available, not just loaded)
|
||||
if llama_models:
|
||||
for modellist in llama_models:
|
||||
for model in modellist:
|
||||
if not "id" in model.keys():
|
||||
model['id'] = model.get('name', model.get('id', ''))
|
||||
else:
|
||||
model['name'] = model['id']
|
||||
models['data'].append(model)
|
||||
|
||||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||||
return JSONResponse(
|
||||
content={"data": dedupe_on_keys(models['data'], ['name'])},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/rerank")
|
||||
@router.post("/rerank")
|
||||
async def rerank_proxy(request: Request):
|
||||
"""
|
||||
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
||||
|
||||
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
||||
vLLM, and services such as Cohere and Jina AI.
|
||||
|
||||
Ollama does not natively support reranking; requests routed to a plain Ollama
|
||||
endpoint will receive a 501 Not Implemented response.
|
||||
|
||||
Request body:
|
||||
model (str, required) – reranker model name
|
||||
query (str, required) – search query
|
||||
documents (list[str], required) – candidate documents to rank
|
||||
top_n (int, optional) – limit returned results (default: all)
|
||||
return_documents (bool, optional) – include document text in results
|
||||
max_tokens_per_doc (int, optional) – truncation limit per document
|
||||
|
||||
Response (Jina/Cohere-compatible):
|
||||
{
|
||||
"id": "...",
|
||||
"model": "...",
|
||||
"usage": {"prompt_tokens": N, "total_tokens": N},
|
||||
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
||||
}
|
||||
"""
|
||||
config = get_config()
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||||
|
||||
model = payload.get("model")
|
||||
query = payload.get("query")
|
||||
documents = payload.get("documents")
|
||||
|
||||
if not model:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
||||
if not isinstance(documents, list) or not documents:
|
||||
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||||
|
||||
# Determine which endpoint serves this model
|
||||
try:
|
||||
endpoint, tracking_model = await choose_endpoint(model)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
# Ollama endpoints have no native rerank support
|
||||
if not is_openai_compatible(endpoint):
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=(
|
||||
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
||||
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
||||
"dedicated reranker model."
|
||||
),
|
||||
)
|
||||
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
|
||||
# Build upstream rerank request body – forward only recognised fields
|
||||
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
||||
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
||||
if optional_key in payload:
|
||||
upstream_payload[optional_key] = payload[optional_key]
|
||||
|
||||
# Determine upstream URL:
|
||||
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
||||
# External OpenAI endpoints expose /rerank under their /v1 base
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
# llama-server: endpoint may or may not already contain /v1
|
||||
if "/v1" in endpoint:
|
||||
rerank_url = f"{endpoint}/rerank"
|
||||
else:
|
||||
rerank_url = f"{endpoint}/v1/rerank"
|
||||
else:
|
||||
# External OpenAI-compatible: ep2base gives us the /v1 base
|
||||
rerank_url = f"{ep2base(endpoint)}/rerank"
|
||||
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
client: aiohttp.ClientSession = get_session(endpoint)
|
||||
try:
|
||||
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
||||
response_bytes = await resp.read()
|
||||
if resp.status >= 400:
|
||||
raise HTTPException(
|
||||
status_code=resp.status,
|
||||
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
||||
)
|
||||
data = orjson.loads(response_bytes)
|
||||
|
||||
# Record token usage if the upstream returned a usage object
|
||||
usage = data.get("usage") or {}
|
||||
prompt_tok = usage.get("prompt_tokens") or 0
|
||||
total_tok = usage.get("total_tokens") or 0
|
||||
# For reranking there are no completion tokens; we record prompt tokens only
|
||||
if prompt_tok or total_tok:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
||||
|
||||
return JSONResponse(content=data)
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""Static-asset and dashboard routes."""
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from starlette.responses import HTMLResponse, RedirectResponse
|
||||
|
||||
# Directory containing static files (resolved relative to project root).
|
||||
STATIC_DIR = Path(__file__).resolve().parent.parent / "static"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
async def redirect_favicon():
|
||||
return RedirectResponse(url="/static/favicon.ico")
|
||||
|
||||
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""
|
||||
Render the dynamic NOMYO Router dashboard listing the configured endpoints
|
||||
and the models details, availability & task status.
|
||||
"""
|
||||
index_path = STATIC_DIR / "index.html"
|
||||
try:
|
||||
return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Page not found")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
"""Backend health probes and error classification helpers.
|
||||
|
||||
Contains:
|
||||
* cache-freshness check (``_is_fresh``)
|
||||
* aiohttp response success assertion (``_ensure_success``)
|
||||
* human-readable connection-issue formatter
|
||||
* upstream-error detection that distinguishes connection failures from
|
||||
legitimate 4xx responses (``_is_backend_connection_error``)
|
||||
* per-(endpoint, model) unhealthy marker that feeds ``choose_endpoint``
|
||||
* llama-server status interpretation (``_is_llama_model_loaded`` etc.)
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
from fastapi import HTTPException
|
||||
|
||||
from security import _mask_secrets
|
||||
from state import _completion_error_cache, _completion_error_cache_lock
|
||||
|
||||
|
||||
def _is_fresh(cached_at: float, ttl: int) -> bool:
|
||||
return (time.time() - cached_at) < ttl
|
||||
|
||||
|
||||
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
||||
if resp.status >= 400:
|
||||
text = await resp.text()
|
||||
raise HTTPException(status_code=resp.status, detail=_mask_secrets(text))
|
||||
|
||||
|
||||
def _format_connection_issue(url: str, error: Exception) -> str:
|
||||
"""
|
||||
Provide a human-friendly error string for connection failures so operators
|
||||
know which endpoint and address failed from inside the container.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
host_hint = parsed.hostname or ""
|
||||
port_hint = parsed.port or ""
|
||||
|
||||
if isinstance(error, aiohttp.ClientConnectorError):
|
||||
resolved_host = getattr(error, "host", host_hint) or host_hint or "?"
|
||||
resolved_port = getattr(error, "port", port_hint) or port_hint or "?"
|
||||
parts = [
|
||||
f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).",
|
||||
"Ensure the endpoint address is reachable from within the container.",
|
||||
]
|
||||
if resolved_host in {"localhost", "127.0.0.1"}:
|
||||
parts.append(
|
||||
"Inside Docker, 'localhost' refers to the container itself; use "
|
||||
"'host.docker.internal' or a Docker network alias if the service "
|
||||
"runs on the host machine."
|
||||
)
|
||||
os_error = getattr(error, "os_error", None)
|
||||
if isinstance(os_error, OSError):
|
||||
errno = getattr(os_error, "errno", None)
|
||||
strerror = os_error.strerror or str(os_error)
|
||||
if errno is not None or strerror:
|
||||
parts.append(f"OS error [{errno}]: {strerror}.")
|
||||
elif os_error:
|
||||
parts.append(f"OS error: {os_error}.")
|
||||
parts.append(f"Original error: {error}.")
|
||||
return " ".join(parts)
|
||||
|
||||
if isinstance(error, asyncio.TimeoutError):
|
||||
return (
|
||||
f"Timed out waiting for {url}. "
|
||||
"The remote endpoint may be offline or slow to respond."
|
||||
)
|
||||
|
||||
return f"Error while contacting {url}: {error}"
|
||||
|
||||
|
||||
def _is_backend_connection_error(exc: Exception) -> bool:
|
||||
"""True for upstream connection-class failures observed via the OpenAI client.
|
||||
|
||||
Targets the case where a llama-server in router mode keeps answering
|
||||
/v1/models but its delegated worker for a specific model is dead, so
|
||||
chat/completions calls return 5xx with 'proxy error: Could not establish
|
||||
connection' (or the SDK raises APIConnectionError outright).
|
||||
|
||||
Excludes BadRequestError with exceed_context_size_error by design — those
|
||||
must stay on the reactive-trim path.
|
||||
"""
|
||||
if isinstance(exc, openai.APIConnectionError):
|
||||
return True
|
||||
if isinstance(exc, openai.InternalServerError):
|
||||
msg = str(exc).lower()
|
||||
return (
|
||||
"proxy error" in msg
|
||||
or "could not establish connection" in msg
|
||||
or "connection refused" in msg
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def _mark_backend_unhealthy(endpoint: str, model: str, reason: str = "") -> None:
|
||||
"""Record (endpoint, model) as broken so choose_endpoint avoids it.
|
||||
|
||||
Cleared only by TTL — the dead-worker failure mode is invisible to the
|
||||
/v1/models / /api/ps probes that clear _loaded_error_cache, so we cannot
|
||||
rely on a successful probe as a recovery signal.
|
||||
"""
|
||||
async with _completion_error_cache_lock:
|
||||
_completion_error_cache[(endpoint, model)] = time.time()
|
||||
print(f"[health] marked unhealthy ep={endpoint} model={model} reason={reason[:120]}", flush=True)
|
||||
|
||||
|
||||
def _is_llama_model_loaded(item: dict) -> bool:
|
||||
"""Return True if a llama-server /v1/models item has status 'loaded'.
|
||||
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
|
||||
If no status field is present, the model is always-loaded (not dynamically managed)."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True # No status field: model is always loaded (e.g. single-model servers)
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") == "loaded"
|
||||
if isinstance(status, str):
|
||||
return status == "loaded"
|
||||
return False
|
||||
|
||||
|
||||
def _is_llama_model_loaded_or_sleeping(item: dict) -> bool:
|
||||
"""Return True if status is 'loaded' or 'sleeping'.
|
||||
Newer llama-server versions report 'sleeping' in /v1/models when a model is idle;
|
||||
ps_details needs to include these so _fetch_llama_props can detect and unload them."""
|
||||
status = item.get("status")
|
||||
if status is None:
|
||||
return True
|
||||
if isinstance(status, dict):
|
||||
return status.get("value") in ("loaded", "sleeping")
|
||||
if isinstance(status, str):
|
||||
return status in ("loaded", "sleeping")
|
||||
return False
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
"""Endpoint URL, model-name, and endpoint-classification helpers.
|
||||
|
||||
The endpoint classifiers read live config via ``get_config()`` so that the
|
||||
startup-time rebind of ``config`` in router.py is picked up at call time.
|
||||
"""
|
||||
from config import get_config
|
||||
|
||||
|
||||
def _normalize_llama_model_name(name: str) -> str:
|
||||
"""Extract the model name from a huggingface-style identifier.
|
||||
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
|
||||
"""
|
||||
if "/" in name:
|
||||
name = name.rsplit("/", 1)[1]
|
||||
if ":" in name:
|
||||
name = name.split(":")[0]
|
||||
return name
|
||||
|
||||
|
||||
def _extract_llama_quant(name: str) -> str:
|
||||
"""Extract the quantization level from a huggingface-style identifier.
|
||||
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
|
||||
Returns empty string if no quant suffix is present.
|
||||
"""
|
||||
if ":" in name:
|
||||
return name.rsplit(":", 1)[1]
|
||||
return ""
|
||||
|
||||
|
||||
def ep2base(ep):
|
||||
if "/v1" in ep:
|
||||
base_url = ep
|
||||
else:
|
||||
base_url = ep + "/v1"
|
||||
return base_url
|
||||
|
||||
|
||||
def dedupe_on_keys(dicts, key_fields):
|
||||
"""
|
||||
Helper function to deduplicate endpoint details based on given dict keys.
|
||||
"""
|
||||
seen = set()
|
||||
out = []
|
||||
for d in dicts:
|
||||
# Build a tuple of the values for the chosen keys
|
||||
key = tuple(d.get(k) for k in key_fields)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
out.append(d)
|
||||
return out
|
||||
|
||||
|
||||
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
||||
"""
|
||||
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
||||
|
||||
Returns True for:
|
||||
- External services like OpenAI.com, Groq, etc.
|
||||
|
||||
Returns False for:
|
||||
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
||||
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
||||
"""
|
||||
cfg = get_config()
|
||||
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
return False
|
||||
|
||||
if "/v1" not in endpoint:
|
||||
return False
|
||||
|
||||
base_endpoint = endpoint.replace('/v1', '')
|
||||
if base_endpoint in cfg.endpoints:
|
||||
return False # It's Ollama's /v1
|
||||
|
||||
# Check for default Ollama port
|
||||
if ':11434' in endpoint:
|
||||
return False # It's Ollama
|
||||
|
||||
return True # It's an external OpenAI endpoint
|
||||
|
||||
|
||||
def is_openai_compatible(endpoint: str) -> bool:
|
||||
"""
|
||||
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
||||
This includes external OpenAI endpoints AND llama-server endpoints.
|
||||
"""
|
||||
return "/v1" in endpoint or endpoint in get_config().llama_server_endpoints
|
||||
|
||||
|
||||
def get_tracking_model(endpoint: str, model: str) -> str:
|
||||
"""
|
||||
Normalize model name for tracking purposes so it matches the PS table key.
|
||||
|
||||
- For llama-server endpoints: strips HF prefix and quantization suffix
|
||||
- For Ollama endpoints: appends ":latest" if no version suffix is present
|
||||
- For external OpenAI endpoints: returns as-is (not shown in PS)
|
||||
|
||||
This ensures consistent model naming across all routes for usage tracking.
|
||||
"""
|
||||
# External OpenAI endpoints are not shown in PS, keep as-is
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return model
|
||||
|
||||
# llama-server endpoints use normalized names in PS
|
||||
if endpoint in get_config().llama_server_endpoints:
|
||||
return _normalize_llama_model_name(model)
|
||||
|
||||
# Ollama endpoints: append ":latest" if no version suffix
|
||||
if ":" not in model:
|
||||
return model + ":latest"
|
||||
|
||||
return model
|
||||
|
|
@ -1,456 +0,0 @@
|
|||
"""Backend probe / discovery primitives.
|
||||
|
||||
The ``fetch`` class wraps the three discovery paths the router uses:
|
||||
* ``available_models`` — what the endpoint advertises (Ollama ``/api/tags``
|
||||
or OpenAI-style ``/v1/models``)
|
||||
* ``loaded_models`` — what is currently resident (Ollama ``/api/ps`` or
|
||||
llama-server ``/v1/models`` filtered on ``status == "loaded"``)
|
||||
* ``endpoint_details`` — arbitrary detail fetch used by management routes
|
||||
|
||||
Each path goes through three layers of cache: success cache, error cache,
|
||||
and an in-flight request map. Stale-while-revalidate refreshes happen in
|
||||
background tasks tracked by the ``_bg_refresh_*`` maps in ``state``.
|
||||
|
||||
``_raw_probe`` and ``_endpoint_health`` are the lower-level dual probes
|
||||
used by ``/health`` and ``/api/config`` to distinguish a healthy daemon
|
||||
with a broken model-introspection path from a dead daemon.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import aiohttp
|
||||
|
||||
from config import get_config
|
||||
from state import (
|
||||
_models_cache,
|
||||
_models_cache_lock,
|
||||
_loaded_models_cache,
|
||||
_loaded_models_cache_lock,
|
||||
_available_error_cache,
|
||||
_available_error_cache_lock,
|
||||
_loaded_error_cache,
|
||||
_loaded_error_cache_lock,
|
||||
_inflight_available_models,
|
||||
_inflight_loaded_models,
|
||||
_inflight_lock,
|
||||
_bg_refresh_available,
|
||||
_bg_refresh_loaded,
|
||||
_bg_refresh_lock,
|
||||
default_headers,
|
||||
)
|
||||
from backends.sessions import get_probe_session
|
||||
from backends.health import (
|
||||
_is_fresh,
|
||||
_ensure_success,
|
||||
_format_connection_issue,
|
||||
_is_llama_model_loaded,
|
||||
)
|
||||
from backends.normalize import is_ext_openai_endpoint, is_openai_compatible
|
||||
|
||||
|
||||
class fetch:
|
||||
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Internal function that performs the actual HTTP request to fetch available models.
|
||||
This is called by available_models() after checking caches and in-flight requests.
|
||||
"""
|
||||
cfg = get_config()
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
ep_base = endpoint.rstrip("/")
|
||||
if endpoint in cfg.llama_server_endpoints and "/v1" not in endpoint:
|
||||
endpoint_url = f"{ep_base}/v1/models"
|
||||
key = "data"
|
||||
elif "/v1" in endpoint or endpoint in cfg.llama_server_endpoints:
|
||||
endpoint_url = f"{ep_base}/models"
|
||||
key = "data"
|
||||
else:
|
||||
endpoint_url = f"{ep_base}/api/tags"
|
||||
key = "models"
|
||||
|
||||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
try:
|
||||
async with client.get(endpoint_url, headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
items = data.get(key, [])
|
||||
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
||||
|
||||
async with _models_cache_lock:
|
||||
_models_cache[endpoint] = (models, time.time())
|
||||
return models
|
||||
except Exception as e:
|
||||
# Treat any error as if the endpoint offers no models
|
||||
message = _format_connection_issue(endpoint_url, e)
|
||||
print(f"[fetch.available_models] {message}")
|
||||
# Update error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
_available_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
|
||||
"""
|
||||
Background task to refresh available models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_bg_refresh_available[endpoint] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_available.get(endpoint) is task:
|
||||
_bg_refresh_available.pop(endpoint, None)
|
||||
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/tags and return a set of all model names that the
|
||||
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
||||
every model that is installed on the Ollama instance, regardless of
|
||||
whether the model is currently loaded into memory.
|
||||
|
||||
Uses request coalescing to prevent cache stampede: if multiple requests
|
||||
arrive when cache is expired, only one actual HTTP request is made.
|
||||
|
||||
Uses stale-while-revalidate: when the cache is between 300-600s old,
|
||||
the stale data is returned immediately while a background refresh runs.
|
||||
This prevents model blackouts caused by transient timeouts.
|
||||
|
||||
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
||||
set is returned.
|
||||
"""
|
||||
# Check models cache with lock protection
|
||||
async with _models_cache_lock:
|
||||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
|
||||
# FRESH: <= 300s old - return immediately
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
|
||||
# STALE: 300-600s old - return stale data and refresh in background
|
||||
if _is_fresh(cached_at, 600):
|
||||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||
return models # Return stale data immediately
|
||||
|
||||
# EXPIRED: > 600s old - too stale, must refresh synchronously
|
||||
del _models_cache[endpoint]
|
||||
|
||||
# Check error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
err_age = time.time() - _available_error_cache[endpoint]
|
||||
if err_age < 30:
|
||||
# Very fresh error (<30s) – endpoint likely still down, bail fast
|
||||
return set()
|
||||
elif err_age < 300:
|
||||
# Stale error (30-300s) – endpoint may have recovered, probe in background
|
||||
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
||||
return set()
|
||||
# Error expired (>300s) – remove and fall through to fresh fetch
|
||||
del _available_error_cache[endpoint]
|
||||
|
||||
# Request coalescing: check if another request is already fetching this endpoint
|
||||
async with _inflight_lock:
|
||||
if endpoint in _inflight_available_models:
|
||||
# Another request is already fetching - wait for it
|
||||
task = _inflight_available_models[endpoint]
|
||||
else:
|
||||
# Create new fetch task
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_inflight_available_models[endpoint] = task
|
||||
|
||||
try:
|
||||
# Wait for the fetch to complete (either ours or another request's)
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up in-flight tracking (only if we created it)
|
||||
async with _inflight_lock:
|
||||
if _inflight_available_models.get(endpoint) == task:
|
||||
_inflight_available_models.pop(endpoint, None)
|
||||
|
||||
|
||||
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Internal function that performs the actual HTTP request to fetch loaded models.
|
||||
This is called by loaded_models() after checking caches and in-flight requests.
|
||||
|
||||
For Ollama endpoints: queries /api/ps and returns model names
|
||||
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
||||
"""
|
||||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
cfg = get_config()
|
||||
|
||||
# Check if this is a llama-server endpoint
|
||||
if endpoint in cfg.llama_server_endpoints:
|
||||
# Query /v1/models for llama-server. Send the configured key as a
|
||||
# Bearer token — current llama.cpp leaves /models public, but a
|
||||
# build/config that protects it would otherwise 401 this probe.
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
api_key = cfg.api_keys.get(endpoint)
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
try:
|
||||
async with client.get(f"{endpoint}/models", headers=headers) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
|
||||
# Filter for loaded models only
|
||||
items = data.get("data", [])
|
||||
models = {
|
||||
item.get("id")
|
||||
for item in items
|
||||
if item.get("id") and _is_llama_model_loaded(item)
|
||||
}
|
||||
|
||||
# Update cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
# Probe succeeded — clear any stale error so the endpoint
|
||||
# becomes routable again.
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
message = _format_connection_issue(f"{endpoint}/models", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
# Record the failure so `choose_endpoint` can avoid routing
|
||||
# to an unhealthy backend and repeated probes short-circuit.
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
else:
|
||||
# Original Ollama /api/ps logic
|
||||
try:
|
||||
async with client.get(f"{endpoint}/api/ps") as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
# The response format is:
|
||||
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
||||
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
||||
|
||||
# Update cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
_loaded_models_cache[endpoint] = (models, time.time())
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache.pop(endpoint, None)
|
||||
return models
|
||||
except Exception as e:
|
||||
# If anything goes wrong we simply assume the endpoint has no models
|
||||
message = _format_connection_issue(f"{endpoint}/api/ps", e)
|
||||
print(f"[fetch.loaded_models] {message}")
|
||||
async with _loaded_error_cache_lock:
|
||||
_loaded_error_cache[endpoint] = time.time()
|
||||
return set()
|
||||
|
||||
async def _refresh_loaded_models(endpoint: str) -> None:
|
||||
"""
|
||||
Background task to refresh loaded models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_bg_refresh_loaded[endpoint] = task
|
||||
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_loaded.get(endpoint) is task:
|
||||
_bg_refresh_loaded.pop(endpoint, None)
|
||||
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
Query <endpoint>/api/ps and return a set of model names that are currently
|
||||
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
||||
set is returned.
|
||||
|
||||
Uses request coalescing to prevent cache stampede and stale-while-revalidate
|
||||
to serve requests immediately even when cache is stale (refreshing in background).
|
||||
"""
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return set()
|
||||
|
||||
# Check loaded models cache with lock protection
|
||||
async with _loaded_models_cache_lock:
|
||||
if endpoint in _loaded_models_cache:
|
||||
models, cached_at = _loaded_models_cache[endpoint]
|
||||
|
||||
# FRESH: < 10s old - return immediately
|
||||
if _is_fresh(cached_at, 10):
|
||||
return models
|
||||
|
||||
# STALE: 10-60s old - return stale data and refresh in background
|
||||
if _is_fresh(cached_at, 60):
|
||||
# Kick off background refresh (fire-and-forget)
|
||||
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
|
||||
return models # Return stale data immediately
|
||||
|
||||
# EXPIRED: > 60s old - too stale, must refresh synchronously
|
||||
del _loaded_models_cache[endpoint]
|
||||
|
||||
# Check error cache with lock protection
|
||||
async with _loaded_error_cache_lock:
|
||||
if endpoint in _loaded_error_cache:
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
||||
return set()
|
||||
# Error expired - remove it
|
||||
del _loaded_error_cache[endpoint]
|
||||
|
||||
# Request coalescing: check if another request is already fetching this endpoint
|
||||
async with _inflight_lock:
|
||||
if endpoint in _inflight_loaded_models:
|
||||
# Another request is already fetching - wait for it
|
||||
task = _inflight_loaded_models[endpoint]
|
||||
else:
|
||||
# Create new fetch task
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_inflight_loaded_models[endpoint] = task
|
||||
|
||||
try:
|
||||
# Wait for the fetch to complete (either ours or another request's)
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
# Clean up in-flight tracking (only if we created it)
|
||||
async with _inflight_lock:
|
||||
if _inflight_loaded_models.get(endpoint) == task:
|
||||
_inflight_loaded_models.pop(endpoint, None)
|
||||
|
||||
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False, timeout: float = None) -> List[dict]:
|
||||
"""
|
||||
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
||||
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
||||
|
||||
When ``skip_error_cache`` is False (the default), the call is short-circuited
|
||||
if the endpoint recently failed (recorded in ``_available_error_cache``).
|
||||
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
|
||||
|
||||
``timeout`` overrides the session default for this single request (seconds, total).
|
||||
"""
|
||||
# Fast-fail if the endpoint is known to be down (unless caller opts out)
|
||||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
return []
|
||||
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
|
||||
request_url = f"{endpoint.rstrip('/')}/{route.lstrip('/')}"
|
||||
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
||||
req_kwargs = {}
|
||||
if timeout is not None:
|
||||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||||
try:
|
||||
async with client.get(request_url, headers=headers, **req_kwargs) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
detail = data.get(detail, [])
|
||||
return detail
|
||||
except Exception as e:
|
||||
# If anything goes wrong we cannot reply details
|
||||
message = _format_connection_issue(request_url, e)
|
||||
print(f"[fetch.endpoint_details] {message}")
|
||||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
_available_error_cache[endpoint] = time.time()
|
||||
return []
|
||||
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# Endpoint health probes (shared by /api/config and /health)
|
||||
# -------------------------------------------------------------
|
||||
async def _raw_probe(
|
||||
ep: str,
|
||||
route: str,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> tuple[bool, object]:
|
||||
"""Direct HTTP probe that distinguishes success from failure
|
||||
(unlike `fetch.endpoint_details`, which returns [] on either).
|
||||
Returns `(ok, payload_or_error_message)`.
|
||||
"""
|
||||
headers = {"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = "Bearer " + api_key
|
||||
url = f"{ep.rstrip('/')}/{route.lstrip('/')}"
|
||||
req_kwargs = {}
|
||||
if timeout is not None:
|
||||
req_kwargs["timeout"] = aiohttp.ClientTimeout(total=timeout)
|
||||
try:
|
||||
client: aiohttp.ClientSession = get_probe_session(ep)
|
||||
async with client.get(url, headers=headers, **req_kwargs) as resp:
|
||||
await _ensure_success(resp)
|
||||
data = await resp.json()
|
||||
return True, data
|
||||
except Exception as exc:
|
||||
return False, _format_connection_issue(url, exc)
|
||||
|
||||
|
||||
async def _endpoint_health(ep: str, *, timeout: Optional[float] = None) -> dict:
|
||||
"""Probe an endpoint and return `{status, version?, detail?}`.
|
||||
|
||||
Ollama endpoints get a dual probe of `/api/version` and `/api/ps` so
|
||||
that a daemon which is reachable but has a broken model-introspection
|
||||
path (issue #83) is reported as `error` rather than `ok`.
|
||||
OpenAI-compatible endpoints use a single `/models` probe.
|
||||
"""
|
||||
if is_openai_compatible(ep):
|
||||
ok, payload = await _raw_probe(
|
||||
ep, "/models", get_config().api_keys.get(ep), timeout=timeout,
|
||||
)
|
||||
if ok:
|
||||
return {"status": "ok", "version": "latest"}
|
||||
return {"status": "error", "detail": str(payload)}
|
||||
|
||||
(version_ok, version_payload), (ps_ok, ps_payload) = await asyncio.gather(
|
||||
_raw_probe(ep, "/api/version", timeout=timeout),
|
||||
_raw_probe(ep, "/api/ps", timeout=timeout),
|
||||
)
|
||||
|
||||
version_value = (
|
||||
version_payload.get("version")
|
||||
if version_ok and isinstance(version_payload, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
if version_ok and ps_ok:
|
||||
return {"status": "ok", "version": version_value}
|
||||
if not version_ok and not ps_ok:
|
||||
return {"status": "error", "detail": str(version_payload)}
|
||||
# Partial failure — daemon reachable but one probe failed. Report
|
||||
# as "error" so callers can surface the issue; include `version` so
|
||||
# the operator knows the daemon itself is alive.
|
||||
if not ps_ok:
|
||||
return {
|
||||
"status": "error",
|
||||
"version": version_value,
|
||||
"detail": f"/api/ps: {ps_payload}",
|
||||
}
|
||||
return {
|
||||
"status": "error",
|
||||
"detail": f"/api/version: {version_payload}",
|
||||
}
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
"""aiohttp / OpenAI client factories aware of Unix-socket endpoints.
|
||||
|
||||
Unix socket endpoints follow the ``.sock`` hostname convention (e.g.
|
||||
``http://192.168.0.52.sock/v1``) and resolve to ``/run/user/<uid>/<host>``.
|
||||
Their sessions/clients live in ``state.app_state`` so that startup can
|
||||
populate them once and routes can reuse them.
|
||||
"""
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
import ollama
|
||||
import openai
|
||||
|
||||
from state import app_state
|
||||
from backends.normalize import ep2base
|
||||
|
||||
|
||||
def _is_unix_socket_endpoint(endpoint: str) -> bool:
|
||||
"""Return True if endpoint uses Unix socket (.sock hostname convention).
|
||||
|
||||
Detects URLs like http://192.168.0.52.sock/v1 where the host ends with
|
||||
.sock, indicating the connection should use a Unix domain socket at
|
||||
/tmp/<host> instead of TCP.
|
||||
"""
|
||||
try:
|
||||
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
|
||||
return host.endswith(".sock")
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_socket_path(endpoint: str) -> str:
|
||||
"""Derive Unix socket file path from a .sock endpoint URL.
|
||||
|
||||
http://192.168.0.52.sock/v1 -> /run/user/<uid>/192.168.0.52.sock
|
||||
"""
|
||||
host = endpoint.split("//", 1)[1].split("/")[0].split(":")[0]
|
||||
return f"/run/user/{os.getuid()}/{host}"
|
||||
|
||||
|
||||
def get_session(endpoint: str) -> aiohttp.ClientSession:
|
||||
"""Return the appropriate aiohttp session for the given endpoint.
|
||||
|
||||
Unix socket endpoints (.sock) get their own UnixConnector session.
|
||||
All other endpoints share the main TCP session.
|
||||
"""
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
sess = app_state["socket_sessions"].get(endpoint)
|
||||
if sess is not None:
|
||||
return sess
|
||||
return app_state["session"]
|
||||
|
||||
|
||||
def get_probe_session(endpoint: str) -> aiohttp.ClientSession:
|
||||
"""Return the session used for lightweight health/introspection probes.
|
||||
|
||||
Probes (available/loaded models, endpoint health) run on a connection
|
||||
pool kept separate from the proxy/streaming session, so a burst of
|
||||
long-lived completion requests cannot starve them — otherwise a probe
|
||||
would queue waiting for a connection, hit its deadline, and mark a
|
||||
perfectly healthy endpoint as unavailable under load.
|
||||
|
||||
Unix socket endpoints keep their dedicated per-endpoint session. TCP
|
||||
endpoints use the shared probe session, falling back to the main
|
||||
session when the probe pool has not been initialised (e.g. in tests).
|
||||
"""
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
sess = app_state["socket_sessions"].get(endpoint)
|
||||
if sess is not None:
|
||||
return sess
|
||||
return app_state.get("probe_session") or app_state["session"]
|
||||
|
||||
|
||||
def get_ollama_client(endpoint: str) -> ollama.AsyncClient:
|
||||
"""Return a cached ``ollama.AsyncClient`` for the endpoint, creating it once.
|
||||
|
||||
``ollama.AsyncClient`` wraps an ``httpx.AsyncClient`` whose construction
|
||||
builds an SSL context and reloads the OS trust store (~40 ms). It is safe to
|
||||
reuse concurrently, so we keep one per endpoint instead of building a fresh
|
||||
one on every request — otherwise that 40 ms of CPU runs on the event loop
|
||||
per request and caps single-worker throughput at ~25 req/s.
|
||||
"""
|
||||
cache = app_state["ollama_clients"]
|
||||
client = cache.get(endpoint)
|
||||
if client is None:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
cache[endpoint] = client
|
||||
return client
|
||||
|
||||
|
||||
def _make_openai_client(
|
||||
endpoint: str,
|
||||
default_headers: dict | None = None,
|
||||
api_key: str = "no-key",
|
||||
) -> openai.AsyncOpenAI:
|
||||
"""Return a cached AsyncOpenAI client configured for the given endpoint.
|
||||
|
||||
Clients are cached per ``(endpoint, api_key)`` and reused across requests:
|
||||
constructing one builds an SSL context and reloads the OS trust store
|
||||
(~40 ms), which serializes the event loop if done per request. For Unix
|
||||
socket endpoints, injects the pre-created httpx UDS transport so the OpenAI
|
||||
SDK connects via the socket instead of TCP.
|
||||
"""
|
||||
cache = app_state["openai_clients"]
|
||||
cache_key = (endpoint, api_key)
|
||||
client = cache.get(cache_key)
|
||||
if client is not None:
|
||||
return client
|
||||
|
||||
base_url = ep2base(endpoint)
|
||||
kwargs: dict = {"api_key": api_key}
|
||||
if default_headers is not None:
|
||||
kwargs["default_headers"] = default_headers
|
||||
if _is_unix_socket_endpoint(endpoint):
|
||||
http_client = app_state["httpx_clients"].get(endpoint)
|
||||
if http_client is not None:
|
||||
kwargs["http_client"] = http_client
|
||||
base_url = "http://localhost/v1"
|
||||
client = openai.AsyncOpenAI(base_url=base_url, **kwargs)
|
||||
cache[cache_key] = client
|
||||
return client
|
||||
139
config.py
139
config.py
|
|
@ -1,139 +0,0 @@
|
|||
"""Router configuration loader.
|
||||
|
||||
Pydantic ``BaseSettings`` model populated from YAML (path resolved via
|
||||
``_config_path_from_env``) with ``${VAR}`` expansion, plus env-var overrides
|
||||
under the ``NOMYO_ROUTER_`` prefix.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
# List of Ollama endpoints
|
||||
endpoints: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"http://localhost:11434",
|
||||
]
|
||||
)
|
||||
# List of llama-server endpoints (OpenAI-compatible with /v1/models status info)
|
||||
llama_server_endpoints: List[str] = Field(default_factory=list)
|
||||
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
||||
max_concurrent_connections: int = 1
|
||||
# Per-endpoint overrides: {endpoint_url: {max_concurrent_connections: N}}
|
||||
endpoint_config: Dict[str, Dict] = Field(default_factory=dict)
|
||||
# When True, config order = priority; routes by utilization ratio + config index (WRR)
|
||||
priority_routing: bool = Field(default=False)
|
||||
|
||||
# Conversation affinity: route the same conversation back to the endpoint that
|
||||
# previously served it, to keep the llama.cpp / Ollama prompt cache (KV cache) warm.
|
||||
# Soft preference — falls back to the standard algorithm when the affine endpoint
|
||||
# is saturated or no longer has the model loaded.
|
||||
conversation_affinity: bool = Field(default=False)
|
||||
# TTL (seconds) for affinity entries. Defaults to Ollama's default keep_alive (5 min):
|
||||
# if the backend has already evicted the model, the KV cache is cold anyway.
|
||||
conversation_affinity_ttl: int = Field(default=300)
|
||||
|
||||
api_keys: Dict[str, str] = Field(default_factory=dict)
|
||||
# Optional router-level API key used to gate access to this service and dashboard
|
||||
router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY")
|
||||
|
||||
# Database configuration
|
||||
db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db"))
|
||||
|
||||
# Semantic LLM Cache configuration
|
||||
cache_enabled: bool = Field(default=False)
|
||||
# Backend: "memory" (default, in-process), "sqlite" (persistent), "redis" (distributed)
|
||||
cache_backend: str = Field(default="memory")
|
||||
# Cosine similarity threshold: 1.0 = exact match only, <1.0 = semantic (requires :semantic image)
|
||||
cache_similarity: float = Field(default=1.0)
|
||||
# TTL in seconds; None = cache forever
|
||||
cache_ttl: Optional[int] = Field(default=3600)
|
||||
# SQLite backend: path to cache database file
|
||||
cache_db_path: str = Field(default="llm_cache.db")
|
||||
# Redis backend: connection URL
|
||||
cache_redis_url: str = Field(default="redis://localhost:6379/0")
|
||||
# Weight of BM25-weighted chat-history embedding vs last-user-message embedding
|
||||
# 0.3 = 30% history context signal, 70% question signal
|
||||
cache_history_weight: float = Field(default=0.3)
|
||||
|
||||
class Config:
|
||||
# YAML loading is handled manually via Config.from_yaml(); env vars use this prefix.
|
||||
env_prefix = "NOMYO_ROUTER_"
|
||||
|
||||
@classmethod
|
||||
def _expand_env_refs(cls, obj):
|
||||
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [cls._expand_env_refs(v) for v in obj]
|
||||
if isinstance(obj, str):
|
||||
# Only expand if it is exactly ${VAR}
|
||||
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
|
||||
if m:
|
||||
return os.getenv(m.group(1), "")
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "Config":
|
||||
"""Load the YAML file and create the Config instance."""
|
||||
if path.exists():
|
||||
with path.open("r", encoding="utf-8") as fp:
|
||||
data = yaml.safe_load(fp) or {}
|
||||
cleaned = cls._expand_env_refs(data)
|
||||
if isinstance(cleaned, dict):
|
||||
# Accept hyphenated config key and map it to the field name
|
||||
key_aliases = [
|
||||
# canonical field name
|
||||
"router_api_key",
|
||||
# lowercase, hyphen/underscore variants
|
||||
"nomyo-router-api-key",
|
||||
"nomyo_router_api_key",
|
||||
"nomyo-router_api_key",
|
||||
"nomyo_router-api_key",
|
||||
# uppercase env-style variants
|
||||
"NOMYO-ROUTER_API_KEY",
|
||||
"NOMYO_ROUTER_API_KEY",
|
||||
]
|
||||
for alias in key_aliases:
|
||||
if alias in cleaned:
|
||||
cleaned["router_api_key"] = cleaned.get("router_api_key", cleaned.pop(alias))
|
||||
break
|
||||
# If not present in YAML (or empty), fall back to env var explicitly
|
||||
if not cleaned.get("router_api_key"):
|
||||
env_key = os.getenv("NOMYO_ROUTER_API_KEY")
|
||||
if env_key:
|
||||
cleaned["router_api_key"] = env_key
|
||||
return cls(**cleaned)
|
||||
return cls()
|
||||
|
||||
|
||||
def _config_path_from_env() -> Path:
|
||||
"""
|
||||
Resolve the configuration file path. Defaults to `config.yaml`
|
||||
in the current working directory unless NOMYO_ROUTER_CONFIG_PATH
|
||||
is set.
|
||||
"""
|
||||
candidate = os.getenv("NOMYO_ROUTER_CONFIG_PATH")
|
||||
if candidate:
|
||||
return Path(candidate).expanduser()
|
||||
return Path("config.yaml")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared config accessor
|
||||
# ------------------------------------------------------------------
|
||||
# Submodules read config at call time via get_config() instead of importing
|
||||
# a bound name. The single source of truth is ``router.config`` — the lazy
|
||||
# import below resolves it after router.py has finished loading, and lets
|
||||
# tests that ``patch.object(router, "config", cfg)`` flow through.
|
||||
def get_config() -> "Config":
|
||||
"""Return the currently active Config from router.py."""
|
||||
import router # lazy to avoid module-load circular import
|
||||
return router.config
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
"""Sliding-window context-trim helpers.
|
||||
|
||||
Mirrors what llama.cpp's context-shift used to do: count tokens with tiktoken
|
||||
(cl100k_base) when available, drop oldest non-system messages until the prompt
|
||||
fits inside (n_ctx - safety_margin).
|
||||
|
||||
Also owns the per-(endpoint, model) n_ctx cache that the routes populate from
|
||||
exceed_context_size_error bodies and from finish_reason=="length" signals.
|
||||
"""
|
||||
import os
|
||||
|
||||
# Point tiktoken at the vendored cl100k_base vocab so the encoding loads offline,
|
||||
# without a network download. The download would otherwise fail anyway: this repo
|
||||
# has a top-level `requests` package that shadows the pip `requests` tiktoken's
|
||||
# downloader imports, so get_encoding() would silently fall back to char/4. See
|
||||
# vendor/tiktoken/. setdefault lets an explicit env override win.
|
||||
os.environ.setdefault(
|
||||
"TIKTOKEN_CACHE_DIR",
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "vendor", "tiktoken"),
|
||||
)
|
||||
|
||||
try:
|
||||
import tiktoken as _tiktoken
|
||||
_tiktoken_enc = _tiktoken.get_encoding("cl100k_base")
|
||||
except Exception:
|
||||
_tiktoken_enc = None
|
||||
|
||||
|
||||
def _count_message_tokens(messages: list) -> int:
|
||||
"""Approximate token count for a message list.
|
||||
|
||||
Uses tiktoken cl100k_base when available (within ~5-15% of llama tokenizers).
|
||||
Falls back to char/4 heuristic if tiktoken is unavailable.
|
||||
Formula follows OpenAI's per-message overhead: 4 tokens/message + content + 2 priming.
|
||||
"""
|
||||
if _tiktoken_enc is None:
|
||||
return sum(len(str(m.get("content", ""))) for m in messages) // 4
|
||||
|
||||
total = 2 # priming tokens
|
||||
for msg in messages:
|
||||
total += 4 # per-message role/separator overhead
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += len(_tiktoken_enc.encode(content))
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
total += len(_tiktoken_enc.encode(part.get("text", "")))
|
||||
return total
|
||||
|
||||
|
||||
def _trim_messages_for_context(
|
||||
messages: list,
|
||||
n_ctx: int,
|
||||
safety_margin: int = None,
|
||||
target_tokens: int = None,
|
||||
) -> list:
|
||||
"""Sliding-window trim — mirrors what llama.cpp context-shift used to do.
|
||||
|
||||
Keeps all system messages and the most recent non-system messages that fit
|
||||
within (n_ctx - safety_margin) tokens. Oldest non-system messages are dropped
|
||||
first (FIFO). The last message is always preserved.
|
||||
|
||||
safety_margin defaults to 1/4 of n_ctx to leave headroom for the generated
|
||||
response, including RAG tool results and tool call JSON synthesis.
|
||||
|
||||
target_tokens: if provided, overrides the (n_ctx - safety_margin) target.
|
||||
Pass a calibrated value when actual n_prompt_tokens is known from the error
|
||||
body so that tiktoken underestimation vs the backend tokenizer is corrected.
|
||||
"""
|
||||
if target_tokens is not None:
|
||||
target = target_tokens
|
||||
else:
|
||||
if safety_margin is None:
|
||||
safety_margin = n_ctx // 4
|
||||
target = n_ctx - safety_margin
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
while len(non_system) > 1:
|
||||
if _count_message_tokens(system_msgs + non_system) <= target:
|
||||
break
|
||||
non_system.pop(0) # drop oldest non-system message
|
||||
|
||||
# Ensure the first non-system message is a user message (chat templates require it).
|
||||
# Drop any leading assistant/tool messages that were left after trimming.
|
||||
while non_system and non_system[0].get("role") != "user":
|
||||
non_system.pop(0)
|
||||
|
||||
return system_msgs + non_system
|
||||
|
||||
|
||||
def _calibrated_trim_target(msgs: list, n_ctx: int, actual_tokens: int) -> int:
|
||||
"""Return a tiktoken-scale trim target based on how much backend tokens must be shed.
|
||||
|
||||
actual_tokens includes messages + tool schemas + overhead as counted by the backend.
|
||||
_count_message_tokens only counts message text, so we cannot derive an accurate
|
||||
per-token scale from the ratio. Instead we compute the *delta* we need to remove
|
||||
in backend space, then convert just that delta to tiktoken scale (×1.2 buffer).
|
||||
|
||||
Example: actual=17993, n_ctx=16384, headroom=4096 → need to shed 5705 backend
|
||||
tokens → shed 6846 tiktoken tokens from messages.
|
||||
"""
|
||||
cur_tiktoken = _count_message_tokens(msgs)
|
||||
headroom = n_ctx // 4 # reserve for generated output
|
||||
max_prompt = n_ctx - headroom # desired max backend tokens in prompt
|
||||
to_shed = max(0, actual_tokens - max_prompt) # backend tokens we must drop
|
||||
# Convert to tiktoken scale with 20% buffer (tiktoken underestimates llama by ~15-20%)
|
||||
tiktoken_to_shed = int(to_shed * 1.2)
|
||||
return max(1, cur_tiktoken - tiktoken_to_shed)
|
||||
|
||||
|
||||
# Per-(endpoint, model) n_ctx cache.
|
||||
# Populated from two sources:
|
||||
# 1. 400 exceed_context_size_error body → n_ctx field
|
||||
# 2. finish_reason/done_reason == "length" in streaming → prompt_tokens + completion_tokens
|
||||
# Only used for proactive pre-trimming when n_ctx <= _CTX_TRIM_SMALL_LIMIT,
|
||||
# so large-context models (200k+ for coding) are never touched.
|
||||
_endpoint_nctx: dict[tuple[str, str], int] = {}
|
||||
_CTX_TRIM_SMALL_LIMIT = 32768 # only proactively trim models with n_ctx at or below this
|
||||
11
db.py
11
db.py
|
|
@ -4,17 +4,6 @@ from pathlib import Path
|
|||
from datetime import datetime, timezone
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def get_db() -> "TokenDatabase":
|
||||
"""Return the live TokenDatabase instance held by router.py.
|
||||
|
||||
Resolved lazily so submodules can access it without import cycles, and
|
||||
so test patches of ``router.db`` flow through to all callers.
|
||||
"""
|
||||
import router # lazy to avoid module-load circular import
|
||||
return router.db
|
||||
|
||||
|
||||
class TokenDatabase:
|
||||
def __init__(self, db_path: str = "token_counts.db"):
|
||||
self.db_path = db_path
|
||||
|
|
|
|||
|
|
@ -1,35 +0,0 @@
|
|||
"""Conversation fingerprinting for prompt-cache-aware routing."""
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _conversation_fingerprint(model: str, messages: Optional[list],
|
||||
prompt: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Stable hash over (model, first system + first user turn). That prefix
|
||||
determines whether the backend's prompt cache is reusable; later turns
|
||||
don't influence the routing decision because they extend the same prefix.
|
||||
Returns None when there is no usable prefix.
|
||||
"""
|
||||
parts: list[str] = [model or "_"]
|
||||
if messages:
|
||||
for m in messages:
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
if role not in ("system", "user"):
|
||||
continue
|
||||
content = m.get("content")
|
||||
if isinstance(content, list): # OpenAI multimodal parts
|
||||
content = "".join(
|
||||
p.get("text", "") for p in content
|
||||
if isinstance(p, dict) and p.get("type") == "text"
|
||||
)
|
||||
if not isinstance(content, str):
|
||||
continue
|
||||
parts.append(f"{role}:{content}")
|
||||
if role == "user":
|
||||
break
|
||||
elif prompt:
|
||||
parts.append(f"user:{prompt}")
|
||||
else:
|
||||
return None
|
||||
return hashlib.sha1("\x1f".join(parts).encode("utf-8", "replace")).hexdigest()
|
||||
66
images.py
66
images.py
|
|
@ -1,66 +0,0 @@
|
|||
"""Image and timestamp helpers used by the Ollama/OpenAI request pipeline."""
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def iso8601_ns():
|
||||
ns = time.time_ns()
|
||||
sec, ns_rem = divmod(ns, 1_000_000_000)
|
||||
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
|
||||
f"{ns_rem:09d}Z"
|
||||
)
|
||||
|
||||
|
||||
def is_base64(image_string):
|
||||
try:
|
||||
if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode():
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def resize_image_if_needed(image_data):
|
||||
try:
|
||||
# Check if already data-url
|
||||
if image_data.startswith("data:"):
|
||||
try:
|
||||
header, image_data = image_data.split(",", 1)
|
||||
except ValueError:
|
||||
pass
|
||||
# Decode the base64 image data
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
with Image.open(io.BytesIO(image_bytes)) as image:
|
||||
if image.mode not in ("RGB", "L"):
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Get current size
|
||||
width, height = image.size
|
||||
|
||||
# Calculate the new dimensions while maintaining aspect ratio
|
||||
if width > 512 or height > 512:
|
||||
aspect_ratio = width / height
|
||||
if aspect_ratio > 1: # Width is larger
|
||||
new_width = 512
|
||||
new_height = int(512 / aspect_ratio)
|
||||
else: # Height is larger
|
||||
new_height = 512
|
||||
new_width = int(512 * aspect_ratio)
|
||||
|
||||
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Encode the resized image back to base64
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return resized_image_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
return None
|
||||
218
requests/chat.py
218
requests/chat.py
|
|
@ -1,218 +0,0 @@
|
|||
"""High-level chat request orchestrator.
|
||||
|
||||
``_make_chat_request`` is the shared core that:
|
||||
* picks an endpoint via ``choose_endpoint`` (which atomically reserves a slot),
|
||||
* dispatches to either the native Ollama client or an OpenAI-compatible
|
||||
client based on endpoint type,
|
||||
* applies reactive context trimming when the backend rejects with
|
||||
``exceed_context_size_error``,
|
||||
* counts tokens for billing/SSE,
|
||||
* always releases the reservation via ``decrement_usage`` in ``finally``.
|
||||
|
||||
``_make_moe_requests`` builds on it to implement the
|
||||
"3 responses + 3 critiques + 1 final" mixture-of-experts dance.
|
||||
"""
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
|
||||
import ollama
|
||||
|
||||
import enhance
|
||||
from config import get_config
|
||||
from state import default_headers, token_queue
|
||||
from context_window import _trim_messages_for_context, _calibrated_trim_target
|
||||
from backends.normalize import is_openai_compatible
|
||||
from backends.sessions import _make_openai_client
|
||||
from routing import choose_endpoint, decrement_usage
|
||||
from requests.messages import (
|
||||
get_last_user_content,
|
||||
transform_tool_calls_to_openai,
|
||||
transform_images_to_data_urls,
|
||||
_strip_assistant_prefill,
|
||||
_strip_images_from_messages,
|
||||
_accumulate_openai_tc_delta,
|
||||
_build_ollama_tool_calls,
|
||||
)
|
||||
from requests.rechunk import rechunk
|
||||
|
||||
|
||||
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
Helper function to make a chat request to a specific endpoint.
|
||||
Handles endpoint selection, client creation, usage tracking, and request execution.
|
||||
"""
|
||||
config = get_config()
|
||||
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")[0]
|
||||
if messages:
|
||||
if any("images" in m for m in messages):
|
||||
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
|
||||
messages = transform_tool_calls_to_openai(messages)
|
||||
messages = _strip_assistant_prefill(messages)
|
||||
params = {
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
}
|
||||
optional_params = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
"stream_options": {"include_usage": True} if stream else None,
|
||||
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
||||
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
||||
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
||||
"seed": options.get("seed") if options and "seed" in options else None,
|
||||
"stop": options.get("stop") if options and "stop" in options else None,
|
||||
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
||||
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
||||
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
|
||||
}
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
try:
|
||||
if use_openai:
|
||||
start_ts = time.perf_counter()
|
||||
try:
|
||||
response = await oclient.chat.completions.create(**params)
|
||||
except Exception as e:
|
||||
_e_str = str(e)
|
||||
print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}")
|
||||
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
|
||||
err_body = getattr(e, "body", {}) or {}
|
||||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||||
if not n_ctx_limit:
|
||||
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
n_ctx_limit = int(_m.group(1))
|
||||
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||||
if _m:
|
||||
actual_tokens = int(_m.group(1))
|
||||
if not n_ctx_limit:
|
||||
raise
|
||||
msgs_to_trim = params.get("messages", [])
|
||||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||||
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||||
print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
|
||||
try:
|
||||
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
||||
except Exception as e2:
|
||||
if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2):
|
||||
print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools")
|
||||
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
|
||||
response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
|
||||
else:
|
||||
raise
|
||||
elif "image input is not supported" in _e_str:
|
||||
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
|
||||
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
|
||||
response = await oclient.chat.completions.create(**params)
|
||||
else:
|
||||
raise
|
||||
if stream:
|
||||
# For streaming, we need to collect all chunks
|
||||
chunks = []
|
||||
tc_acc = {} # accumulate tool-call deltas
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Convert to Ollama format
|
||||
if chunks:
|
||||
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
|
||||
# Inject fully-accumulated tool calls into the final response
|
||||
if tc_acc and response.message:
|
||||
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
else:
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if response.usage is not None:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(response)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
|
||||
else:
|
||||
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
if stream:
|
||||
# For streaming, collect all chunks
|
||||
chunks = []
|
||||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if chunks:
|
||||
response = chunks[-1]
|
||||
else:
|
||||
prompt_tok = response.prompt_eval_count or 0
|
||||
comp_tok = response.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
|
||||
return response
|
||||
finally:
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
|
||||
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
Helper function to make MOE (Multiple Opinions Ensemble) requests.
|
||||
Generates 3 responses, 3 critiques, and returns the final selected response.
|
||||
"""
|
||||
query = get_last_user_content(messages)
|
||||
if not query:
|
||||
raise ValueError("No user query found in messages")
|
||||
|
||||
if options is None:
|
||||
options = {}
|
||||
options["temperature"] = 1
|
||||
|
||||
moe_reqs = []
|
||||
|
||||
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
|
||||
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
|
||||
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
|
||||
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
||||
|
||||
for n, r in enumerate(responses):
|
||||
moe_req = enhance.moe(query, n, r.message.content)
|
||||
moe_reqs.append(moe_req)
|
||||
|
||||
# Generate 3 critiques
|
||||
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
||||
|
||||
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
||||
|
||||
# Select final response
|
||||
m = enhance.moe_select_candidate(query, critiques)
|
||||
|
||||
# Generate final response
|
||||
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
"""Message-shape transforms used across the chat/completions paths.
|
||||
|
||||
Covers the directions between Ollama's native message format and the
|
||||
OpenAI Chat Completions format:
|
||||
* tool-call normalization (Ollama → OpenAI),
|
||||
* images encoded as base64 lists → OpenAI multimodal ``image_url`` parts,
|
||||
* trailing-assistant prefill strip (rejected by Claude/OpenAI),
|
||||
* streaming tool_calls accumulation across deltas,
|
||||
* logprobs translation (OpenAI choice → Ollama ``Logprob``).
|
||||
"""
|
||||
import secrets
|
||||
|
||||
import ollama
|
||||
import orjson
|
||||
from ollama._types import TokenLogprob, Logprob
|
||||
|
||||
from images import is_base64, resize_image_if_needed
|
||||
|
||||
|
||||
def get_last_user_content(messages):
|
||||
"""
|
||||
Given a list of dicts (e.g., messages from an API),
|
||||
return the 'content' of the last dict whose 'role' is 'user'.
|
||||
If no such dict exists, return None.
|
||||
"""
|
||||
# Reverse iterate so we stop at the first match
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
return msg.get("content")
|
||||
return None
|
||||
|
||||
|
||||
def _strip_assistant_prefill(messages: list) -> list:
|
||||
"""Remove a trailing assistant message used as prefill.
|
||||
OpenAI-compatible endpoints (including Claude) do not support prefill and
|
||||
will reject requests where the last message has role 'assistant'."""
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
return messages[:-1]
|
||||
return messages
|
||||
|
||||
|
||||
def transform_tool_calls_to_openai(message_list):
|
||||
"""
|
||||
Ensure tool_calls in assistant messages conform to the OpenAI format:
|
||||
- Each tool call must have "type": "function"
|
||||
- Each tool call must have an "id"
|
||||
- arguments must be a JSON string, not a dict
|
||||
Also ensure tool-role messages have a tool_call_id.
|
||||
"""
|
||||
# Track generated IDs so tool-role messages can reference them
|
||||
last_tool_call_ids = {}
|
||||
for msg in message_list:
|
||||
role = msg.get("role")
|
||||
if role == "assistant" and "tool_calls" in msg:
|
||||
for tc in msg["tool_calls"]:
|
||||
if "type" not in tc:
|
||||
tc["type"] = "function"
|
||||
if "id" not in tc:
|
||||
tc["id"] = f"call_{secrets.token_hex(16)}"
|
||||
func = tc.get("function", {})
|
||||
if isinstance(func.get("arguments"), dict):
|
||||
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
|
||||
# Remember the id for the following tool-role message
|
||||
name = func.get("name")
|
||||
if name:
|
||||
last_tool_call_ids[name] = tc["id"]
|
||||
elif role == "tool":
|
||||
if "tool_call_id" not in msg:
|
||||
# Try to match by name from a preceding assistant tool_call
|
||||
name = msg.get("name") or msg.get("tool_name")
|
||||
if name and name in last_tool_call_ids:
|
||||
msg["tool_call_id"] = last_tool_call_ids.pop(name)
|
||||
return message_list
|
||||
|
||||
|
||||
def transform_images_to_data_urls(message_list):
|
||||
for message in message_list:
|
||||
if "images" in message:
|
||||
images = message.pop("images")
|
||||
if not isinstance(images, list):
|
||||
continue
|
||||
new_content = []
|
||||
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
|
||||
if not is_base64(image):
|
||||
raise ValueError(f"Image string is not a valid base64 encoded string.")
|
||||
resized_image = resize_image_if_needed(image)
|
||||
if resized_image:
|
||||
data_url = f"data:image/png;base64,{resized_image}"
|
||||
#new_content.append({
|
||||
# "type": "text",
|
||||
# "text": ""
|
||||
#})
|
||||
new_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": data_url
|
||||
}
|
||||
})
|
||||
message["content"] = new_content
|
||||
|
||||
return message_list
|
||||
|
||||
|
||||
def _strip_images_from_messages(messages: list) -> list:
|
||||
"""Remove image_url parts from message content, keeping only text."""
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
text_only = [p for p in content if p.get("type") != "image_url"]
|
||||
if len(text_only) == 1 and text_only[0].get("type") == "text":
|
||||
content = text_only[0]["text"]
|
||||
else:
|
||||
content = text_only
|
||||
result.append({**msg, "content": content})
|
||||
else:
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
|
||||
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
|
||||
|
||||
``accumulator`` is a dict mapping tool-call *index* to
|
||||
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
|
||||
is the concatenation of all JSON fragments seen so far.
|
||||
"""
|
||||
if not chunk.choices:
|
||||
return
|
||||
delta = chunk.choices[0].delta
|
||||
tc_deltas = getattr(delta, "tool_calls", None)
|
||||
if not tc_deltas:
|
||||
return
|
||||
for tc in tc_deltas:
|
||||
idx = tc.index
|
||||
if idx not in accumulator:
|
||||
accumulator[idx] = {
|
||||
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
|
||||
"name": tc.function.name if tc.function else None,
|
||||
"arguments": "",
|
||||
}
|
||||
else:
|
||||
if getattr(tc, "id", None):
|
||||
accumulator[idx]["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
accumulator[idx]["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
accumulator[idx]["arguments"] += tc.function.arguments
|
||||
|
||||
|
||||
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
|
||||
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
|
||||
if not accumulator:
|
||||
return None
|
||||
result = []
|
||||
for idx in sorted(accumulator.keys()):
|
||||
tc = accumulator[idx]
|
||||
try:
|
||||
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
result.append(ollama.Message.ToolCall(
|
||||
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
|
||||
))
|
||||
return result
|
||||
|
||||
|
||||
def _convert_openai_logprobs(choice) -> list | None:
|
||||
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
|
||||
lp = getattr(choice, "logprobs", None)
|
||||
if lp is None:
|
||||
return None
|
||||
content = getattr(lp, "content", None)
|
||||
if not content:
|
||||
return None
|
||||
result = []
|
||||
for entry in content:
|
||||
top = [
|
||||
TokenLogprob(token=alt.token, logprob=alt.logprob)
|
||||
for alt in (entry.top_logprobs or [])
|
||||
]
|
||||
result.append(Logprob(
|
||||
token=entry.token,
|
||||
logprob=entry.logprob,
|
||||
top_logprobs=top or None,
|
||||
))
|
||||
return result
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
"""OpenAI → Ollama response shape converters.
|
||||
|
||||
Methods on the ``rechunk`` class are called as bare functions
|
||||
(``rechunk.openai_chat_completion2ollama(...)``) — there is no instance
|
||||
state. The class is just a namespace.
|
||||
|
||||
``extract_usage_from_llama_timings`` reads the ``timings`` field that
|
||||
llama-server returns in place of OpenAI's ``usage`` so the router can still
|
||||
count tokens for those backends.
|
||||
"""
|
||||
import time
|
||||
|
||||
import ollama
|
||||
import orjson
|
||||
|
||||
from images import iso8601_ns
|
||||
from requests.messages import _convert_openai_logprobs
|
||||
|
||||
|
||||
class rechunk:
|
||||
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
|
||||
now = time.perf_counter()
|
||||
if chunk.choices == [] and chunk.usage is not None:
|
||||
return ollama.ChatResponse(
|
||||
model=chunk.model,
|
||||
created_at=iso8601_ns(),
|
||||
done=True,
|
||||
done_reason='stop',
|
||||
total_duration=int((now - start_ts) * 1_000_000_000),
|
||||
load_duration=100000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens),
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
||||
eval_count=int(chunk.usage.completion_tokens),
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000),
|
||||
message=ollama.Message(role="assistant", content=""),
|
||||
)
|
||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||
if stream == True:
|
||||
thinking = (getattr(with_thinking.delta, "reasoning_content", None) or getattr(with_thinking.delta, "reasoning", None)) if with_thinking else None
|
||||
role = chunk.choices[0].delta.role or "assistant"
|
||||
content = chunk.choices[0].delta.content or ''
|
||||
else:
|
||||
thinking = (getattr(with_thinking.message, "reasoning_content", None) or getattr(with_thinking.message, "reasoning", None)) if with_thinking else None
|
||||
role = chunk.choices[0].message.role or "assistant"
|
||||
content = chunk.choices[0].message.content or ''
|
||||
# Convert OpenAI tool_calls to Ollama format
|
||||
# In streaming mode, tool_calls arrive as partial deltas across multiple chunks
|
||||
# (name only in first delta, arguments as incremental JSON fragments).
|
||||
# Callers must accumulate deltas and inject the final result; skip here.
|
||||
ollama_tool_calls = None
|
||||
if not stream:
|
||||
raw_tool_calls = getattr(with_thinking.message, "tool_calls", None) if with_thinking else None
|
||||
if raw_tool_calls:
|
||||
ollama_tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
ollama_tool_calls.append(ollama.Message.ToolCall(
|
||||
function=ollama.Message.ToolCall.Function(name=tc.function.name, arguments=args)
|
||||
))
|
||||
# Convert OpenAI logprobs to Ollama format
|
||||
ollama_logprobs = _convert_openai_logprobs(with_thinking) if with_thinking else None
|
||||
assistant_msg = ollama.Message(
|
||||
role=role,
|
||||
content=content,
|
||||
thinking=thinking,
|
||||
images=None,
|
||||
tool_name=None,
|
||||
tool_calls=ollama_tool_calls)
|
||||
rechunk = ollama.ChatResponse(
|
||||
model=chunk.model,
|
||||
created_at=iso8601_ns(),
|
||||
done=True if chunk.usage is not None else False,
|
||||
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
|
||||
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
load_duration=100000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
message=assistant_msg,
|
||||
logprobs=ollama_logprobs)
|
||||
return rechunk
|
||||
|
||||
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
|
||||
now = time.perf_counter()
|
||||
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
||||
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
|
||||
rechunk = ollama.GenerateResponse(
|
||||
model=chunk.model,
|
||||
created_at=iso8601_ns(),
|
||||
done=True if chunk.usage is not None else False,
|
||||
done_reason=chunk.choices[0].finish_reason,
|
||||
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
load_duration=10000,
|
||||
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
||||
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
||||
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
||||
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
||||
response=chunk.choices[0].text or '',
|
||||
thinking=thinking)
|
||||
return rechunk
|
||||
|
||||
def openai_embeddings2ollama(chunk: dict) -> ollama.EmbeddingsResponse:
|
||||
rechunk = ollama.EmbeddingsResponse(embedding=chunk.data[0].embedding)
|
||||
return rechunk
|
||||
|
||||
def openai_embed2ollama(chunk: dict, model: str) -> ollama.EmbedResponse:
|
||||
rechunk = ollama.EmbedResponse(
|
||||
model=model,
|
||||
created_at=iso8601_ns(),
|
||||
done=None,
|
||||
done_reason=None,
|
||||
total_duration=None,
|
||||
load_duration=None,
|
||||
prompt_eval_count=None,
|
||||
prompt_eval_duration=None,
|
||||
eval_count=None,
|
||||
eval_duration=None,
|
||||
embeddings=[chunk.data[0].embedding])
|
||||
return rechunk
|
||||
|
||||
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
|
||||
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
|
||||
|
||||
llama-server returns a ``timings`` dict instead of the standard OpenAI
|
||||
``usage`` field::
|
||||
|
||||
"timings": {
|
||||
"cache_n": 236, // prompt tokens reused from cache
|
||||
"prompt_n": 1, // prompt tokens processed
|
||||
"predicted_n": 35 // predicted (completion) tokens
|
||||
}
|
||||
|
||||
prompt_tokens = prompt_n + cache_n
|
||||
completion_tokens = predicted_n
|
||||
|
||||
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
|
||||
timings are found.
|
||||
"""
|
||||
timings = getattr(obj, "timings", None)
|
||||
if timings is None:
|
||||
return None
|
||||
if isinstance(timings, dict):
|
||||
prompt_n = timings.get("prompt_n", 0) or 0
|
||||
cache_n = timings.get("cache_n", 0) or 0
|
||||
predicted_n = timings.get("predicted_n", 0) or 0
|
||||
return (prompt_n + cache_n, predicted_n)
|
||||
return None
|
||||
|
|
@ -1,22 +1,22 @@
|
|||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.14.0
|
||||
aiohttp==3.13.5
|
||||
aiosignal==1.4.0
|
||||
annotated-types==0.7.0
|
||||
anyio==4.13.0
|
||||
async-timeout==5.0.1
|
||||
attrs==26.1.0
|
||||
certifi==2026.4.22
|
||||
certifi==2026.5.20
|
||||
click==8.4.0
|
||||
distro==1.9.0
|
||||
exceptiongroup==1.3.1
|
||||
fastapi==0.136.1
|
||||
fastapi==0.136.3
|
||||
fastapi-sse==1.1.1
|
||||
frozenlist==1.8.0
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
idna==3.15
|
||||
jiter==0.14.0
|
||||
jiter==0.15.0
|
||||
multidict==6.7.1
|
||||
ollama==0.6.2
|
||||
openai==2.37.0
|
||||
|
|
@ -30,10 +30,10 @@ pydantic_core==2.46.4
|
|||
python-dotenv==1.2.2
|
||||
PyYAML==6.0.3
|
||||
sniffio==1.3.1
|
||||
starlette==0.52.1
|
||||
starlette>=1.0.1
|
||||
truststore==0.10.4
|
||||
tiktoken==0.13.0
|
||||
tqdm==4.67.3
|
||||
tqdm==4.68.1
|
||||
typing-inspection==0.4.2
|
||||
typing_extensions==4.15.0
|
||||
uvicorn==0.47.0
|
||||
|
|
|
|||
294
routing.py
294
routing.py
|
|
@ -1,294 +0,0 @@
|
|||
"""Endpoint selection (load-balancing + conversation affinity).
|
||||
|
||||
``choose_endpoint`` is the heart of routing — it picks an endpoint that
|
||||
advertises the model, prefers ones with the model already loaded and a free
|
||||
slot, applies the conversation-affinity hint when available, and honors
|
||||
config-order priority routing when ``priority_routing`` is set.
|
||||
|
||||
``increment_usage`` / ``decrement_usage`` keep the per-(endpoint, model)
|
||||
counter that drives utilization-based selection; they fan out an SSE
|
||||
snapshot on every change.
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from config import get_config
|
||||
from state import (
|
||||
usage_counts,
|
||||
usage_lock,
|
||||
_loaded_error_cache,
|
||||
_loaded_error_cache_lock,
|
||||
_completion_error_cache,
|
||||
_completion_error_cache_lock,
|
||||
_COMPLETION_ERROR_TTL,
|
||||
_affinity_map,
|
||||
_affinity_lock,
|
||||
_AFFINITY_MAX_ENTRIES,
|
||||
)
|
||||
from sse import _capture_snapshot, _distribute_snapshot
|
||||
from backends.health import _is_fresh
|
||||
from backends.normalize import (
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
get_tracking_model,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
|
||||
|
||||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
usage_counts[endpoint][model] += 1
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
|
||||
|
||||
async def decrement_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
# Avoid negative counts
|
||||
current = usage_counts[endpoint].get(model, 0)
|
||||
if current > 0:
|
||||
usage_counts[endpoint][model] = current - 1
|
||||
# Optionally, clean up zero entries
|
||||
if usage_counts[endpoint].get(model, 0) == 0:
|
||||
usage_counts[endpoint].pop(model, None)
|
||||
#if not usage_counts[endpoint]:
|
||||
# usage_counts.pop(endpoint, None)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
|
||||
|
||||
def get_max_connections(ep: str) -> int:
|
||||
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
|
||||
cfg = get_config()
|
||||
return cfg.endpoint_config.get(ep, {}).get(
|
||||
"max_concurrent_connections", cfg.max_concurrent_connections
|
||||
)
|
||||
|
||||
|
||||
async def choose_endpoint(model: str, reserve: bool = True,
|
||||
affinity_key: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
ensuring that the chosen endpoint actually *advertises* the model.
|
||||
|
||||
The selection algorithm:
|
||||
|
||||
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
|
||||
2️⃣ Build a list of endpoints that contain the requested model.
|
||||
2️⃣.5 If conversation affinity is enabled and the caller passes
|
||||
``affinity_key``, prefer the endpoint that previously served the
|
||||
same conversation — but only when it still has the model loaded
|
||||
and a free slot. Otherwise fall through to the standard logic.
|
||||
3️⃣ For those endpoints, find those that have the model loaded
|
||||
(`/api/ps`) *and* still have a free slot.
|
||||
4️⃣ If none are both loaded and free, fall back to any endpoint
|
||||
from the filtered list that simply has a free slot and randomly
|
||||
select one.
|
||||
5️⃣ If all are saturated, pick any endpoint from the filtered list
|
||||
(the request will queue on that endpoint).
|
||||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
config = get_config()
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
# Include both config.endpoints and config.llama_server_endpoints
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
all_endpoints = config.endpoints + llama_eps_extra
|
||||
|
||||
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
|
||||
# 6️⃣
|
||||
if not candidate_endpoints:
|
||||
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
|
||||
model_without_latest = model.split(":latest")[0]
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
# Only add :latest suffix if model doesn't already have a version suffix
|
||||
if ":" not in model:
|
||||
model = model + ":latest"
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
|
||||
f"advertise the model '{model}'."
|
||||
)
|
||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||
# (concurrently, but only for the filtered list)
|
||||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
# 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing
|
||||
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
|
||||
# would appear with an empty loaded set but pass through to the
|
||||
# free-slot fallback (step 4) — sending completion calls to an
|
||||
# unhealthy backend. See issue #83.
|
||||
async with _loaded_error_cache_lock:
|
||||
unhealthy = {
|
||||
ep for ep, ts in _loaded_error_cache.items()
|
||||
if _is_fresh(ts, 300)
|
||||
}
|
||||
if unhealthy:
|
||||
filtered = [
|
||||
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if ep not in unhealthy
|
||||
]
|
||||
if filtered:
|
||||
candidate_endpoints = [ep for ep, _ in filtered]
|
||||
loaded_sets = [models for _, models in filtered]
|
||||
# If *every* candidate is unhealthy we still fall through with the
|
||||
# original list — refusing to route is worse than retrying a
|
||||
# possibly-recovered backend.
|
||||
|
||||
# 3️⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
|
||||
# failed with a backend connection error (e.g. llama-server in router mode
|
||||
# whose delegated worker for *this* model died). /v1/models keeps reporting
|
||||
# OK in that case, so the probe-level filter above cannot catch it.
|
||||
async with _completion_error_cache_lock:
|
||||
completion_broken = {
|
||||
ep for (ep, m), ts in _completion_error_cache.items()
|
||||
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
|
||||
}
|
||||
if completion_broken:
|
||||
filtered = [
|
||||
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if ep not in completion_broken
|
||||
]
|
||||
if filtered:
|
||||
candidate_endpoints = [ep for ep, _ in filtered]
|
||||
loaded_sets = [models for _, models in filtered]
|
||||
# Same fallback: if every candidate is broken for this model, fall
|
||||
# through and let the upstream retry — possibly the operator restarted
|
||||
# the dead worker.
|
||||
|
||||
# Look up a possible affinity hint *before* taking usage_lock. The two
|
||||
# locks are never held together to avoid lock-ordering issues.
|
||||
affine_ep: Optional[str] = None
|
||||
if config.conversation_affinity and affinity_key:
|
||||
async with _affinity_lock:
|
||||
entry = _affinity_map.get(affinity_key)
|
||||
if entry is not None:
|
||||
ep, _stored_model, expires_at = entry
|
||||
if expires_at < time.monotonic():
|
||||
_affinity_map.pop(affinity_key, None)
|
||||
else:
|
||||
affine_ep = ep
|
||||
|
||||
# Protect all reads/writes of usage_counts with the lock so that selection
|
||||
# and reservation are atomic — concurrent callers see each other's pending load.
|
||||
async with usage_lock:
|
||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||
# that increment_usage/decrement_usage store — raw model names differ from
|
||||
# tracking names for llama-server (HF prefix / quant suffix stripped).
|
||||
def tracking_usage(ep: str) -> int:
|
||||
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
|
||||
|
||||
def utilization_ratio(ep: str) -> float:
|
||||
return tracking_usage(ep) / get_max_connections(ep)
|
||||
|
||||
# Priority map: position in all_endpoints list (lower = higher priority)
|
||||
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
|
||||
|
||||
selected: Optional[str] = None
|
||||
|
||||
# 2️⃣.5 Conversation affinity preference — only honour the hint when
|
||||
# the affine endpoint still advertises the model loaded *and* has a
|
||||
# free slot. Otherwise fall back to the standard algorithm.
|
||||
if affine_ep:
|
||||
ep_loaded = {
|
||||
ep: set(models)
|
||||
for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
}
|
||||
if (affine_ep in candidate_endpoints
|
||||
and model in ep_loaded.get(affine_ep, set())
|
||||
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
|
||||
selected = affine_ep
|
||||
|
||||
if selected is None:
|
||||
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if model in models and tracking_usage(ep) < get_max_connections(ep)
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
if config.priority_routing:
|
||||
# WRR: sort by config order first (stable), then by utilization ratio.
|
||||
# Stable sort preserves priority for equal-ratio endpoints.
|
||||
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
|
||||
loaded_and_free.sort(key=utilization_ratio)
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
selected = random.choice(loaded_and_free)
|
||||
else:
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < get_max_connections(ep)
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
if config.priority_routing:
|
||||
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
|
||||
endpoints_with_free_slot.sort(key=utilization_ratio)
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
selected = random.choice(endpoints_with_free_slot)
|
||||
else:
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
if config.priority_routing:
|
||||
selected = min(
|
||||
candidate_endpoints,
|
||||
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
|
||||
)
|
||||
else:
|
||||
selected = min(candidate_endpoints, key=tracking_usage)
|
||||
|
||||
tracking_model = get_tracking_model(selected, model)
|
||||
snapshot = None
|
||||
if reserve:
|
||||
usage_counts[selected][tracking_model] += 1
|
||||
snapshot = _capture_snapshot()
|
||||
if snapshot is not None:
|
||||
await _distribute_snapshot(snapshot)
|
||||
# Record / refresh affinity *after* releasing usage_lock.
|
||||
if reserve and config.conversation_affinity and affinity_key:
|
||||
expires_at = time.monotonic() + config.conversation_affinity_ttl
|
||||
async with _affinity_lock:
|
||||
_affinity_map[affinity_key] = (selected, model, expires_at)
|
||||
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
|
||||
now = time.monotonic()
|
||||
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
|
||||
_affinity_map.pop(k, None)
|
||||
return selected, tracking_model
|
||||
14
security.py
14
security.py
|
|
@ -1,14 +0,0 @@
|
|||
"""Secret-masking helpers used when logging or surfacing backend errors."""
|
||||
import re
|
||||
|
||||
|
||||
def _mask_secrets(text: str) -> str:
|
||||
"""
|
||||
Mask common API key patterns to avoid leaking secrets in logs or error payloads.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
# OpenAI-style keys (sk-...) and generic "api key" mentions
|
||||
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
|
||||
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
||||
return text
|
||||
62
sse.py
62
sse.py
|
|
@ -1,62 +0,0 @@
|
|||
"""Server-sent-events plumbing.
|
||||
|
||||
Captures the current ``usage_counts`` / ``token_usage_counts`` snapshot and
|
||||
fan-outs it to every subscribed asyncio.Queue. Routes that need a live
|
||||
dashboard feed call ``subscribe`` / ``unsubscribe`` to obtain a queue.
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
import orjson
|
||||
|
||||
from state import (
|
||||
usage_counts,
|
||||
token_usage_counts,
|
||||
_subscribers,
|
||||
_subscribers_lock,
|
||||
)
|
||||
|
||||
|
||||
def _capture_snapshot() -> str:
|
||||
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
|
||||
return orjson.dumps({
|
||||
"usage_counts": dict(usage_counts),
|
||||
"token_usage_counts": dict(token_usage_counts)
|
||||
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
|
||||
|
||||
async def _distribute_snapshot(snapshot: str) -> None:
|
||||
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
if q.full():
|
||||
try:
|
||||
await q.get()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
await q.put(snapshot)
|
||||
|
||||
|
||||
async def close_all_sse_queues():
|
||||
for q in list(_subscribers):
|
||||
# sentinel value that the generator will recognise
|
||||
await q.put(None)
|
||||
|
||||
|
||||
async def subscribe() -> asyncio.Queue:
|
||||
"""
|
||||
Returns a new Queue that will receive every snapshot.
|
||||
"""
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=10)
|
||||
async with _subscribers_lock:
|
||||
_subscribers.add(q)
|
||||
return q
|
||||
|
||||
|
||||
async def unsubscribe(q: asyncio.Queue):
|
||||
async with _subscribers_lock:
|
||||
_subscribers.discard(q)
|
||||
|
||||
|
||||
async def get_usage_counts() -> Dict:
|
||||
return dict(usage_counts) # shallow copy
|
||||
116
state.py
116
state.py
|
|
@ -1,116 +0,0 @@
|
|||
"""Shared mutable router state.
|
||||
|
||||
All process-wide caches, locks, in-flight task maps, queues, counters and
|
||||
buffers used by the router live here. These names are only ever *mutated*
|
||||
(dict/set updates, lock acquisitions, queue put/get) — never rebound — so
|
||||
importing them via ``from state import …`` is safe from every module.
|
||||
|
||||
Rebound singletons (``config``, ``db``, ``token_worker_task``,
|
||||
``flush_task``) intentionally stay in router.py so their reassignment on
|
||||
startup is visible to all callers.
|
||||
"""
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Set
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# In‑memory caches
|
||||
# ------------------------------------------------------------------
|
||||
# Successful results are cached for 300s
|
||||
_models_cache: dict[str, tuple[Set[str], float]] = {}
|
||||
_loaded_models_cache: dict[str, tuple[Set[str], float]] = {}
|
||||
# Transient errors are cached separately per concern so that a failure
|
||||
# in one path does not poison the other.
|
||||
_available_error_cache: dict[str, float] = {}
|
||||
_loaded_error_cache: dict[str, float] = {}
|
||||
# Per-(endpoint, model) completion-path failures. A llama-server in router
|
||||
# mode can keep returning /v1/models 200 OK after its delegated worker for
|
||||
# a specific model dies — the probe-level caches above will not catch this.
|
||||
# We record signals observed during actual completion attempts so
|
||||
# choose_endpoint can avoid the affected (endpoint, model) pair without
|
||||
# poisoning unrelated models on the same backend.
|
||||
_completion_error_cache: dict[tuple[str, str], float] = {}
|
||||
_COMPLETION_ERROR_TTL = 300
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cache locks
|
||||
# ------------------------------------------------------------------
|
||||
_models_cache_lock = asyncio.Lock()
|
||||
_loaded_models_cache_lock = asyncio.Lock()
|
||||
_available_error_cache_lock = asyncio.Lock()
|
||||
_loaded_error_cache_lock = asyncio.Lock()
|
||||
_completion_error_cache_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# In-flight request tracking (prevents cache stampede)
|
||||
# ------------------------------------------------------------------
|
||||
_inflight_available_models: dict[str, asyncio.Task] = {}
|
||||
_inflight_loaded_models: dict[str, asyncio.Task] = {}
|
||||
_inflight_lock = asyncio.Lock()
|
||||
_bg_refresh_available: dict[str, asyncio.Task] = {}
|
||||
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
|
||||
_bg_refresh_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Queues
|
||||
# ------------------------------------------------------------------
|
||||
_subscribers: Set[asyncio.Queue] = set()
|
||||
_subscribers_lock = asyncio.Lock()
|
||||
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP client / connector cache
|
||||
# ------------------------------------------------------------------
|
||||
app_state = {
|
||||
"session": None,
|
||||
"connector": None,
|
||||
"probe_session": None, # dedicated session for health/introspection probes
|
||||
"probe_connector": None, # connection pool isolated from proxy traffic
|
||||
"socket_sessions": {}, # endpoint -> aiohttp.ClientSession(UnixConnector) for .sock endpoints
|
||||
"httpx_clients": {}, # endpoint -> httpx.AsyncClient(UDS transport) for .sock endpoints
|
||||
# Long-lived backend clients, reused across requests. Constructing these is
|
||||
# expensive (~40 ms each — every new client builds an SSL context and reloads
|
||||
# the OS trust store via truststore), so building one per request serializes
|
||||
# the event loop and caps throughput. Created once at startup, closed on
|
||||
# shutdown. See backends.sessions.get_ollama_client / _make_openai_client.
|
||||
"ollama_clients": {}, # endpoint -> ollama.AsyncClient
|
||||
"openai_clients": {}, # (endpoint, api_key) -> openai.AsyncOpenAI
|
||||
}
|
||||
|
||||
# Default outbound HTTP headers attached to every backend request.
|
||||
default_headers = {
|
||||
"HTTP-Referer": "https://nomyo.ai",
|
||||
"Referer": "https://nomyo.ai",
|
||||
"X-Title": "NOMYO Router",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token Count Buffer (for write-behind pattern)
|
||||
# ------------------------------------------------------------------
|
||||
# Structure: {endpoint: {model: (input_tokens, output_tokens)}}
|
||||
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0)))
|
||||
# Time series buffer with timestamp
|
||||
time_series_buffer: list[dict[str, int | str]] = []
|
||||
# Lock to protect buffer access from race conditions
|
||||
buffer_lock = asyncio.Lock()
|
||||
|
||||
# Configuration for periodic flushing
|
||||
FLUSH_INTERVAL = 10 # seconds
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per‑endpoint per‑model active connection counters
|
||||
# ------------------------------------------------------------------
|
||||
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
usage_lock = asyncio.Lock() # protects access to usage_counts
|
||||
token_usage_lock = asyncio.Lock()
|
||||
|
||||
# Conversation affinity map: fingerprint -> (endpoint, model, expires_at_monotonic).
|
||||
# Keeps the same conversation pinned to the endpoint that already has its
|
||||
# KV-cache prefix warm. Model is stored so the dashboard can aggregate live
|
||||
# entries per (endpoint, model) without recomputing fingerprints.
|
||||
# Never held together with usage_lock.
|
||||
_affinity_map: Dict[str, tuple[str, str, float]] = {}
|
||||
_affinity_lock = asyncio.Lock()
|
||||
_AFFINITY_MAX_ENTRIES = 10000
|
||||
|
|
@ -1171,13 +1171,11 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
|
|||
}
|
||||
};
|
||||
|
||||
source.onerror = (err) => {
|
||||
// EventSource auto-reconnects on transient drops as long as we
|
||||
// don't close it. Don't treat a dropped stream as an auth failure:
|
||||
// auth prompting is handled by loadEndpoints()/authedFetch() on the
|
||||
// REST endpoints. A genuine 401 closes the stream permanently here
|
||||
// (no reconnect loop), and the REST path surfaces the modal.
|
||||
console.error("SSE connection error; awaiting auto-reconnect.", err);
|
||||
source.onerror = async (err) => {
|
||||
console.error("SSE connection error. Retrying...", err);
|
||||
source.close();
|
||||
await showApiKeyModal("Enter the NOMYO Router API key to view live usage.");
|
||||
loadUsage();
|
||||
};
|
||||
window.addEventListener("beforeunload", () => source.close());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,138 +0,0 @@
|
|||
# Load testing the NOMYO Router
|
||||
|
||||
`loadtest.py` is a self-contained load generator (asyncio + httpx) with a built-in
|
||||
**mock backend** so you can measure the router's own concurrency ceiling on a given
|
||||
machine — independent of real GPU/backend compute.
|
||||
|
||||
It answers the question *"how many concurrent connections can the router sustain
|
||||
on this box?"* by hammering it with N concurrent virtual clients and reporting
|
||||
throughput, latency percentiles and (for streaming) time-to-first-token.
|
||||
|
||||
Run everything from the project root with the project venv active:
|
||||
|
||||
```bash
|
||||
source ~/.venv/nomyo-router/bin/activate # whatever venv has the router deps
|
||||
```
|
||||
|
||||
## The three modes
|
||||
|
||||
### 1. `--mock-backend` (recommended) — fully self-contained
|
||||
|
||||
Spawns a fast fake Ollama/OpenAI backend **and** the router (wired to it via a
|
||||
temporary config), drives load against the router, then tears both down. Because
|
||||
the backend is trivial, the numbers reflect the **router's proxy overhead**, not
|
||||
model inference time.
|
||||
|
||||
```bash
|
||||
python test/load/loadtest.py --mock-backend --stream --concurrency 128 --duration 30
|
||||
```
|
||||
|
||||
### 2. Default — drive an already-running router
|
||||
|
||||
```bash
|
||||
python test/load/loadtest.py --url http://127.0.0.1:12434 \
|
||||
--api ollama --stream --concurrency 64 --duration 30 --model llama3
|
||||
```
|
||||
|
||||
### 3. `--serve-mock` — just the mock backend
|
||||
|
||||
Run only the fake backend and point your own router `config.yaml` at it
|
||||
(`endpoints: [http://127.0.0.1:11434]`):
|
||||
|
||||
```bash
|
||||
python test/load/loadtest.py --serve-mock --mock-port 11434 --mock-tokens 64
|
||||
```
|
||||
|
||||
## Finding the concurrency knee
|
||||
|
||||
`--ramp` sweeps several concurrency levels and prints a table. The knee is where
|
||||
`req/s` stops rising and `p99` latency starts climbing sharply:
|
||||
|
||||
```bash
|
||||
python test/load/loadtest.py --mock-backend --stream \
|
||||
--ramp 8,32,64,128,256 --duration 15
|
||||
```
|
||||
|
||||
```
|
||||
conc req ok err req/s p50ms p90ms p99ms maxms ttftP50 ttftP99
|
||||
---------------------------------------------------------------------------------------------
|
||||
8 120 120 0 19.8 404.6 448.3 478.6 501.4 358.4 391.7
|
||||
32 140 140 0 21.5 1487.1 1641.8 2341.8 2397.4 1269.8 1476.3
|
||||
64 148 148 0 21.3 2953.0 4632.5 5204.3 5267.0 1207.8 3031.7
|
||||
128 168 168 0 19.0 6376.4 8608.9 8726.9 8739.8 2843.1 8348.6
|
||||
```
|
||||
|
||||
> Reading the table above: throughput stays flat (~20 req/s) while latency grows
|
||||
> linearly with concurrency — the classic signature of a **single-worker
|
||||
> serialization bottleneck**. Raising `--router-workers` lets throughput scale
|
||||
> across CPU cores; the per-worker ceiling is what each table row measures.
|
||||
|
||||
## Streaming vs non-streaming, Ollama vs OpenAI
|
||||
|
||||
| flag | effect |
|
||||
|------|--------|
|
||||
| `--stream` / `--no-stream` | streamed response (default) vs a single buffered response |
|
||||
| `--api ollama` | drives `POST /api/chat` (default) |
|
||||
| `--api openai` | drives `POST /v1/chat/completions` |
|
||||
|
||||
Streaming runs additionally report **TTFT** (time-to-first-token), which isolates
|
||||
prefill/routing latency from total stream duration.
|
||||
|
||||
## Shaping the mock backend (the "fake GPU")
|
||||
|
||||
The mock's latency is fully configurable, so you can model anything from an
|
||||
instant echo (measure pure proxy overhead) to a slow, long-streaming model
|
||||
(measure how many slow streams the box holds open at once):
|
||||
|
||||
| flag | meaning |
|
||||
|------|---------|
|
||||
| `--mock-ttft-ms` | prefill latency before the first token (ms) |
|
||||
| `--mock-tokens` | number of completion tokens emitted |
|
||||
| `--mock-tok-ms` | per-token decode delay (ms) — inverse of tokens/sec |
|
||||
| `--mock-models` | comma-separated model names advertised in `/api/tags` & `/api/ps` |
|
||||
|
||||
Example — simulate a realistic 40 tok/s model with 300 ms prefill emitting 200
|
||||
tokens, and see how many concurrent such streams the router holds:
|
||||
|
||||
```bash
|
||||
python test/load/loadtest.py --mock-backend --stream --ramp 16,64,256 \
|
||||
--mock-ttft-ms 300 --mock-tokens 200 --mock-tok-ms 25 --duration 20
|
||||
```
|
||||
|
||||
## Load shape & misc flags
|
||||
|
||||
| flag | default | meaning |
|
||||
|------|---------|---------|
|
||||
| `--concurrency N` | 32 | concurrent virtual clients |
|
||||
| `--duration S` | 20 | seconds per stage (ignored if `--requests` set) |
|
||||
| `--requests N` | — | send exactly N requests instead of timing out |
|
||||
| `--warmup S` | 2 | unmeasured warmup before each stage (hot caches/connections) |
|
||||
| `--timeout S` | 120 | per-request timeout |
|
||||
| `--model NAME` | `mock` | model name requested (must match what the backend advertises) |
|
||||
| `--prompt STR` | … | user prompt sent in every request |
|
||||
| `--json PATH` | — | also write the full results as JSON |
|
||||
|
||||
### `--mock-backend` orchestration knobs
|
||||
|
||||
| flag | default | meaning |
|
||||
|------|---------|---------|
|
||||
| `--router-workers N` | 1 | `uvicorn --workers` for the spawned router |
|
||||
| `--router-max-conc N` | = peak concurrency | `max_concurrent_connections` in the generated config (so the router doesn't queue unless you want it to) |
|
||||
| `--router-port` / `--mock-port` | auto | fix the ports instead of auto-picking free ones |
|
||||
| `--keep-config` | off | keep the generated temp `config.yaml` for inspection |
|
||||
|
||||
## Notes & caveats
|
||||
|
||||
- **Single-machine bias.** With `--mock-backend`, the driver, router and mock all
|
||||
share the same CPU, so they compete for cores. For an upper-bound number, run
|
||||
the driver on a separate machine against a real router (`--url`), or pin
|
||||
processes to different cores.
|
||||
- The generated config sets `conversation_affinity: false` and
|
||||
`cache_enabled: false` to measure the raw proxy path. The temp config and a
|
||||
throwaway token DB (under the system temp dir) are deleted on exit.
|
||||
- To measure the router's *admission* limit instead of raw throughput, set
|
||||
`--router-max-conc` low (e.g. `2`) — requests beyond the limit queue on the
|
||||
least-busy endpoint rather than erroring.
|
||||
- Requires the router's own dependencies (`aiohttp`, `httpx`, `uvicorn`, …); it
|
||||
reuses the project venv, no extra packages needed.
|
||||
```
|
||||
|
|
@ -1,803 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
NOMYO Router load test — asyncio + httpx driver with a built-in mock backend.
|
||||
|
||||
Three modes
|
||||
-----------
|
||||
1. Drive an already-running router (default)::
|
||||
|
||||
python test/load/loadtest.py --url http://127.0.0.1:12434 \
|
||||
--concurrency 64 --duration 30 --stream
|
||||
|
||||
2. Fully self-contained "mock backend" mode — spins up a fast fake Ollama/OpenAI
|
||||
backend AND the router (wired to that backend via a temp config), load-tests
|
||||
them, then tears both down. This isolates the *router's* proxy overhead from
|
||||
real GPU compute, so the numbers tell you how many concurrent connections the
|
||||
router itself can sustain on this machine::
|
||||
|
||||
python test/load/loadtest.py --mock-backend \
|
||||
--concurrency 128 --duration 30 --stream
|
||||
|
||||
3. Run just the mock backend (point your own router config at it)::
|
||||
|
||||
python test/load/loadtest.py --serve-mock --mock-port 11434
|
||||
|
||||
Both streaming and non-streaming are supported (--stream / --no-stream), against
|
||||
either the Ollama API (--api ollama -> POST /api/chat) or the OpenAI-compatible
|
||||
API (--api openai -> POST /v1/chat/completions).
|
||||
|
||||
Finding the concurrency ceiling
|
||||
-------------------------------
|
||||
Use --ramp to sweep concurrency levels and print a table; the "knee" is where
|
||||
p99 latency climbs sharply or req/s stops increasing::
|
||||
|
||||
python test/load/loadtest.py --mock-backend --stream \
|
||||
--ramp 16,32,64,128,256 --duration 15
|
||||
|
||||
The mock backend is a configurable "fake GPU": --mock-ttft-ms (prefill latency),
|
||||
--mock-tokens (completion length) and --mock-tok-ms (per-token decode delay) let
|
||||
you model anything from an instant echo (measure pure proxy overhead) to a slow,
|
||||
long-streaming model (measure how many slow streams the box holds open).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import statistics
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WORDS = (
|
||||
"the quick brown fox jumps over the lazy dog while a router proxies many "
|
||||
"concurrent streaming completions across several ollama and openai backends "
|
||||
"without dropping a single token under sustained synthetic load testing"
|
||||
).split()
|
||||
|
||||
|
||||
def _gen_text(n_tokens: int) -> str:
|
||||
"""Deterministic pseudo-completion of roughly ``n_tokens`` space-separated tokens."""
|
||||
return " ".join(_WORDS[i % len(_WORDS)] for i in range(max(0, n_tokens)))
|
||||
|
||||
|
||||
def _rfc3339_now() -> str:
|
||||
# Ollama-style timestamp, e.g. 2024-01-01T00:00:00.000000Z
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
|
||||
|
||||
|
||||
def _count_prompt_tokens(messages: list) -> int:
|
||||
total = 0
|
||||
for m in messages or []:
|
||||
c = m.get("content")
|
||||
if isinstance(c, str):
|
||||
total += len(c.split())
|
||||
elif isinstance(c, list):
|
||||
for part in c:
|
||||
if isinstance(part, dict) and isinstance(part.get("text"), str):
|
||||
total += len(part["text"].split())
|
||||
return max(1, total)
|
||||
|
||||
|
||||
def _free_port() -> int:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(("127.0.0.1", 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Mock backend (a fast, configurable fake Ollama + OpenAI-compatible server)
|
||||
# ===========================================================================
|
||||
|
||||
def build_mock_app(models: list[str], ttft_ms: float, tokens: int, tok_ms: float):
|
||||
"""Construct the aiohttp mock-backend application.
|
||||
|
||||
Serves the native-Ollama surface the router uses for discovery and the
|
||||
`/api/chat` path (`/api/version`, `/api/tags`, `/api/ps`, `/api/chat`,
|
||||
`/api/generate`) plus the OpenAI-compatible surface used by the
|
||||
`/v1/chat/completions` path (`/v1/models`, `/v1/chat/completions`,
|
||||
`/v1/completions`).
|
||||
"""
|
||||
from aiohttp import web # imported lazily so the driver has no hard aiohttp dep
|
||||
|
||||
ttft = ttft_ms / 1000.0
|
||||
tok_delay = tok_ms / 1000.0
|
||||
|
||||
def _tag_entry(name: str) -> dict:
|
||||
return {
|
||||
"name": name,
|
||||
"model": name,
|
||||
"modified_at": _rfc3339_now(),
|
||||
"size": 4_000_000_000,
|
||||
"digest": "0" * 64,
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "mock",
|
||||
"families": ["mock"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0",
|
||||
},
|
||||
}
|
||||
|
||||
async def version(_req):
|
||||
return web.json_response({"version": "0.0.0-nomyo-mock"})
|
||||
|
||||
async def tags(_req):
|
||||
return web.json_response({"models": [_tag_entry(m) for m in models]})
|
||||
|
||||
async def ps(_req):
|
||||
# Report every advertised model as loaded with VRAM so choose_endpoint
|
||||
# treats this endpoint as "loaded + free".
|
||||
out = []
|
||||
for m in models:
|
||||
e = _tag_entry(m)
|
||||
e["size_vram"] = e["size"]
|
||||
e["expires_at"] = "2999-01-01T00:00:00Z"
|
||||
out.append(e)
|
||||
return web.json_response({"models": out})
|
||||
|
||||
async def v1_models(_req):
|
||||
now = int(time.time())
|
||||
return web.json_response({
|
||||
"object": "list",
|
||||
"data": [{"id": m, "object": "model", "created": now, "owned_by": "mock"} for m in models],
|
||||
})
|
||||
|
||||
# ----- Ollama /api/chat -------------------------------------------------
|
||||
async def api_chat(req):
|
||||
payload = await req.json()
|
||||
model = payload.get("model", models[0] if models else "mock")
|
||||
stream = payload.get("stream", True)
|
||||
prompt_tok = _count_prompt_tokens(payload.get("messages", []))
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if stream:
|
||||
resp = web.StreamResponse(
|
||||
status=200, headers={"Content-Type": "application/x-ndjson"}
|
||||
)
|
||||
await resp.prepare(req)
|
||||
if ttft:
|
||||
await asyncio.sleep(ttft)
|
||||
for i in range(tokens):
|
||||
if tok_delay and i:
|
||||
await asyncio.sleep(tok_delay)
|
||||
line = {
|
||||
"model": model,
|
||||
"created_at": _rfc3339_now(),
|
||||
"message": {"role": "assistant", "content": _WORDS[i % len(_WORDS)] + " "},
|
||||
"done": False,
|
||||
}
|
||||
await resp.write(json.dumps(line).encode() + b"\n")
|
||||
dur_ns = int((time.perf_counter() - t0) * 1e9)
|
||||
final = {
|
||||
"model": model,
|
||||
"created_at": _rfc3339_now(),
|
||||
"message": {"role": "assistant", "content": ""},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"total_duration": dur_ns,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tok,
|
||||
"prompt_eval_duration": int(ttft * 1e9),
|
||||
"eval_count": tokens,
|
||||
"eval_duration": dur_ns,
|
||||
}
|
||||
await resp.write(json.dumps(final).encode() + b"\n")
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
# non-streaming: simulate the whole generation latency, then one object
|
||||
await asyncio.sleep(ttft + tokens * tok_delay)
|
||||
dur_ns = int((time.perf_counter() - t0) * 1e9)
|
||||
return web.json_response({
|
||||
"model": model,
|
||||
"created_at": _rfc3339_now(),
|
||||
"message": {"role": "assistant", "content": _gen_text(tokens)},
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"total_duration": dur_ns,
|
||||
"load_duration": 0,
|
||||
"prompt_eval_count": prompt_tok,
|
||||
"prompt_eval_duration": int(ttft * 1e9),
|
||||
"eval_count": tokens,
|
||||
"eval_duration": dur_ns,
|
||||
})
|
||||
|
||||
# ----- Ollama /api/generate --------------------------------------------
|
||||
async def api_generate(req):
|
||||
payload = await req.json()
|
||||
model = payload.get("model", models[0] if models else "mock")
|
||||
stream = payload.get("stream", True)
|
||||
prompt_tok = max(1, len(str(payload.get("prompt", "")).split()))
|
||||
t0 = time.perf_counter()
|
||||
if stream:
|
||||
resp = web.StreamResponse(status=200, headers={"Content-Type": "application/x-ndjson"})
|
||||
await resp.prepare(req)
|
||||
if ttft:
|
||||
await asyncio.sleep(ttft)
|
||||
for i in range(tokens):
|
||||
if tok_delay and i:
|
||||
await asyncio.sleep(tok_delay)
|
||||
await resp.write(json.dumps({
|
||||
"model": model, "created_at": _rfc3339_now(),
|
||||
"response": _WORDS[i % len(_WORDS)] + " ", "done": False,
|
||||
}).encode() + b"\n")
|
||||
dur_ns = int((time.perf_counter() - t0) * 1e9)
|
||||
await resp.write(json.dumps({
|
||||
"model": model, "created_at": _rfc3339_now(), "response": "", "done": True,
|
||||
"done_reason": "stop", "total_duration": dur_ns,
|
||||
"prompt_eval_count": prompt_tok, "eval_count": tokens, "eval_duration": dur_ns,
|
||||
}).encode() + b"\n")
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
await asyncio.sleep(ttft + tokens * tok_delay)
|
||||
dur_ns = int((time.perf_counter() - t0) * 1e9)
|
||||
return web.json_response({
|
||||
"model": model, "created_at": _rfc3339_now(), "response": _gen_text(tokens),
|
||||
"done": True, "done_reason": "stop", "total_duration": dur_ns,
|
||||
"prompt_eval_count": prompt_tok, "eval_count": tokens, "eval_duration": dur_ns,
|
||||
})
|
||||
|
||||
# ----- OpenAI /v1/chat/completions -------------------------------------
|
||||
async def v1_chat(req):
|
||||
payload = await req.json()
|
||||
model = payload.get("model", models[0] if models else "mock")
|
||||
stream = payload.get("stream", False)
|
||||
want_usage = bool((payload.get("stream_options") or {}).get("include_usage"))
|
||||
prompt_tok = _count_prompt_tokens(payload.get("messages", []))
|
||||
created = int(time.time())
|
||||
cid = "chatcmpl-mock"
|
||||
|
||||
if stream:
|
||||
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
|
||||
await resp.prepare(req)
|
||||
if ttft:
|
||||
await asyncio.sleep(ttft)
|
||||
|
||||
def _sse(obj: dict) -> bytes:
|
||||
return b"data: " + json.dumps(obj).encode() + b"\n\n"
|
||||
|
||||
# first chunk carries the role
|
||||
await resp.write(_sse({
|
||||
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": None}],
|
||||
}))
|
||||
for i in range(tokens):
|
||||
if tok_delay and i:
|
||||
await asyncio.sleep(tok_delay)
|
||||
await resp.write(_sse({
|
||||
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": _WORDS[i % len(_WORDS)] + " "}, "finish_reason": None}],
|
||||
}))
|
||||
await resp.write(_sse({
|
||||
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}))
|
||||
if want_usage:
|
||||
await resp.write(_sse({
|
||||
"id": cid, "object": "chat.completion.chunk", "created": created, "model": model,
|
||||
"choices": [],
|
||||
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
|
||||
"total_tokens": prompt_tok + tokens},
|
||||
}))
|
||||
await resp.write(b"data: [DONE]\n\n")
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
await asyncio.sleep(ttft + tokens * tok_delay)
|
||||
return web.json_response({
|
||||
"id": cid, "object": "chat.completion", "created": created, "model": model,
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": _gen_text(tokens)},
|
||||
"finish_reason": "stop", "logprobs": None}],
|
||||
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
|
||||
"total_tokens": prompt_tok + tokens},
|
||||
})
|
||||
|
||||
# ----- OpenAI /v1/completions ------------------------------------------
|
||||
async def v1_completions(req):
|
||||
payload = await req.json()
|
||||
model = payload.get("model", models[0] if models else "mock")
|
||||
stream = payload.get("stream", False)
|
||||
prompt_tok = max(1, len(str(payload.get("prompt", "")).split()))
|
||||
created = int(time.time())
|
||||
cid = "cmpl-mock"
|
||||
if stream:
|
||||
resp = web.StreamResponse(status=200, headers={"Content-Type": "text/event-stream"})
|
||||
await resp.prepare(req)
|
||||
if ttft:
|
||||
await asyncio.sleep(ttft)
|
||||
for i in range(tokens):
|
||||
if tok_delay and i:
|
||||
await asyncio.sleep(tok_delay)
|
||||
await resp.write(b"data: " + json.dumps({
|
||||
"id": cid, "object": "text_completion", "created": created, "model": model,
|
||||
"choices": [{"index": 0, "text": _WORDS[i % len(_WORDS)] + " ", "finish_reason": None}],
|
||||
}).encode() + b"\n\n")
|
||||
await resp.write(b"data: [DONE]\n\n")
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
await asyncio.sleep(ttft + tokens * tok_delay)
|
||||
return web.json_response({
|
||||
"id": cid, "object": "text_completion", "created": created, "model": model,
|
||||
"choices": [{"index": 0, "text": _gen_text(tokens), "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": prompt_tok, "completion_tokens": tokens,
|
||||
"total_tokens": prompt_tok + tokens},
|
||||
})
|
||||
|
||||
app = web.Application(client_max_size=64 * 1024 * 1024)
|
||||
app.add_routes([
|
||||
web.get("/api/version", version),
|
||||
web.get("/api/tags", tags),
|
||||
web.get("/api/ps", ps),
|
||||
web.post("/api/chat", api_chat),
|
||||
web.post("/api/generate", api_generate),
|
||||
web.get("/v1/models", v1_models),
|
||||
web.post("/v1/chat/completions", v1_chat),
|
||||
web.post("/v1/completions", v1_completions),
|
||||
])
|
||||
return app
|
||||
|
||||
|
||||
def serve_mock(args) -> None:
|
||||
from aiohttp import web
|
||||
models = [m.strip() for m in args.mock_models.split(",") if m.strip()]
|
||||
app = build_mock_app(models, args.mock_ttft_ms, args.mock_tokens, args.mock_tok_ms)
|
||||
print(f"[mock] serving models={models} on http://{args.mock_host}:{args.mock_port} "
|
||||
f"(ttft={args.mock_ttft_ms}ms tokens={args.mock_tokens} tok={args.mock_tok_ms}ms)",
|
||||
flush=True)
|
||||
web.run_app(app, host=args.mock_host, port=args.mock_port, print=None)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Load driver
|
||||
# ===========================================================================
|
||||
|
||||
@dataclass
|
||||
class Sample:
|
||||
ok: bool
|
||||
status: int
|
||||
latency: float # full request wall time (s)
|
||||
ttft: Optional[float] # time-to-first-byte for streaming (s), else None
|
||||
err: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
concurrency: int
|
||||
wall: float = 0.0
|
||||
samples: list = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def ok(self) -> list:
|
||||
return [s for s in self.samples if s.ok]
|
||||
|
||||
@property
|
||||
def n_total(self) -> int:
|
||||
return len(self.samples)
|
||||
|
||||
@property
|
||||
def n_ok(self) -> int:
|
||||
return len(self.ok)
|
||||
|
||||
|
||||
def _pct(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
s = sorted(values)
|
||||
if len(s) == 1:
|
||||
return s[0]
|
||||
k = (len(s) - 1) * (p / 100.0)
|
||||
lo = math.floor(k)
|
||||
hi = math.ceil(k)
|
||||
if lo == hi:
|
||||
return s[int(k)]
|
||||
return s[lo] + (s[hi] - s[lo]) * (k - lo)
|
||||
|
||||
|
||||
def _build_request(args):
|
||||
"""Return (path, json_payload) for a single request."""
|
||||
messages = [{"role": "user", "content": args.prompt}]
|
||||
if args.api == "openai":
|
||||
path = "/v1/chat/completions"
|
||||
body = {"model": args.model, "messages": messages, "stream": args.stream}
|
||||
else:
|
||||
path = "/api/chat"
|
||||
body = {"model": args.model, "messages": messages, "stream": args.stream}
|
||||
return path, body
|
||||
|
||||
|
||||
async def _one_request(client: httpx.AsyncClient, url: str, body: dict, stream: bool) -> Sample:
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
if stream:
|
||||
ttft = None
|
||||
async with client.stream("POST", url, json=body) as resp:
|
||||
status = resp.status_code
|
||||
async for _chunk in resp.aiter_bytes():
|
||||
if ttft is None:
|
||||
ttft = time.perf_counter() - t0
|
||||
# drain complete
|
||||
lat = time.perf_counter() - t0
|
||||
ok = 200 <= status < 300
|
||||
return Sample(ok=ok, status=status, latency=lat, ttft=ttft,
|
||||
err=None if ok else f"HTTP {status}")
|
||||
else:
|
||||
resp = await client.post(url, json=body)
|
||||
lat = time.perf_counter() - t0
|
||||
ok = 200 <= resp.status_code < 300
|
||||
# touch body so the full response is received
|
||||
_ = resp.content
|
||||
return Sample(ok=ok, status=resp.status_code, latency=lat, ttft=None,
|
||||
err=None if ok else f"HTTP {resp.status_code}")
|
||||
except Exception as e: # noqa: BLE001 — record any transport error as a failed sample
|
||||
lat = time.perf_counter() - t0
|
||||
return Sample(ok=False, status=0, latency=lat, ttft=None,
|
||||
err=f"{type(e).__name__}: {str(e)[:120]}")
|
||||
|
||||
|
||||
async def run_stage(args, concurrency: int) -> Stats:
|
||||
path, body = _build_request(args)
|
||||
url = args.url.rstrip("/") + path
|
||||
stats = Stats(concurrency=concurrency)
|
||||
|
||||
limits = httpx.Limits(max_connections=concurrency + 50,
|
||||
max_keepalive_connections=concurrency + 50)
|
||||
timeout = httpx.Timeout(args.timeout, connect=15.0)
|
||||
|
||||
async with httpx.AsyncClient(limits=limits, timeout=timeout) as client:
|
||||
# warmup (unmeasured): make a few requests so caches/connections are hot
|
||||
if args.warmup > 0:
|
||||
warm_deadline = time.perf_counter() + args.warmup
|
||||
async def _warm():
|
||||
while time.perf_counter() < warm_deadline:
|
||||
await _one_request(client, url, body, args.stream)
|
||||
await asyncio.gather(*[_warm() for _ in range(min(concurrency, 8))])
|
||||
|
||||
use_duration = args.requests is None
|
||||
deadline = time.perf_counter() + args.duration if use_duration else None
|
||||
remaining = args.requests if not use_duration else None
|
||||
remaining_lock = asyncio.Lock()
|
||||
|
||||
async def worker():
|
||||
nonlocal remaining
|
||||
while True:
|
||||
if use_duration:
|
||||
if time.perf_counter() >= deadline:
|
||||
return
|
||||
else:
|
||||
async with remaining_lock:
|
||||
if remaining <= 0:
|
||||
return
|
||||
remaining -= 1
|
||||
s = await _one_request(client, url, body, args.stream)
|
||||
stats.samples.append(s)
|
||||
|
||||
wall0 = time.perf_counter()
|
||||
await asyncio.gather(*[worker() for _ in range(concurrency)])
|
||||
stats.wall = time.perf_counter() - wall0
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def _print_stage(stats: Stats, args, header: bool) -> None:
|
||||
lat = [s.latency * 1000 for s in stats.ok]
|
||||
ttfts = [s.ttft * 1000 for s in stats.ok if s.ttft is not None]
|
||||
rps = stats.n_ok / stats.wall if stats.wall else 0.0
|
||||
errs = stats.n_total - stats.n_ok
|
||||
|
||||
if args.ramp:
|
||||
if header:
|
||||
cols = f"{'conc':>5} {'req':>7} {'ok':>7} {'err':>5} {'req/s':>9} " \
|
||||
f"{'p50ms':>8} {'p90ms':>8} {'p99ms':>9} {'maxms':>9}"
|
||||
if args.stream:
|
||||
cols += f" {'ttftP50':>8} {'ttftP99':>8}"
|
||||
print(cols)
|
||||
print("-" * len(cols))
|
||||
row = (f"{stats.concurrency:>5} {stats.n_total:>7} {stats.n_ok:>7} {errs:>5} "
|
||||
f"{rps:>9.1f} {_pct(lat,50):>8.1f} {_pct(lat,90):>8.1f} "
|
||||
f"{_pct(lat,99):>9.1f} {(max(lat) if lat else float('nan')):>9.1f}")
|
||||
if args.stream:
|
||||
row += f" {_pct(ttfts,50):>8.1f} {_pct(ttfts,99):>8.1f}"
|
||||
print(row, flush=True)
|
||||
return
|
||||
|
||||
# single-stage detailed report
|
||||
print(f"\n=== Results (concurrency={stats.concurrency}, "
|
||||
f"{'stream' if args.stream else 'non-stream'}, api={args.api}) ===")
|
||||
print(f" wall time : {stats.wall:8.2f} s")
|
||||
print(f" requests : {stats.n_total} total, {stats.n_ok} ok, {errs} failed")
|
||||
print(f" throughput : {rps:8.1f} req/s")
|
||||
if lat:
|
||||
print(f" latency p50 : {_pct(lat,50):8.1f} ms")
|
||||
print(f" p90 : {_pct(lat,90):8.1f} ms")
|
||||
print(f" p95 : {_pct(lat,95):8.1f} ms")
|
||||
print(f" p99 : {_pct(lat,99):8.1f} ms")
|
||||
print(f" max : {max(lat):8.1f} ms")
|
||||
print(f" mean : {statistics.mean(lat):8.1f} ms")
|
||||
if ttfts:
|
||||
print(f" TTFT p50 : {_pct(ttfts,50):8.1f} ms")
|
||||
print(f" p90 : {_pct(ttfts,90):8.1f} ms")
|
||||
print(f" p99 : {_pct(ttfts,99):8.1f} ms")
|
||||
if errs:
|
||||
by_err: dict[str, int] = {}
|
||||
for s in stats.samples:
|
||||
if not s.ok:
|
||||
by_err[s.err or "unknown"] = by_err.get(s.err or "unknown", 0) + 1
|
||||
print(" errors:")
|
||||
for k, v in sorted(by_err.items(), key=lambda kv: -kv[1]):
|
||||
print(f" {v:>6} {k}")
|
||||
|
||||
|
||||
async def run_driver(args) -> list[Stats]:
|
||||
stages = ([int(x) for x in args.ramp.split(",")] if args.ramp else [args.concurrency])
|
||||
results: list[Stats] = []
|
||||
for i, c in enumerate(stages):
|
||||
stats = await run_stage(args, c)
|
||||
_print_stage(stats, args, header=(i == 0))
|
||||
results.append(stats)
|
||||
return results
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Orchestration: --mock-backend (spawn mock + router, run, tear down)
|
||||
# ===========================================================================
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
async def _wait_http_ok(url: str, timeout: float, accept=(200,)) -> bool:
|
||||
deadline = time.perf_counter() + timeout
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
while time.perf_counter() < deadline:
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code in accept:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(0.25)
|
||||
return False
|
||||
|
||||
|
||||
def _write_temp_config(mock_url: str, models: list[str], max_conc: int) -> Path:
|
||||
fd, path = tempfile.mkstemp(prefix="nomyo_loadtest_", suffix=".yaml")
|
||||
os.close(fd)
|
||||
cfg = (
|
||||
"# Auto-generated by test/load/loadtest.py --mock-backend. Safe to delete.\n"
|
||||
"endpoints:\n"
|
||||
f" - {mock_url}\n"
|
||||
"llama_server_endpoints: []\n"
|
||||
f"max_concurrent_connections: {max_conc}\n"
|
||||
"priority_routing: false\n"
|
||||
"conversation_affinity: false\n"
|
||||
"cache_enabled: false\n"
|
||||
"nomyo-router-api-key: \"\"\n"
|
||||
"api_keys:\n"
|
||||
f" \"{mock_url}\": \"mock\"\n"
|
||||
)
|
||||
Path(path).write_text(cfg)
|
||||
return Path(path)
|
||||
|
||||
|
||||
async def run_with_mock_backend(args) -> list[Stats]:
|
||||
mock_port = args.mock_port or _free_port()
|
||||
router_port = args.router_port or _free_port()
|
||||
mock_url = f"http://127.0.0.1:{mock_port}"
|
||||
router_url = f"http://127.0.0.1:{router_port}"
|
||||
models = [m.strip() for m in args.mock_models.split(",") if m.strip()]
|
||||
|
||||
# Size the router's per-endpoint admission limit so it does not artificially
|
||||
# serialize the load (unless the user explicitly wants to measure that).
|
||||
peak = max([int(x) for x in args.ramp.split(",")]) if args.ramp else args.concurrency
|
||||
max_conc = args.router_max_conc if args.router_max_conc else max(peak, 1)
|
||||
|
||||
cfg_path = _write_temp_config(mock_url, models, max_conc)
|
||||
db_path = Path(tempfile.gettempdir()) / f"nomyo_loadtest_{os.getpid()}.db"
|
||||
|
||||
env = dict(os.environ)
|
||||
env["NOMYO_ROUTER_CONFIG_PATH"] = str(cfg_path)
|
||||
env["NOMYO_ROUTER_DB_PATH"] = str(db_path)
|
||||
|
||||
mock_proc = None
|
||||
router_proc = None
|
||||
try:
|
||||
# 1. mock backend first, so the router never caches it as "down"
|
||||
mock_cmd = [
|
||||
sys.executable, str(Path(__file__).resolve()), "--serve-mock",
|
||||
"--mock-host", "127.0.0.1", "--mock-port", str(mock_port),
|
||||
"--mock-models", args.mock_models,
|
||||
"--mock-ttft-ms", str(args.mock_ttft_ms),
|
||||
"--mock-tokens", str(args.mock_tokens),
|
||||
"--mock-tok-ms", str(args.mock_tok_ms),
|
||||
]
|
||||
print(f"[orchestrator] starting mock backend: {mock_url}", flush=True)
|
||||
mock_proc = await asyncio.create_subprocess_exec(*mock_cmd)
|
||||
if not await _wait_http_ok(f"{mock_url}/api/version", timeout=15):
|
||||
raise RuntimeError("mock backend did not become ready")
|
||||
|
||||
# 2. router
|
||||
router_cmd = [
|
||||
sys.executable, "-m", "uvicorn", "router:app",
|
||||
"--host", "127.0.0.1", "--port", str(router_port),
|
||||
# Per-request access logging is pure noise (and overhead) under load.
|
||||
"--no-access-log",
|
||||
]
|
||||
if args.router_workers and args.router_workers > 1:
|
||||
router_cmd += ["--workers", str(args.router_workers)]
|
||||
print(f"[orchestrator] starting router: {router_url} "
|
||||
f"(workers={args.router_workers}, max_concurrent_connections={max_conc})", flush=True)
|
||||
router_proc = await asyncio.create_subprocess_exec(
|
||||
*router_cmd, cwd=str(PROJECT_ROOT), env=env
|
||||
)
|
||||
# /health returns 200 only once it can reach the (healthy) mock backend
|
||||
if not await _wait_http_ok(f"{router_url}/health", timeout=40, accept=(200,)):
|
||||
raise RuntimeError("router did not become healthy")
|
||||
print("[orchestrator] router healthy — starting load\n", flush=True)
|
||||
|
||||
# 3. drive load against the router
|
||||
args.url = router_url
|
||||
return await run_driver(args)
|
||||
|
||||
finally:
|
||||
for name, proc in (("router", router_proc), ("mock", mock_proc)):
|
||||
if proc and proc.returncode is None:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.send_signal(signal.SIGINT)
|
||||
with contextlib.suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(proc.wait(), timeout=8)
|
||||
if proc.returncode is None:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
proc.kill()
|
||||
print(f"[orchestrator] stopped {name}", flush=True)
|
||||
if args.keep_config:
|
||||
print(f"[orchestrator] kept config: {cfg_path}", flush=True)
|
||||
else:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
cfg_path.unlink()
|
||||
for suffix in ("", "-shm", "-wal"):
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
Path(str(db_path) + suffix).unlink()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# CLI
|
||||
# ===========================================================================
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
description="NOMYO Router load test (asyncio + httpx) with a built-in mock backend.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
mode = p.add_argument_group("mode")
|
||||
mode.add_argument("--serve-mock", action="store_true",
|
||||
help="Run ONLY the mock backend (foreground) and exit on Ctrl-C.")
|
||||
mode.add_argument("--mock-backend", action="store_true",
|
||||
help="Spawn mock backend + router, load-test them, then tear down.")
|
||||
|
||||
tgt = p.add_argument_group("target (driver)")
|
||||
tgt.add_argument("--url", default="http://127.0.0.1:12434",
|
||||
help="Router base URL to drive (ignored with --mock-backend).")
|
||||
tgt.add_argument("--api", choices=["ollama", "openai"], default="ollama",
|
||||
help="ollama -> POST /api/chat ; openai -> POST /v1/chat/completions")
|
||||
tgt.add_argument("--model", default="mock", help="Model name to request.")
|
||||
tgt.add_argument("--prompt", default="Say hello and count to ten.",
|
||||
help="User prompt sent in every request.")
|
||||
stream_grp = tgt.add_mutually_exclusive_group()
|
||||
stream_grp.add_argument("--stream", dest="stream", action="store_true",
|
||||
help="Stream the response (default).")
|
||||
stream_grp.add_argument("--no-stream", dest="stream", action="store_false",
|
||||
help="Request a single non-streamed response.")
|
||||
p.set_defaults(stream=True)
|
||||
|
||||
load = p.add_argument_group("load shape")
|
||||
load.add_argument("--concurrency", type=int, default=32,
|
||||
help="Number of concurrent virtual clients.")
|
||||
load.add_argument("--duration", type=float, default=20.0,
|
||||
help="Seconds to run each stage (ignored if --requests given).")
|
||||
load.add_argument("--requests", type=int, default=None,
|
||||
help="Send exactly N requests instead of running for --duration.")
|
||||
load.add_argument("--ramp", default=None,
|
||||
help="Comma-separated concurrency stages, e.g. 16,32,64,128 "
|
||||
"(prints a table to find the knee).")
|
||||
load.add_argument("--warmup", type=float, default=2.0,
|
||||
help="Seconds of unmeasured warmup before each stage.")
|
||||
load.add_argument("--timeout", type=float, default=120.0,
|
||||
help="Per-request timeout (seconds).")
|
||||
load.add_argument("--json", dest="json_out", default=None,
|
||||
help="Also write the results as JSON to this path.")
|
||||
|
||||
mock = p.add_argument_group("mock backend tuning")
|
||||
mock.add_argument("--mock-host", default="127.0.0.1")
|
||||
mock.add_argument("--mock-port", type=int, default=0,
|
||||
help="Mock backend port (0 = auto-pick a free port).")
|
||||
mock.add_argument("--mock-models", default="mock",
|
||||
help="Comma-separated model names the mock advertises.")
|
||||
mock.add_argument("--mock-ttft-ms", type=float, default=0.0,
|
||||
help="Simulated prefill latency before the first token (ms).")
|
||||
mock.add_argument("--mock-tokens", type=int, default=64,
|
||||
help="Completion length in tokens the mock emits.")
|
||||
mock.add_argument("--mock-tok-ms", type=float, default=0.0,
|
||||
help="Simulated per-token decode delay (ms) = inverse of tok/s.")
|
||||
|
||||
orch = p.add_argument_group("router orchestration (--mock-backend only)")
|
||||
orch.add_argument("--router-port", type=int, default=0,
|
||||
help="Router port (0 = auto-pick a free port).")
|
||||
orch.add_argument("--router-workers", type=int, default=1,
|
||||
help="uvicorn --workers for the spawned router.")
|
||||
orch.add_argument("--router-max-conc", type=int, default=0,
|
||||
help="max_concurrent_connections in the generated config "
|
||||
"(0 = match peak concurrency so the router does not queue).")
|
||||
orch.add_argument("--keep-config", action="store_true",
|
||||
help="Do not delete the generated temp config on exit.")
|
||||
return p
|
||||
|
||||
|
||||
def _dump_json(path: str, args, results: list[Stats]) -> None:
|
||||
out = {
|
||||
"config": {k: getattr(args, k) for k in (
|
||||
"api", "model", "stream", "duration", "requests", "warmup", "timeout",
|
||||
"mock_tokens", "mock_ttft_ms", "mock_tok_ms")},
|
||||
"stages": [],
|
||||
}
|
||||
for st in results:
|
||||
lat = [s.latency * 1000 for s in st.ok]
|
||||
ttfts = [s.ttft * 1000 for s in st.ok if s.ttft is not None]
|
||||
out["stages"].append({
|
||||
"concurrency": st.concurrency,
|
||||
"wall_s": st.wall,
|
||||
"requests": st.n_total,
|
||||
"ok": st.n_ok,
|
||||
"errors": st.n_total - st.n_ok,
|
||||
"rps": (st.n_ok / st.wall) if st.wall else 0.0,
|
||||
"latency_ms": {p: _pct(lat, p) for p in (50, 90, 95, 99)} | (
|
||||
{"max": max(lat), "mean": statistics.mean(lat)} if lat else {}),
|
||||
"ttft_ms": {p: _pct(ttfts, p) for p in (50, 90, 99)} if ttfts else {},
|
||||
})
|
||||
Path(path).write_text(json.dumps(out, indent=2))
|
||||
print(f"\n[driver] wrote JSON results to {path}", flush=True)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = build_parser().parse_args()
|
||||
|
||||
if args.serve_mock:
|
||||
try:
|
||||
serve_mock(args)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
return
|
||||
|
||||
if args.requests is not None and args.requests <= 0:
|
||||
print("--requests must be > 0", file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
if args.mock_backend:
|
||||
results = asyncio.run(run_with_mock_backend(args)) or []
|
||||
else:
|
||||
results = asyncio.run(run_driver(args))
|
||||
|
||||
if args.json_out and results:
|
||||
_dump_json(args.json_out, args, results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
pytest>=8.0
|
||||
pytest-asyncio>=0.24
|
||||
pytest-cov>=5.0
|
||||
aioresponses>=0.7
|
||||
|
|
|
|||
|
|
@ -1,19 +1,11 @@
|
|||
"""Tests for fetch.available_models and fetch.loaded_models.
|
||||
|
||||
The backend probes obtain their HTTP client via ``backends.probe.get_probe_session``
|
||||
and only ever call ``async with client.get(url, headers=...) as resp``. We patch that
|
||||
seam with a tiny fake session instead of mocking aiohttp's internals (aioresponses),
|
||||
so the suite stays independent of aiohttp's private ClientResponse/ConnectionKey
|
||||
structure across version bumps.
|
||||
"""
|
||||
"""Tests for fetch.available_models and fetch.loaded_models using aioresponses mocking."""
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from aioresponses import aioresponses
|
||||
|
||||
import router
|
||||
import backends.probe as probe
|
||||
from conftest import TEST_OLLAMA, TEST_LLAMA
|
||||
|
||||
MOCK_OLLAMA_EP = "http://mock-ollama:11434"
|
||||
|
|
@ -30,73 +22,6 @@ def _make_cfg(ollama_eps=None, llama_eps=None, api_keys=None):
|
|||
return cfg
|
||||
|
||||
|
||||
# ── Fake probe session ────────────────────────────────────────────────────────
|
||||
|
||||
class _MockResponse:
|
||||
"""Minimal stand-in for the aiohttp response used by the probes."""
|
||||
|
||||
def __init__(self, *, status=200, payload=None, text=None):
|
||||
self.status = status
|
||||
self._payload = payload
|
||||
self._text = text if text is not None else ""
|
||||
|
||||
async def json(self):
|
||||
return self._payload
|
||||
|
||||
async def text(self):
|
||||
return self._text
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
return False
|
||||
|
||||
|
||||
class _RaisingCtx:
|
||||
"""``async with client.get(...)`` that raises on entry — mimics a failed connection."""
|
||||
|
||||
def __init__(self, exc):
|
||||
self._exc = exc
|
||||
|
||||
async def __aenter__(self):
|
||||
raise self._exc
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
return False
|
||||
|
||||
|
||||
class _MockProbeSession:
|
||||
"""Stand-in for the aiohttp ClientSession returned by ``get_probe_session``.
|
||||
|
||||
Routes are registered by exact URL via :meth:`add_get`. A registered exception
|
||||
is raised when the route is entered; otherwise a :class:`_MockResponse` is yielded.
|
||||
An unregistered GET fails loudly so tests can't silently pass on a wrong URL.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._routes = {}
|
||||
|
||||
def add_get(self, url, *, status=200, payload=None, text=None, exception=None):
|
||||
self._routes[url] = exception if exception is not None else _MockResponse(
|
||||
status=status, payload=payload, text=text
|
||||
)
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
if url not in self._routes:
|
||||
raise AssertionError(f"unexpected probe GET {url}")
|
||||
entry = self._routes[url]
|
||||
return _RaisingCtx(entry) if isinstance(entry, Exception) else entry
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_probe():
|
||||
"""Patch the probe's session factory to return a fresh :class:`_MockProbeSession`."""
|
||||
session = _MockProbeSession()
|
||||
with patch.object(probe, "get_probe_session", lambda endpoint: session):
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_caches(aio_session):
|
||||
"""aio_session fixture already clears caches and sets up app_state."""
|
||||
|
|
@ -106,8 +31,8 @@ def clear_caches(aio_session):
|
|||
class TestFetchAvailableModels:
|
||||
async def test_ollama_tags(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [
|
||||
{"name": "llama3.2:latest"},
|
||||
|
|
@ -119,8 +44,8 @@ class TestFetchAvailableModels:
|
|||
|
||||
async def test_openai_compatible_models_endpoint(self):
|
||||
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_LLAMA_EP}/models",
|
||||
payload={"data": [{"id": "unsloth/model:Q8_0"}]},
|
||||
)
|
||||
|
|
@ -129,8 +54,8 @@ class TestFetchAvailableModels:
|
|||
|
||||
async def test_caches_successful_result(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
|
|
@ -141,19 +66,20 @@ class TestFetchAvailableModels:
|
|||
|
||||
async def test_returns_empty_on_http_500(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(f"{MOCK_OLLAMA_EP}/api/tags", status=500, payload={"error": "oops"})
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
async def test_returns_empty_on_connection_error(self):
|
||||
import aiohttp
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
import aiohttp
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
exception=aiohttp.ClientConnectionError(
|
||||
"Cannot connect to host mock-ollama:11434 [Connection refused]"
|
||||
exception=aiohttp.ClientConnectorError(
|
||||
connection_key=MagicMock(host="mock-ollama", port=11434),
|
||||
os_error=OSError(111, "refused"),
|
||||
),
|
||||
)
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
|
|
@ -161,8 +87,8 @@ class TestFetchAvailableModels:
|
|||
|
||||
async def test_stale_cache_returned_while_refresh_runs(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
|
|
@ -173,8 +99,8 @@ class TestFetchAvailableModels:
|
|||
models, _ = router._models_cache[MOCK_OLLAMA_EP]
|
||||
router._models_cache[MOCK_OLLAMA_EP] = (models, time.time() - 400)
|
||||
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/tags",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
|
|
@ -188,8 +114,8 @@ class TestFetchAvailableModels:
|
|||
async with router._available_error_cache_lock:
|
||||
router._available_error_cache[MOCK_OLLAMA_EP] = time.time()
|
||||
|
||||
with patch.object(router, "config", cfg), mock_probe():
|
||||
# No route registered — if a call happens it raises AssertionError
|
||||
with patch.object(router, "config", cfg), aioresponses():
|
||||
# No HTTP mock registered — if a call happens it will raise
|
||||
models = await router.fetch.available_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
|
|
@ -197,8 +123,8 @@ class TestFetchAvailableModels:
|
|||
class TestFetchLoadedModels:
|
||||
async def test_ollama_ps(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/ps",
|
||||
payload={"models": [{"name": "llama3.2:latest"}]},
|
||||
)
|
||||
|
|
@ -207,8 +133,8 @@ class TestFetchLoadedModels:
|
|||
|
||||
async def test_llama_server_filters_loaded(self):
|
||||
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_LLAMA_EP}/models",
|
||||
payload={"data": [
|
||||
{"id": "model-a", "status": {"value": "loaded"}},
|
||||
|
|
@ -220,8 +146,8 @@ class TestFetchLoadedModels:
|
|||
|
||||
async def test_llama_server_no_status_field_always_loaded(self):
|
||||
cfg = _make_cfg(llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_LLAMA_EP}/models",
|
||||
payload={"data": [{"id": "always-on-model"}]},
|
||||
)
|
||||
|
|
@ -230,8 +156,8 @@ class TestFetchLoadedModels:
|
|||
|
||||
async def test_returns_empty_on_error(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=503, payload={})
|
||||
models = await router.fetch.loaded_models(MOCK_OLLAMA_EP)
|
||||
assert models == set()
|
||||
|
||||
|
|
@ -244,8 +170,8 @@ class TestFetchLoadedModels:
|
|||
|
||||
async def test_caches_result(self):
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/ps",
|
||||
payload={"models": [{"name": "qwen:7b"}]},
|
||||
)
|
||||
|
|
@ -257,15 +183,15 @@ class TestFetchLoadedModels:
|
|||
# Regression: issue #83 — /api/ps failures must be recorded so
|
||||
# `choose_endpoint` can exclude unhealthy backends from routing.
|
||||
cfg = _make_cfg(ollama_eps=[MOCK_OLLAMA_EP], llama_eps=[])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(f"{MOCK_OLLAMA_EP}/api/ps", status=502, payload={})
|
||||
await router.fetch.loaded_models(MOCK_OLLAMA_EP)
|
||||
assert MOCK_OLLAMA_EP in router._loaded_error_cache
|
||||
|
||||
async def test_records_error_for_llama_server_on_failure(self):
|
||||
cfg = _make_cfg(ollama_eps=[], llama_eps=[MOCK_LLAMA_EP])
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(f"{MOCK_LLAMA_EP}/models", status=502, payload={})
|
||||
await router.fetch.loaded_models(MOCK_LLAMA_EP)
|
||||
assert MOCK_LLAMA_EP in router._loaded_error_cache
|
||||
|
||||
|
|
@ -275,8 +201,8 @@ class TestFetchLoadedModels:
|
|||
# network probe instead of short-circuiting on the error cache.
|
||||
async with router._loaded_error_cache_lock:
|
||||
router._loaded_error_cache[MOCK_OLLAMA_EP] = time.time() - 301
|
||||
with patch.object(router, "config", cfg), mock_probe() as m:
|
||||
m.add_get(
|
||||
with patch.object(router, "config", cfg), aioresponses() as m:
|
||||
m.get(
|
||||
f"{MOCK_OLLAMA_EP}/api/ps",
|
||||
payload={"models": [{"name": "qwen:7b"}]},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import pytest
|
|||
from fastapi import HTTPException
|
||||
|
||||
import router
|
||||
from api import openai as api_openai
|
||||
|
||||
|
||||
_BYPASS = HTTPException(status_code=599, detail="bypassed")
|
||||
|
|
@ -48,8 +47,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
# Patch the route's references to both helpers — they're imported by name
|
||||
# into router's namespace at module load time.
|
||||
with (
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -71,8 +70,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -99,8 +98,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
"""When nomyo.cache=False, get_chat is never called even if a cache exists."""
|
||||
fake = _FakeCache(b"") # has a response, but should never be consulted
|
||||
with (
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -118,8 +117,8 @@ class TestOpenAIChatCompletionsCacheHit:
|
|||
async def test_no_cache_configured_bypasses_cache_check(self, client):
|
||||
"""get_llm_cache() returning None should not break the route."""
|
||||
with (
|
||||
patch.object(api_openai, "get_llm_cache", return_value=None),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
patch.object(router, "get_llm_cache", return_value=None),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=_BYPASS)),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -141,8 +140,8 @@ class TestOpenAICompletionsCacheHit:
|
|||
async def test_nonstream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
@ -164,8 +163,8 @@ class TestOpenAICompletionsCacheHit:
|
|||
async def test_stream_cache_hit(self, client, cache_hit_payload):
|
||||
fake = _FakeCache(cache_hit_payload)
|
||||
with (
|
||||
patch.object(api_openai, "get_llm_cache", return_value=fake),
|
||||
patch.object(api_openai, "choose_endpoint",
|
||||
patch.object(router, "get_llm_cache", return_value=fake),
|
||||
patch.object(router, "choose_endpoint",
|
||||
AsyncMock(side_effect=AssertionError("backend must not be reached"))),
|
||||
):
|
||||
resp = await client.post(
|
||||
|
|
|
|||
|
|
@ -1,151 +0,0 @@
|
|||
"""
|
||||
Unit tests for transitive backend-error handling in the four Ollama-native
|
||||
streaming generators (``/api/generate``, ``/api/chat``, ``/api/embeddings``,
|
||||
``/api/embed``).
|
||||
|
||||
These reproduce the reported failure mode: a backend (nginx in front of ollama)
|
||||
returns a 504 Gateway Time-out *while the response is being streamed*, so the
|
||||
``ollama`` client raises ``ResponseError`` from inside the StreamingResponse
|
||||
generator. Before the fix this escaped as an opaque "Exception in ASGI
|
||||
application" traceback; now ``_handle_stream_error`` logs the endpoint/model and
|
||||
emits a terminal Ollama-format ``{"error": ..., "status_code": ...}`` line.
|
||||
|
||||
No real backend required — the ollama client and routing are mocked.
|
||||
"""
|
||||
import json
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import ollama
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from conftest import TEST_OLLAMA
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# ── Fakes ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class _Chunk:
|
||||
"""Minimal Ollama-native streaming chunk the generators can consume."""
|
||||
prompt_eval_count = 0
|
||||
eval_count = 0
|
||||
done = False
|
||||
message = None
|
||||
response = None
|
||||
done_reason = None
|
||||
|
||||
def model_dump_json(self):
|
||||
return '{"model": "fake", "done": false}'
|
||||
|
||||
|
||||
def _one_then_raise(exc):
|
||||
"""Async generator: yield one valid chunk, then fail mid-stream."""
|
||||
async def _gen():
|
||||
yield _Chunk()
|
||||
raise exc
|
||||
return _gen()
|
||||
|
||||
|
||||
class _FakeAsyncClient:
|
||||
"""Stand-in for ``ollama.AsyncClient`` that fails with ``exc``.
|
||||
|
||||
Streaming methods (chat/generate) fail *after* one chunk to mimic a
|
||||
mid-stream 504; the embedding methods fail on the initial await.
|
||||
"""
|
||||
def __init__(self, exc, *args, **kwargs):
|
||||
self._exc = exc
|
||||
|
||||
async def chat(self, **kwargs):
|
||||
return _one_then_raise(self._exc)
|
||||
|
||||
async def generate(self, **kwargs):
|
||||
return _one_then_raise(self._exc)
|
||||
|
||||
async def embeddings(self, **kwargs):
|
||||
raise self._exc
|
||||
|
||||
async def embed(self, **kwargs):
|
||||
raise self._exc
|
||||
|
||||
|
||||
def _patches(exc, mark_unhealthy):
|
||||
"""Patch routing + the ollama client so the native path hits ``exc``."""
|
||||
stack = ExitStack()
|
||||
stack.enter_context(
|
||||
patch("api.ollama.choose_endpoint", AsyncMock(return_value=(TEST_OLLAMA, "fake")))
|
||||
)
|
||||
stack.enter_context(patch("api.ollama.is_openai_compatible", lambda ep: False))
|
||||
stack.enter_context(patch("api.ollama.decrement_usage", AsyncMock()))
|
||||
stack.enter_context(patch("api.ollama._mark_backend_unhealthy", mark_unhealthy))
|
||||
# The native path now fetches a cached client via get_ollama_client() rather
|
||||
# than constructing ollama.AsyncClient inline, so patch that seam.
|
||||
stack.enter_context(
|
||||
patch("api.ollama.get_ollama_client", lambda *a, **k: _FakeAsyncClient(exc))
|
||||
)
|
||||
return stack
|
||||
|
||||
|
||||
# Route → request payload. stream=True only matters for chat/generate.
|
||||
_ROUTES = {
|
||||
"/api/chat": {"model": "fake", "stream": True, "messages": [{"role": "user", "content": "hi"}]},
|
||||
"/api/generate": {"model": "fake", "stream": True, "prompt": "hi"},
|
||||
"/api/embeddings": {"model": "fake", "prompt": "hi"},
|
||||
"/api/embed": {"model": "fake", "input": "hi"},
|
||||
}
|
||||
|
||||
|
||||
def _last_json_line(text):
|
||||
lines = [l for l in text.strip().split("\n") if l.strip()]
|
||||
assert lines, "expected at least one ndjson line in the response body"
|
||||
return json.loads(lines[-1])
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
|
||||
async def test_504_surfaces_as_error_line(client, route, payload):
|
||||
"""A 504 ResponseError becomes a terminal {"error", "status_code"} line."""
|
||||
exc = ollama.ResponseError("<html>504 Gateway Time-out</html>", 504)
|
||||
mark = AsyncMock()
|
||||
with _patches(exc, mark):
|
||||
resp = await client.post(route, json=payload)
|
||||
|
||||
# Streaming already started (or single-shot) → HTTP status is 200, the
|
||||
# error is delivered in-band rather than as a 5xx crash.
|
||||
assert resp.status_code == 200
|
||||
err = _last_json_line(resp.text)
|
||||
assert "error" in err
|
||||
assert "504" in err["error"]
|
||||
assert err["status_code"] == 504
|
||||
# A plain 504 is not a connection-class failure → endpoint stays healthy.
|
||||
mark.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("route, payload", list(_ROUTES.items()))
|
||||
async def test_no_asgi_500_on_backend_failure(client, route, payload):
|
||||
"""The generator must never let the backend error escape as a 500."""
|
||||
exc = ollama.ResponseError("boom", 502)
|
||||
with _patches(exc, AsyncMock()):
|
||||
resp = await client.post(route, json=payload)
|
||||
assert resp.status_code == 200
|
||||
assert resp.status_code != 500
|
||||
|
||||
|
||||
async def test_connection_error_marks_backend_unhealthy(client):
|
||||
"""A connection-class failure mid-stream marks (endpoint, model) unhealthy."""
|
||||
exc = openai.APIConnectionError(request=httpx.Request("POST", "http://x"))
|
||||
mark = AsyncMock()
|
||||
with _patches(exc, mark):
|
||||
resp = await client.post("/api/chat", json=_ROUTES["/api/chat"])
|
||||
|
||||
assert resp.status_code == 200
|
||||
err = _last_json_line(resp.text)
|
||||
assert "error" in err
|
||||
mark.assert_awaited_once()
|
||||
# Called with the routed endpoint + model.
|
||||
called_ep, called_model = mark.await_args.args[0], mark.await_args.args[1]
|
||||
assert called_ep == TEST_OLLAMA
|
||||
assert called_model == "fake"
|
||||
177
tokens.py
177
tokens.py
|
|
@ -1,177 +0,0 @@
|
|||
"""Token-count write-behind pipeline.
|
||||
|
||||
``token_worker`` drains ``token_queue`` into the in-memory buffer (and into
|
||||
``token_usage_counts`` for immediate SSE reporting). ``flush_buffer``
|
||||
periodically persists the buffer to SQLite via ``TokenDatabase``.
|
||||
``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left.
|
||||
|
||||
The lock order is ``buffer_lock`` then ``token_usage_lock`` — see
|
||||
choose_endpoint for why we never combine them with usage_lock.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from state import (
|
||||
token_queue,
|
||||
token_buffer,
|
||||
time_series_buffer,
|
||||
buffer_lock,
|
||||
token_usage_counts,
|
||||
token_usage_lock,
|
||||
FLUSH_INTERVAL,
|
||||
)
|
||||
from sse import _capture_snapshot, _distribute_snapshot
|
||||
from db import get_db
|
||||
|
||||
|
||||
async def token_worker() -> None:
|
||||
try:
|
||||
while True:
|
||||
endpoint, model, prompt, comp = await token_queue.get()
|
||||
# Calculate timestamp once before acquiring lock
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||||
|
||||
# Accumulate counts in memory buffer (protected by lock)
|
||||
async with buffer_lock:
|
||||
token_buffer[endpoint][model] = (
|
||||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||||
)
|
||||
|
||||
# Add to time series buffer with timestamp (UTC)
|
||||
time_series_buffer.append({
|
||||
'endpoint': endpoint,
|
||||
'model': model,
|
||||
'input_tokens': prompt,
|
||||
'output_tokens': comp,
|
||||
'total_tokens': prompt + comp,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
# Update in-memory counts for immediate reporting
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
except asyncio.CancelledError:
|
||||
# Gracefully handle task cancellation during shutdown
|
||||
print("[token_worker] Task cancelled, processing remaining queue items...")
|
||||
# Process any remaining items in the queue before exiting
|
||||
while not token_queue.empty():
|
||||
try:
|
||||
endpoint, model, prompt, comp = token_queue.get_nowait()
|
||||
# Calculate timestamp once before acquiring lock
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||||
|
||||
async with buffer_lock:
|
||||
token_buffer[endpoint][model] = (
|
||||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||||
)
|
||||
time_series_buffer.append({
|
||||
'endpoint': endpoint,
|
||||
'model': model,
|
||||
'input_tokens': prompt,
|
||||
'output_tokens': comp,
|
||||
'total_tokens': prompt + comp,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
print("[token_worker] Task cancelled, remaining items processed.")
|
||||
raise
|
||||
|
||||
|
||||
async def flush_buffer() -> None:
|
||||
"""Periodically flush accumulated token counts to the database."""
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(FLUSH_INTERVAL)
|
||||
|
||||
# Flush token counts and time series (protected by lock)
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
# Copy buffer before releasing lock for DB operation
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
|
||||
# Perform DB operations outside the lock to avoid blocking
|
||||
db = get_db()
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
except asyncio.CancelledError:
|
||||
# Gracefully handle task cancellation during shutdown
|
||||
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
|
||||
# Flush any remaining data before exiting
|
||||
try:
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
db = get_db()
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
|
||||
except Exception as e:
|
||||
print(f"[flush_buffer] Error during shutdown flush: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def flush_remaining_buffers() -> None:
|
||||
"""
|
||||
Flush any in-memory buffers to the database on shutdown.
|
||||
This is designed to be safely invoked during shutdown and should not raise.
|
||||
"""
|
||||
try:
|
||||
flushed_entries = 0
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
flushed_entries += sum(len(v) for v in token_buffer.values())
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
flushed_entries += len(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
# Perform DB operations outside the lock
|
||||
db = get_db()
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
if flushed_entries:
|
||||
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
|
||||
else:
|
||||
print("[shutdown] No in-memory entries to flush on shutdown.")
|
||||
except Exception as e:
|
||||
# Do not raise during shutdown – log and continue teardown
|
||||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
||||
100256
vendor/tiktoken/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
vendored
100256
vendor/tiktoken/9b5ad71b2ce5302211f9c61530b329a4922fc6a4
vendored
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue