nomyo-router/router.py

3380 lines
153 KiB
Python
Raw Normal View History

2025-08-26 18:19:43 +02:00
"""
2026-03-05 11:09:20 +01:00
title: NOMYO Router - an (O)llama and OpenAI API v1 Proxy with Endpoint:Model aware routing
2025-08-26 18:19:43 +02:00
author: alpha-nerd-nomyo
author_url: https://github.com/nomyo-ai
2026-05-13 14:59:05 +02:00
version: 0.9
2025-08-26 18:19:43 +02:00
license: AGPL
"""
# -------------------------------------------------------------
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx, hashlib
try:
import truststore; truststore.inject_into_ssl()
except ImportError:
pass
from datetime import datetime, timezone
2025-08-26 18:19:43 +02:00
from pathlib import Path
# Directory containing static files (relative to this script)
STATIC_DIR = Path(__file__).parent / "static"
2025-09-05 12:11:31 +02:00
from typing import Dict, Set, List, Optional
from urllib.parse import urlparse, parse_qsl, urlencode
2025-08-26 18:19:43 +02:00
from fastapi import FastAPI, Request, HTTPException
2025-09-05 12:11:31 +02:00
from fastapi_sse import sse_handler
from fastapi.staticfiles import StaticFiles
2025-09-11 09:46:19 +02:00
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
2025-08-26 18:19:43 +02:00
from pydantic import Field
from pydantic_settings import BaseSettings
from collections import defaultdict
from PIL import Image
2026-05-19 10:05:27 +02:00
from security import _mask_secrets
from context_window import (
_count_message_tokens,
_trim_messages_for_context,
_calibrated_trim_target,
_endpoint_nctx,
_CTX_TRIM_SMALL_LIMIT,
)
2026-05-19 11:18:06 +02:00
from state import (
_models_cache,
_loaded_models_cache,
_available_error_cache,
_loaded_error_cache,
_completion_error_cache,
_COMPLETION_ERROR_TTL,
_models_cache_lock,
_loaded_models_cache_lock,
_available_error_cache_lock,
_loaded_error_cache_lock,
_completion_error_cache_lock,
_inflight_available_models,
_inflight_loaded_models,
_inflight_lock,
_bg_refresh_available,
_bg_refresh_loaded,
_bg_refresh_lock,
_subscribers,
_subscribers_lock,
token_queue,
app_state,
token_buffer,
time_series_buffer,
buffer_lock,
FLUSH_INTERVAL,
)
2026-05-19 11:18:06 +02:00
# Rebound on startup — must stay in router.py module namespace.
token_worker_task: asyncio.Task | None = None
flush_task: asyncio.Task | None = None
2026-05-19 11:00:50 +02:00
from config import Config, _config_path_from_env
from ollama._types import TokenLogprob, Logprob
from db import TokenDatabase
2026-03-08 09:12:09 +01:00
from cache import init_llm_cache, get_llm_cache, openai_nonstream_to_sse
2026-05-19 12:05:51 +02:00
# Create the global config object it will be overwritten on startup.
# Submodules read it lazily via config.get_config().
config = Config.from_yaml(_config_path_from_env())
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 2. FastAPI application
# -------------------------------------------------------------
app = FastAPI()
2025-09-05 12:11:31 +02:00
sse_handler.app = app
2025-09-11 09:46:19 +02:00
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE"],
allow_headers=["Authorization", "Content-Type"],
)
2026-05-19 12:05:51 +02:00
from state import default_headers
# -------------------------------------------------------------
# Router-level authentication (optional)
# -------------------------------------------------------------
def _extract_router_api_key(request: Request) -> Optional[str]:
"""
Extract the provided router API key from the Authorization header or `api_key`
query parameter. The middleware uses this to gate access to API routes when
a router_api_key is configured.
"""
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
key = auth_header.split(" ", 1)[1].strip()
if key: # Ensure key is not empty
return key
query_key = request.query_params.get("api_key")
if query_key:
return query_key
return None
def _strip_api_key_from_scope(request: Request) -> None:
"""
Remove api_key from the ASGI scope query string to avoid leaking it in logs.
"""
scope = request.scope
raw_qs = scope.get("query_string", b"")
if not raw_qs:
return
params = parse_qsl(raw_qs.decode("utf-8"), keep_blank_values=True)
filtered = [(k, v) for (k, v) in params if k != "api_key"]
scope["query_string"] = urlencode(filtered).encode("utf-8")
@app.middleware("http")
async def enforce_router_api_key(request: Request, call_next):
"""
Enforce the optional NOMYO Router API key for all non-static requests.
When `config.router_api_key` is set, clients must supply the key either in
the Authorization header (`Bearer <key>`) or as `api_key` query parameter.
"""
expected_key = config.router_api_key
if not expected_key or request.method == "OPTIONS":
return await call_next(request)
path = request.url.path
# Allow static assets (CSS, JS, images, fonts) but NOT HTML pages,
# which would bypass auth by accessing /static/index.html directly.
_STATIC_ASSET_EXTS = {".css", ".js", ".ico", ".png", ".jpg", ".jpeg", ".svg", ".woff", ".woff2", ".ttf", ".map"}
is_static_asset = path.startswith("/static") and Path(path).suffix.lower() in _STATIC_ASSET_EXTS
if is_static_asset or path in {"/", "/favicon.ico"}:
return await call_next(request)
provided_key = _extract_router_api_key(request)
# Strip the api_key query param from scope so access logs do not leak it
_strip_api_key_from_scope(request)
if provided_key is None:
# No key provided but authentication is required - return 401
headers = {}
if "/api/" in path and path != "/api/usage-stream":
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Authorization, Content-Type",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
}
return JSONResponse(
content={"detail": "Missing NOMYO Router API key"},
status_code=401,
headers=headers,
)
if not secrets.compare_digest(str(provided_key), str(expected_key)):
return JSONResponse(
content={"detail": "Invalid NOMYO Router API key"},
status_code=403,
)
response = await call_next(request)
# Add CORS headers for authenticated API requests
if "/api/" in path and path != "/api/usage-stream":
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
return response
@app.exception_handler(openai.APIStatusError)
async def _openai_api_status_error_handler(request: Request, exc: openai.APIStatusError):
"""Forward upstream OpenAI-SDK status errors with their original status code and body
instead of letting them bubble up as 500s."""
body = exc.body if exc.body is not None else {"error": {"message": str(exc), "code": exc.status_code}}
return JSONResponse(status_code=exc.status_code, content=body)
2026-05-19 11:18:06 +02:00
from state import (
usage_counts,
token_usage_counts,
usage_lock,
token_usage_lock,
_affinity_map,
_affinity_lock,
_AFFINITY_MAX_ENTRIES,
)
2026-05-19 10:05:27 +02:00
from fingerprint import _conversation_fingerprint
# Database instance
db: "TokenDatabase" = None
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
2026-05-19 12:05:51 +02:00
# 4. Helperfunctions
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
2026-05-19 12:05:51 +02:00
from backends.normalize import (
_normalize_llama_model_name,
_extract_llama_quant,
ep2base,
dedupe_on_keys,
)
from backends.sessions import (
_is_unix_socket_endpoint,
_get_socket_path,
get_session,
_make_openai_client,
)
from backends.health import (
_is_fresh,
_ensure_success,
_format_connection_issue,
_is_backend_connection_error,
_mark_backend_unhealthy,
_is_llama_model_loaded,
_is_llama_model_loaded_or_sleeping,
)
2026-05-19 12:05:51 +02:00
from backends.normalize import (
is_ext_openai_endpoint,
is_openai_compatible,
get_tracking_model,
)
async def token_worker() -> None:
try:
while True:
endpoint, model, prompt, comp = await token_queue.get()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
# Accumulate counts in memory buffer (protected by lock)
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
# Add to time series buffer with timestamp (UTC)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
# Update in-memory counts for immediate reporting
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[token_worker] Task cancelled, processing remaining queue items...")
# Process any remaining items in the queue before exiting
while not token_queue.empty():
try:
endpoint, model, prompt, comp = token_queue.get_nowait()
# Calculate timestamp once before acquiring lock
now = datetime.now(tz=timezone.utc)
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
async with buffer_lock:
token_buffer[endpoint][model] = (
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
token_buffer[endpoint].get(model, (0, 0))[1] + comp
)
time_series_buffer.append({
'endpoint': endpoint,
'model': model,
'input_tokens': prompt,
'output_tokens': comp,
'total_tokens': prompt + comp,
'timestamp': timestamp
})
async with token_usage_lock:
token_usage_counts[endpoint][model] += (prompt + comp)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
except asyncio.QueueEmpty:
break
print("[token_worker] Task cancelled, remaining items processed.")
raise
async def flush_buffer() -> None:
"""Periodically flush accumulated token counts to the database."""
try:
while True:
await asyncio.sleep(FLUSH_INTERVAL)
# Flush token counts and time series (protected by lock)
async with buffer_lock:
if token_buffer:
# Copy buffer before releasing lock for DB operation
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock to avoid blocking
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
except asyncio.CancelledError:
# Gracefully handle task cancellation during shutdown
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
# Flush any remaining data before exiting
try:
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
except Exception as e:
print(f"[flush_buffer] Error during shutdown flush: {e}")
raise
async def flush_remaining_buffers() -> None:
"""
Flush any in-memory buffers to the database on shutdown.
This is designed to be safely invoked during shutdown and should not raise.
"""
try:
flushed_entries = 0
async with buffer_lock:
if token_buffer:
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
flushed_entries += sum(len(v) for v in token_buffer.values())
token_buffer.clear()
else:
buffer_copy = None
if time_series_buffer:
ts_copy = list(time_series_buffer)
flushed_entries += len(time_series_buffer)
time_series_buffer.clear()
else:
ts_copy = None
# Perform DB operations outside the lock
if buffer_copy:
await db.update_batched_counts(buffer_copy)
if ts_copy:
await db.add_batched_time_series(ts_copy)
if flushed_entries:
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
else:
print("[shutdown] No in-memory entries to flush on shutdown.")
except Exception as e:
# Do not raise during shutdown log and continue teardown
print(f"[shutdown] Error flushing remaining buffers: {e}")
2026-05-19 12:05:51 +02:00
from backends.probe import fetch
2025-08-26 18:19:43 +02:00
async def increment_usage(endpoint: str, model: str) -> None:
async with usage_lock:
usage_counts[endpoint][model] += 1
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
async def decrement_usage(endpoint: str, model: str) -> None:
async with usage_lock:
# Avoid negative counts
current = usage_counts[endpoint].get(model, 0)
if current > 0:
usage_counts[endpoint][model] = current - 1
# Optionally, clean up zero entries
if usage_counts[endpoint].get(model, 0) == 0:
usage_counts[endpoint].pop(model, None)
2025-09-05 12:11:31 +02:00
#if not usage_counts[endpoint]:
# usage_counts.pop(endpoint, None)
snapshot = _capture_snapshot()
await _distribute_snapshot(snapshot)
2025-09-05 12:11:31 +02:00
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make a chat request to a specific endpoint.
Handles endpoint selection, client creation, usage tracking, and request execution.
"""
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")[0]
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
try:
if use_openai:
start_ts = time.perf_counter()
try:
response = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[_make_chat_request] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
raise
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[_make_chat_request] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
response = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
if "exceed_context_size_error" in str(e2) or "exceeds the available context size" in str(e2):
print(f"[_make_chat_request] Context still exceeded after trimming, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
response = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
else:
raise
elif "image input is not supported" in _e_str:
print(f"[_make_chat_request] Model {model} doesn't support images, retrying with text-only messages")
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
response = await oclient.chat.completions.create(**params)
else:
raise
if stream:
# For streaming, we need to collect all chunks
chunks = []
tc_acc = {} # accumulate tool-call deltas
async for chunk in response:
chunks.append(chunk)
_accumulate_openai_tc_delta(chunk, tc_acc)
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
# Inject fully-accumulated tool calls into the final response
if tc_acc and response.message:
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
else:
prompt_tok = 0
comp_tok = 0
if response.usage is not None:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(response)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
if stream:
# For streaming, collect all chunks
chunks = []
async for chunk in response:
chunks.append(chunk)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, tracking_model)
def get_last_user_content(messages):
"""
Given a list of dicts (e.g., messages from an API),
return the 'content' of the last dict whose 'role' is 'user'.
If no such dict exists, return None.
"""
# Reverse iterate so we stop at the first match
for msg in reversed(messages):
if msg.get("role") == "user":
return msg.get("content")
return None
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
"""
Helper function to make MOE (Multiple Opinions Ensemble) requests.
Generates 3 responses, 3 critiques, and returns the final selected response.
"""
query = get_last_user_content(messages)
if not query:
raise ValueError("No user query found in messages")
if options is None:
options = {}
options["temperature"] = 1
moe_reqs = []
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
responses = await asyncio.gather(response1_task, response2_task, response3_task)
for n, r in enumerate(responses):
moe_req = enhance.moe(query, n, r.message.content)
moe_reqs.append(moe_req)
# Generate 3 critiques
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
# Select final response
m = enhance.moe_select_candidate(query, critiques)
# Generate final response
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
2026-05-19 10:05:27 +02:00
from images import iso8601_ns, is_base64, resize_image_if_needed
def _strip_assistant_prefill(messages: list) -> list:
"""Remove a trailing assistant message used as prefill.
OpenAI-compatible endpoints (including Claude) do not support prefill and
will reject requests where the last message has role 'assistant'."""
if messages and messages[-1].get("role") == "assistant":
return messages[:-1]
return messages
def transform_tool_calls_to_openai(message_list):
"""
Ensure tool_calls in assistant messages conform to the OpenAI format:
- Each tool call must have "type": "function"
- Each tool call must have an "id"
- arguments must be a JSON string, not a dict
Also ensure tool-role messages have a tool_call_id.
"""
# Track generated IDs so tool-role messages can reference them
last_tool_call_ids = {}
for msg in message_list:
role = msg.get("role")
if role == "assistant" and "tool_calls" in msg:
for tc in msg["tool_calls"]:
if "type" not in tc:
tc["type"] = "function"
if "id" not in tc:
tc["id"] = f"call_{secrets.token_hex(16)}"
func = tc.get("function", {})
if isinstance(func.get("arguments"), dict):
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
# Remember the id for the following tool-role message
name = func.get("name")
if name:
last_tool_call_ids[name] = tc["id"]
elif role == "tool":
if "tool_call_id" not in msg:
# Try to match by name from a preceding assistant tool_call
name = msg.get("name") or msg.get("tool_name")
if name and name in last_tool_call_ids:
msg["tool_call_id"] = last_tool_call_ids.pop(name)
return message_list
2025-09-23 17:33:15 +02:00
def transform_images_to_data_urls(message_list):
for message in message_list:
if "images" in message:
images = message.pop("images")
if not isinstance(images, list):
continue
new_content = []
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
if not is_base64(image):
raise ValueError(f"Image string is not a valid base64 encoded string.")
resized_image = resize_image_if_needed(image)
if resized_image:
data_url = f"data:image/png;base64,{resized_image}"
#new_content.append({
# "type": "text",
# "text": ""
#})
new_content.append({
"type": "image_url",
"image_url": {
"url": data_url
}
})
2025-09-23 17:33:15 +02:00
message["content"] = new_content
return message_list
def _strip_images_from_messages(messages: list) -> list:
"""Remove image_url parts from message content, keeping only text."""
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
text_only = [p for p in content if p.get("type") != "image_url"]
if len(text_only) == 1 and text_only[0].get("type") == "text":
content = text_only[0]["text"]
else:
content = text_only
result.append({**msg, "content": content})
else:
result.append(msg)
return result
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
``accumulator`` is a dict mapping tool-call *index* to
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
is the concatenation of all JSON fragments seen so far.
"""
if not chunk.choices:
return
delta = chunk.choices[0].delta
tc_deltas = getattr(delta, "tool_calls", None)
if not tc_deltas:
return
for tc in tc_deltas:
idx = tc.index
if idx not in accumulator:
accumulator[idx] = {
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
"name": tc.function.name if tc.function else None,
"arguments": "",
}
else:
if getattr(tc, "id", None):
accumulator[idx]["id"] = tc.id
if tc.function and tc.function.name:
accumulator[idx]["name"] = tc.function.name
if tc.function and tc.function.arguments:
accumulator[idx]["arguments"] += tc.function.arguments
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
if not accumulator:
return None
result = []
for idx in sorted(accumulator.keys()):
tc = accumulator[idx]
try:
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
except (orjson.JSONDecodeError, TypeError):
args = {}
result.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
))
return result
def _convert_openai_logprobs(choice) -> list | None:
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
lp = getattr(choice, "logprobs", None)
if lp is None:
return None
content = getattr(lp, "content", None)
if not content:
return None
result = []
for entry in content:
top = [
TokenLogprob(token=alt.token, logprob=alt.logprob)
for alt in (entry.top_logprobs or [])
]
result.append(Logprob(
token=entry.token,
logprob=entry.logprob,
top_logprobs=top or None,
))
return result
class rechunk:
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
now = time.perf_counter()
2025-09-23 12:51:37 +02:00
if chunk.choices == [] and chunk.usage is not None:
return ollama.ChatResponse(
model=chunk.model,
created_at=iso8601_ns(),
done=True,
done_reason='stop',
total_duration=int((now - start_ts) * 1_000_000_000),
load_duration=100000,
2025-09-23 12:51:37 +02:00
prompt_eval_count=int(chunk.usage.prompt_tokens),
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
2025-09-23 12:51:37 +02:00
eval_count=int(chunk.usage.completion_tokens),
eval_duration=int((now - start_ts) * 1_000_000_000),
message=ollama.Message(role="assistant", content=""),
2025-09-23 12:51:37 +02:00
)
with_thinking = chunk.choices[0] if chunk.choices[0] else None
if stream == True:
thinking = (getattr(with_thinking.delta, "reasoning_content", None) or getattr(with_thinking.delta, "reasoning", None)) if with_thinking else None
2025-09-21 16:33:43 +02:00
role = chunk.choices[0].delta.role or "assistant"
2025-09-23 12:51:37 +02:00
content = chunk.choices[0].delta.content or ''
else:
thinking = (getattr(with_thinking.message, "reasoning_content", None) or getattr(with_thinking.message, "reasoning", None)) if with_thinking else None
2025-09-21 16:33:43 +02:00
role = chunk.choices[0].message.role or "assistant"
2025-09-23 12:51:37 +02:00
content = chunk.choices[0].message.content or ''
# Convert OpenAI tool_calls to Ollama format
# In streaming mode, tool_calls arrive as partial deltas across multiple chunks
# (name only in first delta, arguments as incremental JSON fragments).
# Callers must accumulate deltas and inject the final result; skip here.
ollama_tool_calls = None
if not stream:
raw_tool_calls = getattr(with_thinking.message, "tool_calls", None) if with_thinking else None
if raw_tool_calls:
ollama_tool_calls = []
for tc in raw_tool_calls:
try:
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
except (orjson.JSONDecodeError, TypeError):
args = {}
ollama_tool_calls.append(ollama.Message.ToolCall(
function=ollama.Message.ToolCall.Function(name=tc.function.name, arguments=args)
))
# Convert OpenAI logprobs to Ollama format
ollama_logprobs = _convert_openai_logprobs(with_thinking) if with_thinking else None
2025-09-21 16:33:43 +02:00
assistant_msg = ollama.Message(
role=role,
content=content,
thinking=thinking,
2025-09-21 16:33:43 +02:00
images=None,
tool_name=None,
tool_calls=ollama_tool_calls)
2025-09-21 16:33:43 +02:00
rechunk = ollama.ChatResponse(
model=chunk.model,
2025-09-21 16:33:43 +02:00
created_at=iso8601_ns(),
2025-09-23 12:51:37 +02:00
done=True if chunk.usage is not None else False,
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=100000,
2025-09-22 19:01:14 +02:00
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
2025-09-22 19:01:14 +02:00
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
message=assistant_msg,
logprobs=ollama_logprobs)
2025-09-13 12:38:13 +02:00
return rechunk
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
now = time.perf_counter()
with_thinking = chunk.choices[0] if chunk.choices[0] else None
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
2025-09-21 16:33:43 +02:00
rechunk = ollama.GenerateResponse(
model=chunk.model,
created_at=iso8601_ns(),
2025-09-23 12:51:37 +02:00
done=True if chunk.usage is not None else False,
2025-09-21 16:33:43 +02:00
done_reason=chunk.choices[0].finish_reason,
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
load_duration=10000,
2025-09-23 12:51:37 +02:00
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
2025-09-23 12:51:37 +02:00
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
2025-09-23 12:51:37 +02:00
response=chunk.choices[0].text or '',
thinking=thinking)
return rechunk
def openai_embeddings2ollama(chunk: dict) -> ollama.EmbeddingsResponse:
rechunk = ollama.EmbeddingsResponse(embedding=chunk.data[0].embedding)
return rechunk
def openai_embed2ollama(chunk: dict, model: str) -> ollama.EmbedResponse:
2025-09-21 16:33:43 +02:00
rechunk = ollama.EmbedResponse(
model=model,
created_at=iso8601_ns(),
done=None,
done_reason=None,
total_duration=None,
load_duration=None,
prompt_eval_count=None,
prompt_eval_duration=None,
eval_count=None,
eval_duration=None,
embeddings=[chunk.data[0].embedding])
return rechunk
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
llama-server returns a ``timings`` dict instead of the standard OpenAI
``usage`` field::
"timings": {
"cache_n": 236, // prompt tokens reused from cache
"prompt_n": 1, // prompt tokens processed
"predicted_n": 35 // predicted (completion) tokens
}
prompt_tokens = prompt_n + cache_n
completion_tokens = predicted_n
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
timings are found.
"""
timings = getattr(obj, "timings", None)
if timings is None:
return None
if isinstance(timings, dict):
prompt_n = timings.get("prompt_n", 0) or 0
cache_n = timings.get("cache_n", 0) or 0
predicted_n = timings.get("predicted_n", 0) or 0
return (prompt_n + cache_n, predicted_n)
return None
2025-09-05 12:11:31 +02:00
# ------------------------------------------------------------------
# SSE Helpser
# ------------------------------------------------------------------
def _capture_snapshot() -> str:
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
return orjson.dumps({
"usage_counts": dict(usage_counts),
"token_usage_counts": dict(token_usage_counts)
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
async def _distribute_snapshot(snapshot: str) -> None:
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
2025-09-05 12:11:31 +02:00
async with _subscribers_lock:
for q in _subscribers:
if q.full():
try:
await q.get()
except asyncio.QueueEmpty:
pass
2025-09-05 12:11:31 +02:00
await q.put(snapshot)
async def close_all_sse_queues():
for q in list(_subscribers):
# sentinel value that the generator will recognise
await q.put(None)
2025-09-05 12:11:31 +02:00
# ------------------------------------------------------------------
# Subscriber helpers
# ------------------------------------------------------------------
async def subscribe() -> asyncio.Queue:
"""
Returns a new Queue that will receive every snapshot.
"""
q: asyncio.Queue = asyncio.Queue(maxsize=10)
async with _subscribers_lock:
_subscribers.add(q)
return q
async def unsubscribe(q: asyncio.Queue):
async with _subscribers_lock:
_subscribers.discard(q)
# ------------------------------------------------------------------
# Convenience wrapper returns the current snapshot (for the proxy)
# ------------------------------------------------------------------
async def get_usage_counts() -> Dict:
return dict(usage_counts) # shallow copy
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 5. Endpoint selection logic (respecting the configurable limit)
# -------------------------------------------------------------
2026-04-22 17:27:34 +02:00
def get_max_connections(ep: str) -> int:
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
return config.endpoint_config.get(ep, {}).get(
"max_concurrent_connections", config.max_concurrent_connections
)
async def choose_endpoint(model: str, reserve: bool = True,
affinity_key: Optional[str] = None) -> tuple[str, str]:
2025-08-26 18:19:43 +02:00
"""
Determine which endpoint to use for the given model while respecting
the `max_concurrent_connections` per endpointmodel pair **and**
ensuring that the chosen endpoint actually *advertises* the model.
The selection algorithm:
1 Query every endpoint for its advertised models (`/api/tags`).
2 Build a list of endpoints that contain the requested model.
2.5 If conversation affinity is enabled and the caller passes
``affinity_key``, prefer the endpoint that previously served the
same conversation but only when it still has the model loaded
and a free slot. Otherwise fall through to the standard logic.
3 For those endpoints, find those that have the model loaded
(`/api/ps`) *and* still have a free slot.
4 If none are both loaded and free, fall back to any endpoint
from the filtered list that simply has a free slot and randomly
select one.
5 If all are saturated, pick any endpoint from the filtered list
(the request will queue on that endpoint).
6 If no endpoint advertises the model at all, raise an error.
2025-08-26 18:19:43 +02:00
"""
# 1⃣ Gather advertisedmodel sets for all endpoints concurrently
# Include both config.endpoints and config.llama_server_endpoints
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints = config.endpoints + llama_eps_extra
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
advertised_sets = await asyncio.gather(*tag_tasks)
# 2⃣ Filter endpoints that advertise the requested model
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
# 6
if not candidate_endpoints:
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
model_without_latest = model.split(":latest")[0]
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
]
if not candidate_endpoints:
# Only add :latest suffix if model doesn't already have a version suffix
if ":" not in model:
model = model + ":latest"
candidate_endpoints = [
ep for ep, models in zip(all_endpoints, advertised_sets)
if model in models
]
if not candidate_endpoints:
raise RuntimeError(
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
f"advertise the model '{model}'."
)
# 3⃣ Among the candidates, find those that have the model *loaded*
# (concurrently, but only for the filtered list)
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
loaded_sets = await asyncio.gather(*load_tasks)
# 3⃣.5 Exclude endpoints whose loaded-model probe has been failing
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
# would appear with an empty loaded set but pass through to the
# free-slot fallback (step 4) — sending completion calls to an
# unhealthy backend. See issue #83.
async with _loaded_error_cache_lock:
unhealthy = {
ep for ep, ts in _loaded_error_cache.items()
if _is_fresh(ts, 300)
}
if unhealthy:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in unhealthy
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# If *every* candidate is unhealthy we still fall through with the
# original list — refusing to route is worse than retrying a
# possibly-recovered backend.
# 3⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
# failed with a backend connection error (e.g. llama-server in router mode
# whose delegated worker for *this* model died). /v1/models keeps reporting
# OK in that case, so the probe-level filter above cannot catch it.
async with _completion_error_cache_lock:
completion_broken = {
ep for (ep, m), ts in _completion_error_cache.items()
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
}
if completion_broken:
filtered = [
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
if ep not in completion_broken
]
if filtered:
candidate_endpoints = [ep for ep, _ in filtered]
loaded_sets = [models for _, models in filtered]
# Same fallback: if every candidate is broken for this model, fall
# through and let the upstream retry — possibly the operator restarted
# the dead worker.
# Look up a possible affinity hint *before* taking usage_lock. The two
# locks are never held together to avoid lock-ordering issues.
affine_ep: Optional[str] = None
if config.conversation_affinity and affinity_key:
async with _affinity_lock:
entry = _affinity_map.get(affinity_key)
if entry is not None:
ep, _stored_model, expires_at = entry
if expires_at < time.monotonic():
_affinity_map.pop(affinity_key, None)
else:
affine_ep = ep
# Protect all reads/writes of usage_counts with the lock so that selection
# and reservation are atomic — concurrent callers see each other's pending load.
2025-08-26 18:19:43 +02:00
async with usage_lock:
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
# tracking names for llama-server (HF prefix / quant suffix stripped).
def tracking_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
2026-04-22 17:27:34 +02:00
def utilization_ratio(ep: str) -> float:
return tracking_usage(ep) / get_max_connections(ep)
# Priority map: position in all_endpoints list (lower = higher priority)
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
selected: Optional[str] = None
# 2⃣.5 Conversation affinity preference — only honour the hint when
# the affine endpoint still advertises the model loaded *and* has a
# free slot. Otherwise fall back to the standard algorithm.
if affine_ep:
ep_loaded = {
ep: set(models)
for ep, models in zip(candidate_endpoints, loaded_sets)
}
if (affine_ep in candidate_endpoints
and model in ep_loaded.get(affine_ep, set())
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
selected = affine_ep
if selected is None:
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and tracking_usage(ep) < get_max_connections(ep)
]
if loaded_and_free:
2026-04-22 17:27:34 +02:00
if config.priority_routing:
# WRR: sort by config order first (stable), then by utilization ratio.
# Stable sort preserves priority for equal-ratio endpoints.
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
loaded_and_free.sort(key=utilization_ratio)
selected = loaded_and_free[0]
2026-04-22 17:27:34 +02:00
else:
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
selected = random.choice(loaded_and_free)
2026-04-22 17:27:34 +02:00
else:
selected = loaded_and_free[0]
else:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if tracking_usage(ep) < get_max_connections(ep)
]
if endpoints_with_free_slot:
if config.priority_routing:
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
endpoints_with_free_slot.sort(key=utilization_ratio)
selected = endpoints_with_free_slot[0]
else:
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
selected = random.choice(endpoints_with_free_slot)
else:
selected = endpoints_with_free_slot[0]
2026-04-22 17:27:34 +02:00
else:
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
if config.priority_routing:
selected = min(
candidate_endpoints,
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
)
else:
selected = min(candidate_endpoints, key=tracking_usage)
2025-08-26 18:19:43 +02:00
tracking_model = get_tracking_model(selected, model)
snapshot = None
if reserve:
usage_counts[selected][tracking_model] += 1
snapshot = _capture_snapshot()
if snapshot is not None:
await _distribute_snapshot(snapshot)
# Record / refresh affinity *after* releasing usage_lock.
if reserve and config.conversation_affinity and affinity_key:
expires_at = time.monotonic() + config.conversation_affinity_ttl
async with _affinity_lock:
_affinity_map[affinity_key] = (selected, model, expires_at)
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
now = time.monotonic()
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
_affinity_map.pop(k, None)
return selected, tracking_model
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 6. API route Generate
# -------------------------------------------------------------
@app.post("/api/generate")
async def proxy(request: Request):
"""
Proxy a generate request to Ollama and stream the response back to the client.
"""
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-08-26 18:19:43 +02:00
model = payload.get("model")
prompt = payload.get("prompt")
suffix = payload.get("suffix")
system = payload.get("system")
template = payload.get("template")
context = payload.get("context")
stream = payload.get("stream")
think = payload.get("think")
raw = payload.get("raw")
2025-09-11 18:53:23 +02:00
_format = payload.get("format")
2025-08-26 18:19:43 +02:00
images = payload.get("images")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
2025-08-26 18:19:43 +02:00
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except orjson.JSONDecodeError as e:
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
raise HTTPException(status_code=400, detail=error_msg) from e
2025-08-26 18:19:43 +02:00
2026-03-08 09:12:09 +01:00
# Cache lookup — before endpoint selection so no slot is wasted on a hit
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
_cached = await _cache.get_generate(model, prompt, system or "")
if _cached is not None:
async def _serve_cached_generate():
yield _cached
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
2025-09-15 19:12:00 +02:00
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
2025-08-26 18:19:43 +02:00
optional_params = {
"stream": stream,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
2025-10-28 11:08:52 +01:00
"suffix": suffix,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
2025-08-26 18:19:43 +02:00
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
try:
if use_openai:
start_ts = time.perf_counter()
async_gen = await oclient.completions.create(**params)
else:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
2025-08-26 18:19:43 +02:00
if stream == True:
2026-03-08 09:12:09 +01:00
content_parts: list[str] = []
2025-08-26 18:19:43 +02:00
async for chunk in async_gen:
if use_openai:
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
2025-08-26 18:19:43 +02:00
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = orjson.dumps(chunk)
2026-03-08 09:12:09 +01:00
# Accumulate and store cache on done chunk — before yield so it always runs
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
if getattr(chunk, "response", None):
content_parts.append(chunk.response)
if getattr(chunk, "done", False):
assembled = orjson.dumps({
k: v for k, v in {
"model": getattr(chunk, "model", model),
"response": "".join(content_parts),
"done": True,
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
"eval_count": getattr(chunk, "eval_count", None),
"total_duration": getattr(chunk, "total_duration", None),
"eval_duration": getattr(chunk, "eval_duration", None),
}.items() if v is not None
}) + b"\n"
try:
await _cache.set_generate(model, prompt, system or "", assembled)
except Exception as _ce:
print(f"[cache] set_generate (streaming) failed: {_ce}")
2025-08-26 18:19:43 +02:00
yield json_line.encode("utf-8") + b"\n"
else:
if use_openai:
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
response = response.model_dump_json()
else:
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
2025-08-26 18:19:43 +02:00
json_line = (
response
2025-08-26 18:19:43 +02:00
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
2025-08-26 18:19:43 +02:00
)
2026-03-08 09:12:09 +01:00
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
try:
await _cache.set_generate(model, prompt, system or "", cache_bytes)
except Exception as _ce:
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
2025-08-26 18:19:43 +02:00
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
2025-08-26 18:19:43 +02:00
# 5. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_generate_response(),
media_type="application/json",
)
# -------------------------------------------------------------
# 7. API route Chat
# -------------------------------------------------------------
@app.post("/api/chat")
async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-08-26 18:19:43 +02:00
model = payload.get("model")
messages = payload.get("messages")
tools = payload.get("tools")
stream = payload.get("stream")
think = payload.get("think")
_format = payload.get("format")
2025-08-26 18:19:43 +02:00
keep_alive = payload.get("keep_alive")
options = payload.get("options")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
2025-09-23 17:33:15 +02:00
2025-08-26 18:19:43 +02:00
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
2025-08-26 18:19:43 +02:00
)
if options is not None and not isinstance(options, dict):
raise HTTPException(
status_code=400, detail="`options` must be a JSON object"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
2026-03-08 09:12:09 +01:00
# Cache lookup — before endpoint selection, always bypassed for MOE
_is_moe = model.startswith("moe-")
_cache = get_llm_cache()
# Normalise model name for cache key: strip ":latest" suffix here so that
# get_chat and set_chat use the same model string regardless of when the
# strip happens further down (line ~1793 strips it for OpenAI endpoints).
_cache_model = model[: -len(":latest")] if model.endswith(":latest") else model
# Snapshot original messages before any OpenAI-format transformation so that
# get_chat and set_chat always use the same key regardless of backend type.
_cache_messages = messages
if _cache is not None and not _is_moe and _cache_enabled:
2026-03-08 09:12:09 +01:00
_cached = await _cache.get_chat("ollama_chat", _cache_model, messages)
if _cached is not None:
async def _serve_cached_chat():
yield _cached
return StreamingResponse(
_serve_cached_chat(),
media_type="application/x-ndjson" if stream else "application/json",
)
2025-08-26 18:19:43 +02:00
# 2. Endpoint logic
if model.startswith("moe-"):
model = model.split("moe-")[1]
opt = True
else:
opt = False
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
2025-09-15 19:12:00 +02:00
model = model.split(":latest")
model = model[0]
2025-09-23 17:33:15 +02:00
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None),
"top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None),
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
# For OpenAI endpoints: make the API call in handler scope
# (try/except inside async generators is unreliable with Starlette's streaming)
start_ts = None
async_gen = None
if use_openai:
start_ts = time.perf_counter()
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
_pre_est = _count_message_tokens(params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
params = {**params, "messages": _pre_trimmed}
2025-08-26 18:19:43 +02:00
try:
async_gen = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
async_gen = await oclient.chat.completions.create(**params)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
if use_openai:
_async_gen = async_gen # established in handler scope above
else:
if opt == True:
# Use the dedicated MOE helper function
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
else:
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
2025-08-26 18:19:43 +02:00
if stream == True:
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
2026-03-08 09:12:09 +01:00
content_parts: list[str] = []
async for chunk in _async_gen:
if use_openai:
_accumulate_openai_tc_delta(chunk, tc_acc)
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
# Inject fully-accumulated tool calls only into the final chunk
if chunk.done and tc_acc and chunk.message:
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
2025-08-26 18:19:43 +02:00
# `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
2025-08-26 18:19:43 +02:00
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = orjson.dumps(chunk)
2026-03-08 09:12:09 +01:00
# Accumulate and store cache on done chunk — before yield so it always runs
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
# already converted to Ollama format by rechunk before this point.
if getattr(chunk, "done", False):
# Detect context exhaustion mid-generation for small-ctx models
_dr = getattr(chunk, "done_reason", None)
# Only cache when no max_tokens limit was set — otherwise
# finish_reason=length might just mean max_tokens was hit,
# not that the context window was exhausted.
_req_max_tok = (
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
if use_openai else
(options.get("num_predict") if options else None)
)
if _dr == "length" and not _req_max_tok:
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
_ct = getattr(chunk, "eval_count", 0) or 0
_inferred_nctx = _pt + _ct
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
if _cache is not None and not _is_moe and _cache_enabled:
2026-03-08 09:12:09 +01:00
if chunk.message and getattr(chunk.message, "content", None):
content_parts.append(chunk.message.content)
if getattr(chunk, "done", False):
assembled = orjson.dumps({
k: v for k, v in {
"model": getattr(chunk, "model", model),
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
"message": {"role": "assistant", "content": "".join(content_parts)},
"done": True,
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
"eval_count": getattr(chunk, "eval_count", None),
"total_duration": getattr(chunk, "total_duration", None),
"eval_duration": getattr(chunk, "eval_duration", None),
}.items() if v is not None
}) + b"\n"
try:
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
2025-08-26 18:19:43 +02:00
yield json_line.encode("utf-8") + b"\n"
else:
if use_openai:
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
response = response.model_dump_json()
else:
response = _async_gen.model_dump_json()
prompt_tok = _async_gen.prompt_eval_count or 0
comp_tok = _async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
2025-08-26 18:19:43 +02:00
json_line = (
response
if hasattr(_async_gen, "model_dump_json")
else orjson.dumps(_async_gen)
2025-08-26 18:19:43 +02:00
)
2026-03-08 09:12:09 +01:00
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
if _cache is not None and not _is_moe and _cache_enabled:
2026-03-08 09:12:09 +01:00
try:
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
2025-08-26 18:19:43 +02:00
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
2025-08-26 18:19:43 +02:00
# 4. Return a StreamingResponse backed by the generator
2025-09-22 19:01:14 +02:00
media_type = "application/x-ndjson" if stream else "application/json"
2025-08-26 18:19:43 +02:00
return StreamingResponse(
stream_chat_response(),
2025-09-22 19:01:14 +02:00
media_type=media_type,
2025-08-26 18:19:43 +02:00
)
# -------------------------------------------------------------
# 8. API route Embedding - deprecated
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/embeddings")
async def embedding_proxy(request: Request):
"""
Proxy an embedding request to Ollama and reply with embeddings.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-08-26 18:19:43 +02:00
model = payload.get("model")
prompt = payload.get("prompt")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
2025-09-15 19:12:00 +02:00
model = model.split(":latest")
model = model[0]
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
2025-08-26 18:19:43 +02:00
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
if use_openai:
async_gen = await client.embeddings.create(input=prompt, model=model)
async_gen = rechunk.openai_embeddings2ollama(async_gen)
else:
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
2025-08-26 18:19:43 +02:00
if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json()
else:
json_line = orjson.dumps(async_gen)
2025-08-26 18:19:43 +02:00
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
2025-08-26 18:19:43 +02:00
# 5. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_embedding_response(),
media_type="application/json",
)
# -------------------------------------------------------------
# 9. API route Embed
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/embed")
async def embed_proxy(request: Request):
"""
Proxy an embed request to Ollama and reply with embeddings.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-08-26 18:19:43 +02:00
model = payload.get("model")
_input = payload.get("input")
2025-08-26 18:19:43 +02:00
truncate = payload.get("truncate")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not _input:
2025-08-26 18:19:43 +02:00
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
2025-09-15 19:12:00 +02:00
model = model.split(":latest")
model = model[0]
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
2025-08-26 18:19:43 +02:00
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
# The chat method returns a generator of dicts (or GenerateResponse)
if use_openai:
async_gen = await client.embeddings.create(input=_input, model=model)
async_gen = rechunk.openai_embed2ollama(async_gen, model)
else:
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
2025-08-26 18:19:43 +02:00
if hasattr(async_gen, "model_dump_json"):
json_line = async_gen.model_dump_json()
else:
json_line = orjson.dumps(async_gen)
2025-08-26 18:19:43 +02:00
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
2025-08-26 18:19:43 +02:00
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_embedding_response(),
media_type="application/json",
)
# -------------------------------------------------------------
# 10. API route Create
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/create")
async def create_proxy(request: Request):
"""
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
"""
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-08-26 18:19:43 +02:00
model = payload.get("model")
quantize = payload.get("quantize")
from_ = payload.get("from")
files = payload.get("files")
adapters = payload.get("adapters")
template = payload.get("template")
license = payload.get("license")
system = payload.get("system")
parameters = payload.get("parameters")
messages = payload.get("messages")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not from_ and not files:
raise HTTPException(
status_code=400, detail="You need to provide either from_ or files parameter!"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
status_lists = []
2025-08-26 18:19:43 +02:00
for endpoint in config.endpoints:
client = ollama.AsyncClient(host=endpoint)
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
status_lists.append(create)
combined_status = []
for status_list in status_lists:
combined_status += status_list
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
# -------------------------------------------------------------
# 11. API route Show
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/show")
2025-09-05 12:11:31 +02:00
async def show_proxy(request: Request, model: Optional[str] = None):
2025-08-26 18:19:43 +02:00
"""
Proxy a model show request to Ollama and reply with ShowResponse.
"""
try:
body_bytes = await request.body()
2025-09-05 12:11:31 +02:00
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-09-05 12:11:31 +02:00
model = payload.get("model")
2025-08-26 18:19:43 +02:00
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, _ = await choose_endpoint(model, reserve=False)
2025-08-26 18:19:43 +02:00
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple show request
show = await client.show(model=model)
# 4. Return ShowResponse
return show
# -------------------------------------------------------------
@app.get("/api/token_counts")
async def token_counts_proxy():
breakdown = []
total = 0
async for entry in db.load_token_counts():
total += entry['total_tokens']
breakdown.append({
"endpoint": entry["endpoint"],
"model": entry["model"],
"input_tokens": entry["input_tokens"],
"output_tokens": entry["output_tokens"],
"total_tokens": entry["total_tokens"],
})
return {"total_tokens": total, "breakdown": breakdown}
@app.post("/api/aggregate_time_series_days")
async def aggregate_time_series_days_proxy(request: Request):
"""
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
"""
try:
body_bytes = await request.body()
if not body_bytes:
days = 30
trim_old = False
else:
payload = orjson.loads(body_bytes.decode("utf-8"))
days = int(payload.get("days", 30))
trim_old = bool(payload.get("trim_old", False))
except Exception:
days = 30
trim_old = False
aggregated = await db.aggregate_time_series_older_than(days, trim_old=trim_old)
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
# 12. API route Stats
# -------------------------------------------------------------
@app.post("/api/stats")
async def stats_proxy(request: Request, model: Optional[str] = None):
"""
Return token usage statistics for a specific model.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Get token counts from database
token_data = await db.get_token_counts_for_model(model)
if not token_data:
raise HTTPException(
status_code=404, detail="No token data found for this model"
)
time_series = [
entry async for entry in db.get_time_series_for_model(model)
]
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
return {
'model': model,
'input_tokens': token_data['input_tokens'],
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
2025-11-19 17:28:31 +01:00
'time_series': time_series,
'endpoint_distribution': endpoint_distribution,
}
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 12. API route Copy
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/copy")
2025-09-05 12:11:31 +02:00
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
2025-08-26 18:19:43 +02:00
"""
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
2025-09-05 12:11:31 +02:00
if not source and not destination:
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-09-05 12:11:31 +02:00
src = payload.get("source")
dst = payload.get("destination")
else:
src = source
dst = destination
2025-08-26 18:19:43 +02:00
if not src:
raise HTTPException(
status_code=400, detail="Missing required field 'source'"
)
if not dst:
raise HTTPException(
status_code=400, detail="Missing required field 'destination'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 3. Iterate over all endpoints to copy the model on each endpoint
status_list = []
2025-08-26 18:19:43 +02:00
for endpoint in config.endpoints:
2025-09-05 12:11:31 +02:00
if "/v1" not in endpoint:
client = ollama.AsyncClient(host=endpoint)
# 4. Proxy a simple copy request
copy = await client.copy(source=src, destination=dst)
status_list.append(copy.status)
2025-08-26 18:19:43 +02:00
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
2025-09-05 12:11:31 +02:00
return Response(status_code=404 if 404 in status_list else 200)
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 13. API route Delete
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.delete("/api/delete")
2025-09-05 12:11:31 +02:00
async def delete_proxy(request: Request, model: Optional[str] = None):
2025-08-26 18:19:43 +02:00
"""
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
2025-09-05 12:11:31 +02:00
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-09-05 12:11:31 +02:00
model = payload.get("model")
2025-08-26 18:19:43 +02:00
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to delete the model on each endpoint
status_list = []
2025-08-26 18:19:43 +02:00
for endpoint in config.endpoints:
2025-09-05 12:11:31 +02:00
if "/v1" not in endpoint:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple copy request
copy = await client.delete(model=model)
status_list.append(copy.status)
2025-08-26 18:19:43 +02:00
2025-10-28 11:08:52 +01:00
# 4. Return 200 0K, if a single enpoint fails, respond with 404
2025-09-05 12:11:31 +02:00
return Response(status_code=404 if 404 in status_list else 200)
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 14. API route Pull
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/pull")
2025-09-05 12:11:31 +02:00
async def pull_proxy(request: Request, model: Optional[str] = None):
2025-08-26 18:19:43 +02:00
"""
Proxy a pull request to all Ollama endpoint and report status back.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
2025-09-05 12:11:31 +02:00
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-09-05 12:11:31 +02:00
model = payload.get("model")
insecure = payload.get("insecure")
else:
insecure = None
2025-08-26 18:19:43 +02:00
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to pull the model
status_list = []
2025-08-26 18:19:43 +02:00
for endpoint in config.endpoints:
2025-09-05 12:11:31 +02:00
if "/v1" not in endpoint:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple pull request
pull = await client.pull(model=model, insecure=insecure, stream=False)
status_list.append(pull)
2025-08-26 18:19:43 +02:00
combined_status = []
for status in status_list:
combined_status += status
# 4. Report back a deduplicated status message
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
# -------------------------------------------------------------
# 15. API route Push
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.post("/api/push")
async def push_proxy(request: Request):
"""
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
2025-08-26 18:19:43 +02:00
model = payload.get("model")
insecure = payload.get("insecure")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
2025-08-26 18:19:43 +02:00
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints
status_list = []
2025-08-26 18:19:43 +02:00
for endpoint in config.endpoints:
client = ollama.AsyncClient(host=endpoint)
# 3. Proxy a simple push request
push = await client.push(model=model, insecure=insecure, stream=False)
status_list.append(push)
combined_status = []
for status in status_list:
combined_status += status
# 4. Report a deduplicated status
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
# -------------------------------------------------------------
# 16. API route Version
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.get("/api/version")
async def version_proxy(request: Request):
"""
Proxy a version request to Ollama and reply lowest version of all endpoints.
"""
# 1. Query all endpoints for version
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions_raw = await asyncio.gather(*tasks)
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
if not all_versions:
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
2025-09-05 12:11:31 +02:00
2025-08-26 18:19:43 +02:00
def version_key(v):
return tuple(map(int, v.split('.')))
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
return JSONResponse(
content={"version": str(min(all_versions, key=version_key))},
status_code=200,
)
# -------------------------------------------------------------
# 17. API route tags
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.get("/api/tags")
async def tags_proxy(request: Request):
"""
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
"""
2025-09-05 12:11:31 +02:00
2025-08-26 18:19:43 +02:00
# 1. Query all endpoints for models
tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep]
# Also query llama-server endpoints not already covered by config.endpoints
llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags]
2025-08-26 18:19:43 +02:00
all_models = await asyncio.gather(*tasks)
2025-08-26 18:19:43 +02:00
models = {'models': []}
for modellist in all_models:
for model in modellist:
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
2025-09-15 19:12:00 +02:00
model['model'] = model['id'] + ":latest"
else:
model['id'] = model['model']
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
model['name'] = model['model']
else:
model['id'] = model['model']
2025-08-26 18:19:43 +02:00
models['models'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
2025-09-05 12:11:31 +02:00
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
2025-08-26 18:19:43 +02:00
status_code=200,
)
# -------------------------------------------------------------
# 18. API route ps
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.get("/api/ps")
async def ps_proxy(request: Request):
"""
Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models.
2025-08-26 18:19:43 +02:00
For Ollama endpoints: queries /api/ps
For llama-server endpoints: queries /v1/models with status.value == "loaded"
2025-08-26 18:19:43 +02:00
"""
# 1. Query Ollama endpoints for running models via /api/ps
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query llama-server endpoints for loaded models via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
for ep in all_llama_endpoints
]
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
2025-08-26 18:19:43 +02:00
models = {'models': []}
# Add Ollama models (if any)
if ollama_loaded:
for modellist in ollama_loaded:
models['models'] += modellist
# Add llama-server models (filter for loaded only, if any)
if llama_loaded:
for modellist in llama_loaded:
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
# Convert llama-server format to Ollama-like format for consistency
for item in loaded_models:
raw_id = item.get("id", "")
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
models['models'].append({
"name": normalized,
"id": normalized,
"digest": "",
"status": item.get("status"),
"details": {"quantization_level": quant} if quant else {}
})
2025-08-26 18:19:43 +02:00
# 3. Return a JSONResponse with deduplicated currently deployed models
# Deduplicate on 'name' rather than 'digest': llama-server models always
# have digest="" so deduping on digest collapses all of them to one entry.
2025-08-26 18:19:43 +02:00
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['name'])},
2025-08-26 18:19:43 +02:00
status_code=200,
)
# -------------------------------------------------------------
# 18b. API route ps details (backwards compatible)
# -------------------------------------------------------------
@app.get("/api/ps_details")
async def ps_details_proxy(request: Request):
"""
Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances.
This keeps /api/ps backward compatible while providing richer data.
For Ollama endpoints: queries /api/ps
For llama-server endpoints: queries /v1/models with status info
"""
# 1. Query Ollama endpoints via /api/ps
ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep]
# 2. Query llama-server endpoints via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
(ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8))
for ep in all_llama_endpoints
]
ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else []
llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else []
models: list[dict] = []
# Add Ollama models with endpoint info (if any)
if ollama_loaded:
for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded):
for model in modellist:
if isinstance(model, dict):
model_with_endpoint = dict(model)
model_with_endpoint["endpoint"] = endpoint
models.append(model_with_endpoint)
# Add llama-server models with endpoint info and full status metadata (if any)
if llama_loaded:
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
props_requests: list[tuple[str, str]] = []
llama_models_pending: list[dict] = []
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
# Include sleeping models too so _fetch_llama_props can unload them
loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)]
for item in loaded_models:
if isinstance(item, dict) and item.get("id"):
raw_id = item["id"]
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
model_with_endpoint = {
"name": normalized,
"id": normalized,
"original_name": raw_id,
"digest": "",
"details": {"quantization_level": quant} if quant else {},
"endpoint": endpoint,
"status": item.get("status"),
"created": item.get("created"),
"owned_by": item.get("owned_by")
}
# Include full llama-server status details (args, preset)
status_info = item.get("status", {})
if isinstance(status_info, dict):
model_with_endpoint["llama_status_args"] = status_info.get("args")
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
llama_models_pending.append(model_with_endpoint)
props_requests.append((endpoint, raw_id))
# Fetch /props for each llama-server model to get context length (n_ctx)
# and unload sleeping models automatically
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
client: aiohttp.ClientSession = get_session(endpoint)
base_url = endpoint.rstrip("/").removesuffix("/v1")
props_url = f"{base_url}/props?model={model_id}"
headers = None
api_key = config.api_keys.get(endpoint)
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
try:
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status == 200:
data = await resp.json()
dgs = data.get("default_generation_settings", {})
n_ctx = dgs.get("n_ctx")
is_sleeping = data.get("is_sleeping", False)
# Embedding models have no sampling params in default_generation_settings
is_generation = "temperature" in dgs
if is_sleeping:
unload_url = f"{base_url}/models/unload"
try:
async with client.post(
unload_url,
json={"model": model_id},
headers=headers,
) as unload_resp:
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
except Exception as ue:
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
return n_ctx, is_sleeping, is_generation
except Exception as e:
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
return None, False, False
props_results = await asyncio.gather(
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
)
for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results):
if n_ctx is not None:
model_dict["context_length"] = n_ctx
if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
normalized = _normalize_llama_model_name(raw_id)
_endpoint_nctx[(ep, normalized)] = n_ctx
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
if not is_sleeping:
models.append(model_dict)
return JSONResponse(content={"models": models}, status_code=200)
# -------------------------------------------------------------
# 18b. Conversation-affinity stats feeds the PS-table dot matrix
# -------------------------------------------------------------
@app.get("/api/affinity_stats")
async def affinity_stats(request: Request):
"""
Aggregate live conversation-affinity pins, one entry per pinned conversation.
Each entry exposes only the endpoint, model, and remaining TTL in seconds
no fingerprints or content. When conversation_affinity is disabled the
`entries` list is always empty.
"""
if not config.conversation_affinity:
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
now = time.monotonic()
entries: list[dict] = []
llama_eps = set(config.llama_server_endpoints)
async with _affinity_lock:
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
remaining = expires_at - now
if remaining <= 0:
_affinity_map.pop(fp, None)
continue
# Mirror the normalisation used by /api/ps_details so the dashboard
# can join affinity entries to PS rows by (endpoint, model).
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
entries.append({
"endpoint": ep,
"model": display_model,
"remaining": round(remaining, 2),
})
return {
"enabled": True,
"ttl": config.conversation_affinity_ttl,
"entries": entries,
}
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
# 19. Proxy usage route for monitoring
# -------------------------------------------------------------
@app.get("/api/usage")
async def usage_proxy(request: Request):
"""
Return a snapshot of the usage counter for each endpoint.
Useful for debugging / monitoring.
"""
return {"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts}
2026-05-19 12:05:51 +02:00
from backends.probe import _raw_probe, _endpoint_health
# -------------------------------------------------------------
# 20b. Proxy config route for monitoring and frontend usage
# -------------------------------------------------------------
@app.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
Ollama endpoints and llama_server_endpoints. The frontend uses this
to display which endpoints are being proxied and their health.
Status is "error" when either liveness (/api/version) or routing
health (/api/ps) fails see issue #83.
"""
async def check(url: str) -> dict:
return {"url": url, **(await _endpoint_health(url, timeout=5))}
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
llama_results = []
if config.llama_server_endpoints:
llama_results = await asyncio.gather(
*[check(ep) for ep in config.llama_server_endpoints]
)
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
"require_router_api_key": bool(config.router_api_key),
}
# -------------------------------------------------------------
# 21. API route OpenAI compatible Embedding
# -------------------------------------------------------------
@app.post("/v1/embeddings")
async def openai_embedding_proxy(request: Request):
"""
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
2025-09-11 13:56:51 +02:00
doc = payload.get("input")
2025-09-05 12:11:31 +02:00
# Normalize multimodal input: extract only text parts for embedding models
if isinstance(doc, list):
normalized = []
for item in doc:
if isinstance(item, dict):
# Multimodal content part - extract text only, skip images
if item.get("type") == "text":
normalized.append(item.get("text", ""))
# Skip image_url and other non-text types
else:
normalized.append(item)
doc = normalized if len(normalized) != 1 else normalized[0]
elif isinstance(doc, dict) and doc.get("type") == "text":
doc = doc.get("text", "")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
2025-09-11 13:56:51 +02:00
if not doc:
raise HTTPException(
status_code=400, detail="Missing required field 'input'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
2025-09-05 12:11:31 +02:00
else:
api_key = "ollama"
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key)
try:
async_gen = await oclient.embeddings.create(input=doc, model=model)
result = async_gen.model_dump()
for item in result.get("data", []):
emb = item.get("embedding")
if emb:
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
return JSONResponse(content=result)
finally:
await decrement_usage(endpoint, tracking_model)
# -------------------------------------------------------------
# 22. API route OpenAI compatible Chat Completions
# -------------------------------------------------------------
@app.post("/v1/chat/completions")
async def openai_chat_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
response_format = payload.get("response_format")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
2025-09-05 12:11:31 +02:00
max_completion_tokens = payload.get("max_completion_tokens")
tools = payload.get("tools")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing required field 'messages' (must be a list)"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
2025-09-11 13:56:51 +02:00
optional_params = {
"tools": tools,
"response_format": response_format,
"stream_options": stream_options or {"include_usage": True },
2025-09-11 13:56:51 +02:00
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"seed": seed,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"stop": stop,
"stream": stream,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
2025-09-11 13:56:51 +02:00
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Reject unsupported image formats (SVG) before doing any work
for _msg in messages:
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
if _item.get("type") == "image_url":
_url = (_item.get("image_url") or {}).get("url", "")
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
raise HTTPException(
status_code=400,
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
)
2026-03-08 09:12:09 +01:00
# Cache lookup — before endpoint selection
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
_cached = await _cache.get_chat("openai_chat", model, messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ochat_stream():
yield _sse
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ochat_json():
yield _cached
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Helpers and API call — done in handler scope so try/except works reliably
async def _normalize_images_in_messages(msgs: list) -> list:
"""Fetch remote image URLs and convert them to base64 data URLs so
Ollama/llama-server can handle them without making outbound HTTP requests."""
resolved = []
for msg in msgs:
content = msg.get("content")
if not isinstance(content, list):
resolved.append(msg)
continue
new_content = []
for item in content:
if item.get("type") == "image_url":
url = (item.get("image_url") or {}).get("url", "")
if url and not url.startswith("data:"):
try:
http: aiohttp.ClientSession = app_state["session"]
async with http.get(url) as resp:
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
img_bytes = await resp.read()
b64 = base64.b64encode(img_bytes).decode("utf-8")
new_content.append({
"type": "image_url",
"image_url": {"url": f"data:{ctype};base64,{b64}"}
})
except Exception as _ie:
print(f"[image] Failed to fetch image URL: {_ie}")
new_content.append(item)
else:
new_content.append(item)
else:
new_content.append(item)
resolved.append({**msg, "content": new_content})
return resolved
# Make the API call in handler scope — try/except inside async generators is unreliable
# with Starlette's streaming machinery, so we resolve errors here before the generator starts.
send_params = params
if not is_ext_openai_endpoint(endpoint):
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
send_params = {**params, "messages": resolved_msgs}
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2)
_pre_est = _count_message_tokens(send_params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = send_params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
send_params = {**send_params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**send_params)
except Exception as e:
_e_str = str(e)
_is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str
print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True)
if "does not support tools" in _e_str:
# Model doesn't support tools — retry without them
print(f"[ochat] retry: no tools", flush=True)
try:
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
async_gen = await oclient.chat.completions.create(**params_without_tools)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_ctx_err:
# Backend context limit hit — apply sliding-window trim (context-shift at message level)
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
# Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors)
if not n_ctx_limit:
import re as _re
_m = _re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True)
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = send_params.get("messages", [])
try:
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
except Exception as _helper_exc:
print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True)
await decrement_usage(endpoint, tracking_model)
raise
dropped = len(msgs_to_trim) - len(trimmed_messages)
print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True)
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages})
print(f"[ctx-trim] retry-1 ok", flush=True)
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
# Still too large — tool definitions likely consuming too many tokens, strip them too
print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True)
params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages})
print(f"[ctx-trim] retry-2 ok", flush=True)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
# Upstream connection failed (e.g. llama-server in router mode
# whose delegated worker died). Mark (endpoint, model) so the
# next request reroutes; the client will retry this one.
print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
# Model doesn't support images — strip and retry
print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
# 4. Async generator — only streams the already-established async_gen
async def stream_ochat_response():
try:
if stream == True:
2026-03-08 09:12:09 +01:00
content_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
delta = chunk.choices[0].delta
has_content = delta.content is not None
has_reasoning = (
getattr(delta, "reasoning_content", None) is not None
or getattr(delta, "reasoning", None) is not None
)
has_tool_calls = getattr(delta, "tool_calls", None) is not None
if has_content or has_reasoning or has_tool_calls:
yield f"data: {data}\n\n".encode("utf-8")
2026-03-08 09:12:09 +01:00
if has_content and delta.content:
content_parts.append(delta.content)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
2026-03-08 09:12:09 +01:00
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Detect context exhaustion mid-generation for small-ctx models.
# Guard: skip if max_tokens was set in the request — finish_reason=length
# could just mean the caller's token budget was exhausted, not the context window.
_req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens")
if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok:
_inferred_nctx = (prompt_tok + comp_tok) or 0
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
2026-03-08 09:12:09 +01:00
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and content_parts:
2026-03-08 09:12:09 +01:00
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_chat", model, messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
2025-09-23 17:33:15 +02:00
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
2026-03-08 09:12:09 +01:00
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
try:
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ochat_response(),
media_type="text/event-stream" if stream else "application/json",
)
# -------------------------------------------------------------
# 23. API route OpenAI compatible Completions
# -------------------------------------------------------------
@app.post("/v1/completions")
async def openai_completions_proxy(request: Request):
"""
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
"""
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
frequency_penalty = payload.get("frequency_penalty")
presence_penalty = payload.get("presence_penalty")
seed = payload.get("seed")
stop = payload.get("stop")
stream = payload.get("stream")
stream_options = payload.get("stream_options")
temperature = payload.get("temperature")
top_p = payload.get("top_p")
max_tokens = payload.get("max_tokens")
2025-09-05 12:11:31 +02:00
max_completion_tokens = payload.get("max_completion_tokens")
suffix = payload.get("suffix")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
2025-09-11 13:56:51 +02:00
optional_params = {
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"seed": seed,
"stop": stop,
"stream": stream,
"stream_options": stream_options or {"include_usage": True },
2025-09-11 13:56:51 +02:00
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"suffix": suffix
}
params.update({k: v for k, v in optional_params.items() if v is not None})
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
2026-03-08 09:12:09 +01:00
# Cache lookup — completions prompt mapped to a single-turn messages list
_cache = get_llm_cache()
_compl_messages = [{"role": "user", "content": prompt}]
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
if _cached is not None:
if stream:
_sse = openai_nonstream_to_sse(_cached, model)
async def _serve_cached_ocompl_stream():
yield _sse
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
else:
async def _serve_cached_ocompl_json():
yield _cached
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
# 2. Endpoint logic
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
# Make the API call in handler scope (try/except inside async generators is unreliable)
try:
async_gen = await oclient.completions.create(**params)
except Exception as e:
if _is_backend_connection_error(e):
print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, str(e))
await decrement_usage(endpoint, tracking_model)
raise
async def stream_ocompletions_response(model=model):
try:
if stream == True:
2026-03-08 09:12:09 +01:00
text_parts: list[str] = []
usage_snapshot: dict = {}
async for chunk in async_gen:
data = (
chunk.model_dump_json()
if hasattr(chunk, "model_dump_json")
else orjson.dumps(chunk)
)
if chunk.choices:
choice = chunk.choices[0]
has_text = getattr(choice, "text", None) is not None
has_reasoning = (
getattr(choice, "reasoning_content", None) is not None
or getattr(choice, "reasoning", None) is not None
)
if has_text or has_reasoning or choice.finish_reason is not None:
yield f"data: {data}\n\n".encode("utf-8")
2026-03-08 09:12:09 +01:00
if has_text and choice.text:
text_parts.append(choice.text)
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
2026-03-08 09:12:09 +01:00
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
2026-03-08 09:12:09 +01:00
# Cache assembled streaming response — before [DONE] so it always runs
if _cache is not None and _cache_enabled and text_parts:
2026-03-08 09:12:09 +01:00
assembled = orjson.dumps({
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
**({"usage": usage_snapshot} if usage_snapshot else {}),
}) + b"\n"
try:
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
# Final DONE event
yield b"data: [DONE]\n\n"
else:
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
2026-03-08 09:12:09 +01:00
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
2026-03-08 09:12:09 +01:00
try:
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ocompletions_response(),
media_type="text/event-stream" if stream else "application/json",
)
# -------------------------------------------------------------
# 24. OpenAI API compatible models endpoint
# -------------------------------------------------------------
@app.get("/v1/models")
async def openai_models_proxy(request: Request):
"""
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
For Ollama endpoints: queries /api/tags (all models)
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
"""
# 1. Query Ollama endpoints for all models via /api/tags
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
# 3. Query llama-server endpoints for loaded models via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
for ep in all_llama_endpoints
]
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
models = {'data': []}
# Add Ollama models (if any)
if ollama_models:
for modellist in ollama_models:
for model in modellist:
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add external OpenAI models (if any)
if ext_openai_models:
for modellist in ext_openai_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# Add llama-server models (all available, not just loaded)
if llama_models:
for modellist in llama_models:
for model in modellist:
if not "id" in model.keys():
model['id'] = model.get('name', model.get('id', ''))
else:
model['name'] = model['id']
models['data'].append(model)
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"data": dedupe_on_keys(models['data'], ['name'])},
status_code=200,
)
# -------------------------------------------------------------
# 25. API route OpenAI/Jina/Cohere compatible Rerank
# -------------------------------------------------------------
@app.post("/v1/rerank")
@app.post("/rerank")
async def rerank_proxy(request: Request):
"""
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
Compatible with the Jina/Cohere rerank API convention used by llama-server,
vLLM, and services such as Cohere and Jina AI.
Ollama does not natively support reranking; requests routed to a plain Ollama
endpoint will receive a 501 Not Implemented response.
Request body:
model (str, required) reranker model name
query (str, required) search query
documents (list[str], required) candidate documents to rank
top_n (int, optional) limit returned results (default: all)
return_documents (bool, optional) include document text in results
max_tokens_per_doc (int, optional) truncation limit per document
Response (Jina/Cohere-compatible):
{
"id": "...",
"model": "...",
"usage": {"prompt_tokens": N, "total_tokens": N},
"results": [{"index": 0, "relevance_score": 0.95}, ...]
}
"""
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
query = payload.get("query")
documents = payload.get("documents")
if not model:
raise HTTPException(status_code=400, detail="Missing required field 'model'")
if not query:
raise HTTPException(status_code=400, detail="Missing required field 'query'")
if not isinstance(documents, list) or not documents:
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Determine which endpoint serves this model
try:
endpoint, tracking_model = await choose_endpoint(model)
except RuntimeError as e:
raise HTTPException(status_code=404, detail=str(e))
# Ollama endpoints have no native rerank support
if not is_openai_compatible(endpoint):
await decrement_usage(endpoint, tracking_model)
raise HTTPException(
status_code=501,
detail=(
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
"dedicated reranker model."
),
)
if ":latest" in model:
model = model.split(":latest")[0]
# Build upstream rerank request body forward only recognised fields
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
if optional_key in payload:
upstream_payload[optional_key] = payload[optional_key]
# Determine upstream URL:
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
# External OpenAI endpoints expose /rerank under their /v1 base
if endpoint in config.llama_server_endpoints:
# llama-server: endpoint may or may not already contain /v1
if "/v1" in endpoint:
rerank_url = f"{endpoint}/rerank"
else:
rerank_url = f"{endpoint}/v1/rerank"
else:
# External OpenAI-compatible: ep2base gives us the /v1 base
rerank_url = f"{ep2base(endpoint)}/rerank"
api_key = config.api_keys.get(endpoint, "no-key")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
client: aiohttp.ClientSession = get_session(endpoint)
try:
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
response_bytes = await resp.read()
if resp.status >= 400:
raise HTTPException(
status_code=resp.status,
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
)
data = orjson.loads(response_bytes)
# Record token usage if the upstream returned a usage object
usage = data.get("usage") or {}
prompt_tok = usage.get("prompt_tokens") or 0
total_tok = usage.get("total_tokens") or 0
# For reranking there are no completion tokens; we record prompt tokens only
if prompt_tok or total_tok:
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
return JSONResponse(content=data)
finally:
await decrement_usage(endpoint, tracking_model)
2026-03-08 09:12:09 +01:00
# -------------------------------------------------------------
# 25b. Cache management endpoints
# -------------------------------------------------------------
@app.get("/api/cache/stats")
async def cache_stats():
"""Return hit/miss counters and configuration for the LLM response cache."""
c = get_llm_cache()
if c is None:
return {"enabled": False}
return {"enabled": True, **c.stats()}
@app.post("/api/cache/invalidate")
async def cache_invalidate():
"""Clear all entries from the LLM response cache and reset counters."""
c = get_llm_cache()
if c is None:
return {"enabled": False, "cleared": False}
await c.clear()
return {"enabled": True, "cleared": True}
# -------------------------------------------------------------
# 26. Serve the static frontend
# -------------------------------------------------------------
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/favicon.ico")
async def redirect_favicon():
return RedirectResponse(url="/static/favicon.ico")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""
2025-09-05 12:11:31 +02:00
Render the dynamic NOMYO Router dashboard listing the configured endpoints
and the models details, availability & task status.
"""
index_path = STATIC_DIR / "index.html"
try:
return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Page not found")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
# -------------------------------------------------------------
# 26. Healthendpoint
# -------------------------------------------------------------
@app.get("/health")
async def health_proxy(request: Request):
"""
Healthcheck endpoint for monitoring the proxy.
* Queries each configured endpoint for both liveness and routing health:
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
OpenAI-compatible endpoints at `/models`.
* Returns a JSON object containing:
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
- `endpoints`: a mapping of endpoint URL `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
# Run all health checks in parallel.
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
# (vLLM, llama-server, external) expose /models, which serves both
# purposes. Probing /api/version alone would miss the case where the
# Ollama process is up but /api/ps is failing — see issue #83.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints += llama_eps_extra
probe_results = await asyncio.gather(
*(_endpoint_health(ep) for ep in all_endpoints),
)
health_summary = dict(zip(all_endpoints, probe_results))
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
response_payload = {
"status": "ok" if overall_ok else "error",
"endpoints": health_summary,
}
http_status = 200 if overall_ok else 503
return JSONResponse(content=response_payload, status_code=http_status)
# -------------------------------------------------------------
2026-04-10 17:29:43 +02:00
# 27. Hostname endpoint
# -------------------------------------------------------------
@app.get("/api/hostname")
async def get_hostname():
"""Return the hostname of the machine running the router."""
return JSONResponse(content={"hostname": socket.gethostname()})
# -------------------------------------------------------------
# 28. SSE route for usage broadcasts
2025-09-05 12:11:31 +02:00
# -------------------------------------------------------------
@app.get("/api/usage-stream")
async def usage_stream(request: Request):
"""
ServerSentEvents that emits a JSON payload every time the
global `usage_counts` dictionary changes.
"""
async def event_generator():
# The queue that receives *every* new snapshot
queue = await subscribe()
try:
while True:
# If the client disconnects, cancel the loop
if await request.is_disconnected():
break
data = await queue.get()
if data is None:
break
2025-09-05 12:11:31 +02:00
# Send the data as a single SSE message
yield f"data: {data}\n\n"
finally:
# Cleanup: unsubscribe from the broadcast channel
await unsubscribe(queue)
return StreamingResponse(event_generator(), media_type="text/event-stream")
# -------------------------------------------------------------
# 28. FastAPI startup/shutdown events
2025-08-26 18:19:43 +02:00
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event() -> None:
global config, db, token_worker_task, flush_task
2025-08-26 18:19:43 +02:00
# Load YAML config (or use defaults if not present)
config_path = _config_path_from_env()
config = Config.from_yaml(config_path)
if config_path.exists():
print(
f"Loaded configuration from {config_path}:\n"
f" endpoints={config.endpoints},\n"
f" llama_server_endpoints={config.llama_server_endpoints},\n"
2026-04-22 17:27:34 +02:00
f" max_concurrent_connections={config.max_concurrent_connections},\n"
f" endpoint_config={config.endpoint_config},\n"
f" priority_routing={config.priority_routing}"
)
else:
print(
f"No configuration file found at {config_path}. "
"Falling back to default settings."
)
# Initialize database
db = TokenDatabase(config.db_path)
await db.init_db()
# Load existing token counts from database
async for count_entry in db.load_token_counts():
endpoint = count_entry['endpoint']
model = count_entry['model']
input_tokens = count_entry['input_tokens']
output_tokens = count_entry['output_tokens']
total_tokens = count_entry['total_tokens']
token_usage_counts[endpoint][model] = total_tokens
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context)
timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15)
2026-05-08 12:15:51 +02:00
session = aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")},
)
app_state["connector"] = connector
app_state["session"] = session
2026-05-08 12:15:51 +02:00
# Create httpx clients for external OpenAI endpoints (Google, etc.)
# aiohttp strips Referer headers for cross-origin requests, so we use httpx
for ep in config.endpoints:
if is_ext_openai_endpoint(ep):
app_state["httpx_clients"][ep] = httpx.AsyncClient(timeout=30.0)
# Create per-endpoint Unix socket sessions for .sock endpoints
for ep in config.llama_server_endpoints:
if _is_unix_socket_endpoint(ep):
sock_path = _get_socket_path(ep)
sock_connector = aiohttp.UnixConnector(path=sock_path)
sock_timeout = aiohttp.ClientTimeout(total=300, connect=5, sock_read=300)
sock_session = aiohttp.ClientSession(connector=sock_connector, timeout=sock_timeout)
app_state["socket_sessions"][ep] = sock_session
transport = httpx.AsyncHTTPTransport(uds=sock_path)
app_state["httpx_clients"][ep] = httpx.AsyncClient(transport=transport, timeout=300.0)
print(f"[startup] Unix socket session: {ep} -> {sock_path}")
token_worker_task = asyncio.create_task(token_worker())
flush_task = asyncio.create_task(flush_buffer())
2026-03-08 09:12:09 +01:00
await init_llm_cache(config)
@app.on_event("shutdown")
async def shutdown_event() -> None:
await close_all_sse_queues()
# Stop background tasks first so they stop touching the DB before we close it.
for t in (token_worker_task, flush_task):
if t is not None:
t.cancel()
try:
await t
except (asyncio.CancelledError, Exception):
pass
await flush_remaining_buffers()
await app_state["session"].close()
# Close Unix socket sessions
for ep, sess in list(app_state.get("socket_sessions", {}).items()):
try:
await sess.close()
print(f"[shutdown] Closed Unix socket session: {ep}")
except Exception as e:
print(f"[shutdown] Error closing Unix socket session {ep}: {e}")
# Close httpx Unix socket clients
for ep, client in list(app_state.get("httpx_clients", {}).items()):
try:
await client.aclose()
print(f"[shutdown] Closed httpx client: {ep}")
except Exception as e:
print(f"[shutdown] Error closing httpx client {ep}: {e}")
# Close the aiosqlite connection last — its worker thread is non-daemon
# and would otherwise keep the interpreter alive after lifespan completes.
if db is not None:
try:
await db.close()
print("[shutdown] Closed token DB connection.")
except Exception as e:
print(f"[shutdown] Error closing DB: {e}")