nomyo-router/api/ollama.py

1248 lines
55 KiB
Python
Raw Permalink Normal View History

2026-05-19 14:57:39 +02:00
"""Ollama-native API routes (``/api/*``).
These are the ``/api/generate``, ``/api/chat``, ``/api/embed(dings)`` and the
model-management routes (``/api/create``, ``/api/show``, ``/api/copy``,
``/api/delete``, ``/api/pull``, ``/api/push``, ``/api/version``,
``/api/tags``, ``/api/ps``, ``/api/ps_details``) that the Ollama clients
expect. The chat/generate handlers also serve OpenAI-compatible endpoints
when ``is_openai_compatible(endpoint)`` is true in that case they
translate the request to the OpenAI Chat Completions / Completions API and
``rechunk`` the response back into Ollama wire format.
"""
import asyncio
import re
import time
from typing import Optional
from urllib.parse import quote
2026-05-19 14:57:39 +02:00
import aiohttp
import ollama
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, Response, StreamingResponse
from cache import get_llm_cache
from config import get_config
from context_window import (
_count_message_tokens,
_trim_messages_for_context,
_calibrated_trim_target,
_endpoint_nctx,
_CTX_TRIM_SMALL_LIMIT,
)
from fingerprint import _conversation_fingerprint
from state import token_queue, default_headers
from backends.health import (
_is_backend_connection_error,
_is_llama_model_loaded,
_is_llama_model_loaded_or_sleeping,
_mark_backend_unhealthy,
)
from backends.normalize import (
dedupe_on_keys,
is_openai_compatible,
2026-06-14 16:34:31 +02:00
is_llama_server,
llama_endpoints,
2026-05-19 14:57:39 +02:00
_normalize_llama_model_name,
_extract_llama_quant,
)
2026-06-14 16:34:31 +02:00
from backends.control import unload_model
2026-05-19 14:57:39 +02:00
from backends.probe import fetch
from backends.sessions import _make_openai_client, get_ollama_client, get_probe_session
2026-05-19 14:57:39 +02:00
from requests.chat import _make_moe_requests
from requests.messages import (
transform_images_to_data_urls,
transform_tool_calls_to_openai,
_strip_assistant_prefill,
_strip_images_from_messages,
_accumulate_openai_tc_delta,
_build_ollama_tool_calls,
)
from requests.rechunk import rechunk
from routing import choose_endpoint, decrement_usage
router = APIRouter()
async def _handle_stream_error(
exc: Exception, endpoint: str, model: str, *, context: str
) -> bytes:
"""Surface an upstream backend error transitively from a streaming generator.
Errors raised while iterating a backend response (e.g. an ollama
``ResponseError`` for a 504 Gateway Time-out) would otherwise escape the
StreamingResponse generator and be dumped by Starlette as an opaque
"Exception in ASGI application" traceback with no indication of which
endpoint/model failed. This logs the failure with that context which is
what makes the many timeout errors greppable and analyzable marks the
backend unhealthy when it is a connection-class failure, and returns a
terminal Ollama-format ``{"error": ...}`` line so the client receives a
meaningful error instead of a silently truncated stream.
"""
status_code = getattr(exc, "status_code", None)
err_msg = getattr(exc, "error", None) or str(exc)
print(
f"[{context}] upstream error from ({endpoint}, {model}) "
f"status={status_code} type={type(exc).__name__}: {str(err_msg)[:500]}",
flush=True,
)
if _is_backend_connection_error(exc):
await _mark_backend_unhealthy(endpoint, model, str(err_msg))
err_payload = {"error": str(err_msg)}
if status_code is not None:
err_payload["status_code"] = status_code
return orjson.dumps(err_payload) + b"\n"
async def _guarded_stream(inner, *, endpoint: str, model: str, tracking_model: str, context: str):
"""Wrap a per-route body generator with the shared streaming contract.
Every ``/api/*`` streaming handler needs the same three guarantees around its
body: surface backend errors transitively (via :func:`_handle_stream_error`),
let client-disconnect cancellation propagate untouched, and always decrement
the usage counter. Centralising them here keeps the four bodies free of
duplicated ``try/except/finally`` scaffolding.
"""
try:
async for item in inner:
yield item
except asyncio.CancelledError:
raise
except Exception as e:
try:
yield await _handle_stream_error(e, endpoint, model, context=context)
except Exception:
pass
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
2026-05-19 14:57:39 +02:00
@router.post("/api/generate")
async def proxy(request: Request):
"""
Proxy a generate request to Ollama and stream the response back to the client.
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
suffix = payload.get("suffix")
system = payload.get("system")
template = payload.get("template")
context = payload.get("context")
stream = payload.get("stream")
think = payload.get("think")
raw = payload.get("raw")
_format = payload.get("format")
images = payload.get("images")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except orjson.JSONDecodeError as e:
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
raise HTTPException(status_code=400, detail=error_msg) from e
# Cache lookup — before endpoint selection so no slot is wasted on a hit
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
_cached = await _cache.get_generate(model, prompt, system or "")
if _cached is not None:
async def _serve_cached_generate():
yield _cached
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
optional_params = {
"stream": stream,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"suffix": suffix,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# 4. Async generator body (error handling + cleanup handled by _guarded_stream)
2026-05-19 14:57:39 +02:00
async def stream_generate_response():
if use_openai:
start_ts = time.perf_counter()
async_gen = await oclient.completions.create(**params)
else:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
if stream == True:
content_parts: list[str] = []
async for chunk in async_gen:
2026-05-19 14:57:39 +02:00
if use_openai:
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
2026-05-19 14:57:39 +02:00
else:
json_line = orjson.dumps(chunk)
# Accumulate and store cache on done chunk — before yield so it always runs
2026-05-19 14:57:39 +02:00
if _cache is not None and _cache_enabled:
if getattr(chunk, "response", None):
content_parts.append(chunk.response)
if getattr(chunk, "done", False):
assembled = orjson.dumps({
k: v for k, v in {
"model": getattr(chunk, "model", model),
"response": "".join(content_parts),
"done": True,
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
"eval_count": getattr(chunk, "eval_count", None),
"total_duration": getattr(chunk, "total_duration", None),
"eval_duration": getattr(chunk, "eval_duration", None),
}.items() if v is not None
}) + b"\n"
try:
await _cache.set_generate(model, prompt, system or "", assembled)
except Exception as _ce:
print(f"[cache] set_generate (streaming) failed: {_ce}")
yield json_line.encode("utf-8") + b"\n"
else:
if use_openai:
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
response = response.model_dump_json()
else:
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_generate(model, prompt, system or "", cache_bytes)
except Exception as _ce:
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
2026-05-19 14:57:39 +02:00
# 5. Return a StreamingResponse backed by the guarded generator
2026-05-19 14:57:39 +02:00
return StreamingResponse(
_guarded_stream(
stream_generate_response(),
endpoint=endpoint, model=model,
tracking_model=tracking_model, context="generate_proxy",
),
2026-05-19 14:57:39 +02:00
media_type="application/json",
)
@router.post("/api/chat")
async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
tools = payload.get("tools")
stream = payload.get("stream")
think = payload.get("think")
_format = payload.get("format")
keep_alive = payload.get("keep_alive")
options = payload.get("options")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
)
if options is not None and not isinstance(options, dict):
raise HTTPException(
status_code=400, detail="`options` must be a JSON object"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Cache lookup — before endpoint selection, always bypassed for MOE
_is_moe = model.startswith("moe-")
_cache = get_llm_cache()
# Normalise model name for cache key: strip ":latest" suffix here so that
# get_chat and set_chat use the same model string regardless of when the
# strip happens further down (line ~1793 strips it for OpenAI endpoints).
_cache_model = model[: -len(":latest")] if model.endswith(":latest") else model
# Snapshot original messages before any OpenAI-format transformation so that
# get_chat and set_chat always use the same key regardless of backend type.
_cache_messages = messages
if _cache is not None and not _is_moe and _cache_enabled:
_cached = await _cache.get_chat("ollama_chat", _cache_model, messages)
if _cached is not None:
async def _serve_cached_chat():
yield _cached
return StreamingResponse(
_serve_cached_chat(),
media_type="application/x-ndjson" if stream else "application/json",
)
# 2. Endpoint logic
if model.startswith("moe-"):
model = model.split("moe-")[1]
opt = True
else:
opt = False
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None),
"top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None),
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# For OpenAI endpoints: make the API call in handler scope
# (try/except inside async generators is unreliable with Starlette's streaming)
start_ts = None
async_gen = None
if use_openai:
start_ts = time.perf_counter()
# Proactive trim: only for small-ctx models we've already seen run out of space
2026-06-14 16:34:31 +02:00
_lookup_model = _normalize_llama_model_name(model) if is_llama_server(endpoint) else model
2026-05-19 14:57:39 +02:00
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
_pre_est = _count_message_tokens(params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
params = {**params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
async_gen = await oclient.chat.completions.create(**params)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
# 3. Async generator body (error handling + cleanup handled by _guarded_stream)
2026-05-19 14:57:39 +02:00
async def stream_chat_response():
# The chat method returns a generator of dicts (or GenerateResponse)
if use_openai:
_async_gen = async_gen # established in handler scope above
else:
if opt == True:
# Use the dedicated MOE helper function
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
2026-05-19 14:57:39 +02:00
else:
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
if stream == True:
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
content_parts: list[str] = []
async for chunk in _async_gen:
2026-05-19 14:57:39 +02:00
if use_openai:
_accumulate_openai_tc_delta(chunk, tc_acc)
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
# Inject fully-accumulated tool calls only into the final chunk
if chunk.done and tc_acc and chunk.message:
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
# `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
2026-05-19 14:57:39 +02:00
else:
json_line = orjson.dumps(chunk)
# Accumulate and store cache on done chunk — before yield so it always runs
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
# already converted to Ollama format by rechunk before this point.
if getattr(chunk, "done", False):
# Detect context exhaustion mid-generation for small-ctx models
_dr = getattr(chunk, "done_reason", None)
# Only cache when no max_tokens limit was set — otherwise
# finish_reason=length might just mean max_tokens was hit,
# not that the context window was exhausted.
_req_max_tok = (
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
if use_openai else
(options.get("num_predict") if options else None)
)
if _dr == "length" and not _req_max_tok:
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
_ct = getattr(chunk, "eval_count", 0) or 0
_inferred_nctx = _pt + _ct
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
2026-05-19 14:57:39 +02:00
if _cache is not None and not _is_moe and _cache_enabled:
if chunk.message and getattr(chunk.message, "content", None):
content_parts.append(chunk.message.content)
if getattr(chunk, "done", False):
assembled = orjson.dumps({
k: v for k, v in {
"model": getattr(chunk, "model", model),
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
"message": {"role": "assistant", "content": "".join(content_parts)},
"done": True,
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
"eval_count": getattr(chunk, "eval_count", None),
"total_duration": getattr(chunk, "total_duration", None),
"eval_duration": getattr(chunk, "eval_duration", None),
}.items() if v is not None
}) + b"\n"
try:
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
yield json_line.encode("utf-8") + b"\n"
else:
if use_openai:
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
response = response.model_dump_json()
else:
response = _async_gen.model_dump_json()
prompt_tok = _async_gen.prompt_eval_count or 0
comp_tok = _async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(_async_gen, "model_dump_json")
else orjson.dumps(_async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
if _cache is not None and not _is_moe and _cache_enabled:
try:
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
2026-05-19 14:57:39 +02:00
# 4. Return a StreamingResponse backed by the guarded generator
2026-05-19 14:57:39 +02:00
media_type = "application/x-ndjson" if stream else "application/json"
return StreamingResponse(
_guarded_stream(
stream_chat_response(),
endpoint=endpoint, model=model,
tracking_model=tracking_model, context="chat_proxy",
),
2026-05-19 14:57:39 +02:00
media_type=media_type,
)
async def _handle_embedding_request(
request: Request,
*,
input_field: str,
context: str,
make_native,
make_openai,
):
"""Shared implementation for ``/api/embeddings`` and ``/api/embed``.
The two routes differ only in the request field they read (``prompt`` vs
``input``), the ollama SDK method they call, and the OpenAI rechunk helper.
Those are passed in via ``input_field`` and the ``make_native`` /
``make_openai`` callables; everything else parsing, endpoint selection,
serialization, and the streaming error/cleanup contract is shared.
2026-05-19 14:57:39 +02:00
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
value = payload.get(input_field)
truncate = payload.get("truncate")
2026-05-19 14:57:39 +02:00
options = payload.get("options")
keep_alive = payload.get("keep_alive")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not value:
2026-05-19 14:57:39 +02:00
raise HTTPException(
status_code=400, detail=f"Missing required field '{input_field}'"
2026-05-19 14:57:39 +02:00
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = get_ollama_client(endpoint)
# 3. Async generator body (error handling + cleanup handled by _guarded_stream)
2026-05-19 14:57:39 +02:00
async def stream_embedding_response():
if use_openai:
response = await make_openai(client, model, value)
else:
response = await make_native(client, model, value, options, keep_alive, truncate)
if hasattr(response, "model_dump_json"):
json_line = response.model_dump_json()
else:
json_line = orjson.dumps(response)
yield json_line.encode("utf-8") + b"\n"
# 4. Return a StreamingResponse backed by the guarded generator
2026-05-19 14:57:39 +02:00
return StreamingResponse(
_guarded_stream(
stream_embedding_response(),
endpoint=endpoint, model=model,
tracking_model=tracking_model, context=context,
),
2026-05-19 14:57:39 +02:00
media_type="application/json",
)
@router.post("/api/embeddings")
async def embedding_proxy(request: Request):
"""Proxy an embedding request to Ollama and reply with embeddings."""
async def _native(client, model, value, options, keep_alive, truncate):
return await client.embeddings(model=model, prompt=value, options=options, keep_alive=keep_alive)
2026-05-19 14:57:39 +02:00
async def _openai(client, model, value):
return rechunk.openai_embeddings2ollama(await client.embeddings.create(input=value, model=model))
2026-05-19 14:57:39 +02:00
return await _handle_embedding_request(
request, input_field="prompt", context="embeddings_proxy",
make_native=_native, make_openai=_openai,
)
2026-05-19 14:57:39 +02:00
@router.post("/api/embed")
async def embed_proxy(request: Request):
"""Proxy an embed request to Ollama and reply with embeddings."""
async def _native(client, model, value, options, keep_alive, truncate):
return await client.embed(model=model, input=value, truncate=truncate, options=options, keep_alive=keep_alive)
async def _openai(client, model, value):
return rechunk.openai_embed2ollama(await client.embeddings.create(input=value, model=model), model)
return await _handle_embedding_request(
request, input_field="input", context="embed_proxy",
make_native=_native, make_openai=_openai,
2026-05-19 14:57:39 +02:00
)
@router.post("/api/create")
async def create_proxy(request: Request):
"""
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
quantize = payload.get("quantize")
from_ = payload.get("from")
files = payload.get("files")
adapters = payload.get("adapters")
template = payload.get("template")
license = payload.get("license")
system = payload.get("system")
parameters = payload.get("parameters")
messages = payload.get("messages")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not from_ and not files:
raise HTTPException(
status_code=400, detail="You need to provide either from_ or files parameter!"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
status_lists = []
for endpoint in config.endpoints:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
status_lists.append(create)
combined_status = []
for status_list in status_lists:
combined_status += status_list
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
@router.post("/api/show")
async def show_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a model show request to Ollama and reply with ShowResponse.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, _ = await choose_endpoint(model, reserve=False)
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# 3. Proxy a simple show request
show = await client.show(model=model)
# 4. Return ShowResponse
return show
@router.post("/api/copy")
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
"""
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
if not source and not destination:
payload = orjson.loads(body_bytes.decode("utf-8"))
src = payload.get("source")
dst = payload.get("destination")
else:
src = source
dst = destination
if not src:
raise HTTPException(
status_code=400, detail="Missing required field 'source'"
)
if not dst:
raise HTTPException(
status_code=400, detail="Missing required field 'destination'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 3. Iterate over all endpoints to copy the model on each endpoint
status_list = []
for endpoint in config.endpoints:
if "/v1" not in endpoint:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# 4. Proxy a simple copy request
copy = await client.copy(source=src, destination=dst)
status_list.append(copy.status)
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
return Response(status_code=404 if 404 in status_list else 200)
@router.delete("/api/delete")
async def delete_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to delete the model on each endpoint
status_list = []
for endpoint in config.endpoints:
if "/v1" not in endpoint:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# 3. Proxy a simple copy request
copy = await client.delete(model=model)
status_list.append(copy.status)
# 4. Return 200 0K, if a single enpoint fails, respond with 404
return Response(status_code=404 if 404 in status_list else 200)
@router.post("/api/pull")
async def pull_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a pull request to all Ollama endpoint and report status back.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
else:
insecure = None
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to pull the model
status_list = []
for endpoint in config.endpoints:
if "/v1" not in endpoint:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# 3. Proxy a simple pull request
pull = await client.pull(model=model, insecure=insecure, stream=False)
status_list.append(pull)
combined_status = []
for status in status_list:
combined_status += status
# 4. Report back a deduplicated status message
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
@router.post("/api/push")
async def push_proxy(request: Request):
"""
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints
status_list = []
for endpoint in config.endpoints:
client = get_ollama_client(endpoint)
2026-05-19 14:57:39 +02:00
# 3. Proxy a simple push request
push = await client.push(model=model, insecure=insecure, stream=False)
status_list.append(push)
combined_status = []
for status in status_list:
combined_status += status
# 4. Report a deduplicated status
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
@router.get("/api/version")
async def version_proxy(request: Request):
"""
Proxy a version request to Ollama and reply lowest version of all endpoints.
"""
config = get_config()
# 1. Query all endpoints for version
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions_raw = await asyncio.gather(*tasks)
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
if not all_versions:
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
def version_key(v):
return tuple(map(int, v.split('.')))
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
return JSONResponse(
content={"version": str(min(all_versions, key=version_key))},
status_code=200,
)
@router.get("/api/tags")
async def tags_proxy(request: Request):
"""
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
"""
config = get_config()
# 1. Query all endpoints for models
tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep]
2026-06-14 16:34:31 +02:00
# Also query llama-server / llama-swap endpoints not already covered by config.endpoints
llama_eps_for_tags = [ep for ep in llama_endpoints(config) if ep not in config.endpoints]
2026-05-19 14:57:39 +02:00
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
for modellist in all_models:
for model in modellist:
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
model['model'] = model['id'] + ":latest"
else:
model['id'] = model['model']
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
model['name'] = model['model']
else:
model['id'] = model['model']
models['models'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
status_code=200,
)
2026-06-14 16:34:31 +02:00
async def _fetch_llama_swap_running(endpoint: str) -> list[dict]:
"""Return the list of ready (`state == "ready"`) workers from a llama-swap
endpoint's `/running` route. llama-swap omits the per-model `status` field on
`/v1/models`, so running workers must be read here instead.
"""
config = get_config()
base_url = endpoint.rstrip("/").removesuffix("/v1")
return await fetch.endpoint_details(
base_url, "/running", "running", config.api_keys.get(endpoint),
skip_error_cache=True, timeout=8,
)
# Match the context size in a llama-swap worker's `cmd` string, e.g.
# "llama-server --port 5818 -hf ... --ctx-size 131072 ...". llama.cpp accepts
# both --ctx-size and the short -c alias.
_CTX_SIZE_CMD_RE = re.compile(r"(?:--ctx-size|-c)[=\s]+(\d+)")
def _ctx_size_from_cmd(cmd: str) -> int | None:
"""Extract n_ctx from a llama-swap worker `cmd` string, or None if absent."""
if not cmd:
return None
m = _CTX_SIZE_CMD_RE.search(cmd)
return int(m.group(1)) if m else None
async def _fetch_llama_swap_nctx(endpoint: str, model_id: str) -> int | None:
"""Fallback when a worker's `cmd` lacks --ctx-size: ask the underlying
llama-server via llama-swap's /upstream/<model>/props route (plain /props?model=
is not routed by llama-swap and 404s). Returns n_ctx or None on any failure.
"""
config = get_config()
base_url = endpoint.rstrip("/").removesuffix("/v1")
props_url = f"{base_url}/upstream/{quote(model_id, safe='')}/props"
headers = None
api_key = config.api_keys.get(endpoint)
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
try:
client: aiohttp.ClientSession = get_probe_session(endpoint)
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status == 200:
data = await resp.json()
return data.get("default_generation_settings", {}).get("n_ctx")
except Exception as e:
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
return None
2026-05-19 14:57:39 +02:00
@router.get("/api/ps")
async def ps_proxy(request: Request):
"""
2026-06-14 16:34:31 +02:00
Proxy a ps request to all Ollama, llama-server and llama-swap endpoints and reply a unique list of all running models.
2026-05-19 14:57:39 +02:00
For Ollama endpoints: queries /api/ps
For llama-server endpoints: queries /v1/models with status.value == "loaded"
2026-06-14 16:34:31 +02:00
For llama-swap endpoints: queries /running (state == "ready")
2026-05-19 14:57:39 +02:00
"""
config = get_config()
# 1. Query Ollama endpoints for running models via /api/ps
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query llama-server endpoints for loaded models via /v1/models
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
2026-06-14 16:34:31 +02:00
for ep in config.llama_server_endpoints
2026-05-19 14:57:39 +02:00
]
2026-06-14 16:34:31 +02:00
# 3. Query llama-swap endpoints for running workers via /running
swap_tasks = [_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints]
2026-05-19 14:57:39 +02:00
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
2026-06-14 16:34:31 +02:00
swap_running = await asyncio.gather(*swap_tasks) if swap_tasks else []
2026-05-19 14:57:39 +02:00
models = {'models': []}
# Add Ollama models (if any)
if ollama_loaded:
for modellist in ollama_loaded:
models['models'] += modellist
# Add llama-server models (filter for loaded only, if any)
if llama_loaded:
for modellist in llama_loaded:
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
# Convert llama-server format to Ollama-like format for consistency
for item in loaded_models:
raw_id = item.get("id", "")
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
models['models'].append({
"name": normalized,
"id": normalized,
"digest": "",
"status": item.get("status"),
"details": {"quantization_level": quant} if quant else {}
})
2026-06-14 16:34:31 +02:00
# Add llama-swap running workers (already filtered on state == "ready")
if swap_running:
for runlist in swap_running:
for item in runlist:
if item.get("state") != "ready":
continue
raw_id = item.get("model", "")
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
models['models'].append({
"name": normalized,
"id": normalized,
"digest": "",
"details": {"quantization_level": quant} if quant else {}
})
2026-05-19 14:57:39 +02:00
# 3. Return a JSONResponse with deduplicated currently deployed models
# Deduplicate on 'name' rather than 'digest': llama-server models always
# have digest="" so deduping on digest collapses all of them to one entry.
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['name'])},
status_code=200,
)
@router.get("/api/ps_details")
async def ps_details_proxy(request: Request):
"""
Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances.
This keeps /api/ps backward compatible while providing richer data.
For Ollama endpoints: queries /api/ps
For llama-server endpoints: queries /v1/models with status info
"""
config = get_config()
# 1. Query Ollama endpoints via /api/ps
ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep]
# 2. Query llama-server endpoints via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
(ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8))
for ep in all_llama_endpoints
]
ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else []
llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else []
models: list[dict] = []
# Add Ollama models with endpoint info (if any)
if ollama_loaded:
for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded):
for model in modellist:
if isinstance(model, dict):
model_with_endpoint = dict(model)
model_with_endpoint["endpoint"] = endpoint
models.append(model_with_endpoint)
# Add llama-server models with endpoint info and full status metadata (if any)
if llama_loaded:
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
props_requests: list[tuple[str, str]] = []
llama_models_pending: list[dict] = []
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
# Include sleeping models too so _fetch_llama_props can unload them
loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)]
for item in loaded_models:
if isinstance(item, dict) and item.get("id"):
raw_id = item["id"]
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
model_with_endpoint = {
"name": normalized,
"id": normalized,
"original_name": raw_id,
"digest": "",
"details": {"quantization_level": quant} if quant else {},
"endpoint": endpoint,
"status": item.get("status"),
"created": item.get("created"),
"owned_by": item.get("owned_by")
}
# Include full llama-server status details (args, preset)
status_info = item.get("status", {})
if isinstance(status_info, dict):
model_with_endpoint["llama_status_args"] = status_info.get("args")
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
llama_models_pending.append(model_with_endpoint)
props_requests.append((endpoint, raw_id))
# Fetch /props for each llama-server model to get context length (n_ctx)
# and unload sleeping models automatically
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
client: aiohttp.ClientSession = get_probe_session(endpoint)
2026-05-19 14:57:39 +02:00
base_url = endpoint.rstrip("/").removesuffix("/v1")
props_url = f"{base_url}/props?model={model_id}"
headers = None
api_key = config.api_keys.get(endpoint)
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
try:
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status == 200:
data = await resp.json()
dgs = data.get("default_generation_settings", {})
n_ctx = dgs.get("n_ctx")
is_sleeping = data.get("is_sleeping", False)
# Embedding models have no sampling params in default_generation_settings
is_generation = "temperature" in dgs
if is_sleeping:
2026-06-14 16:34:31 +02:00
await unload_model(endpoint, model_id)
2026-05-19 14:57:39 +02:00
return n_ctx, is_sleeping, is_generation
except Exception as e:
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
return None, False, False
props_results = await asyncio.gather(
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
)
for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results):
if n_ctx is not None:
model_dict["context_length"] = n_ctx
if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
normalized = _normalize_llama_model_name(raw_id)
_endpoint_nctx[(ep, normalized)] = n_ctx
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
if not is_sleeping:
models.append(model_dict)
2026-06-14 16:34:31 +02:00
# Add llama-swap running workers (read from /running; no status/props/auto-unload —
# llama-swap omits the status field on /v1/models and manages its own TTL eviction).
if config.llama_swap_endpoints:
swap_running = await asyncio.gather(
*[_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints]
)
swap_nctx_fallbacks: list[tuple[str, str, dict]] = []
2026-06-14 16:34:31 +02:00
for endpoint, runlist in zip(config.llama_swap_endpoints, swap_running):
for item in runlist:
if not isinstance(item, dict) or item.get("state") != "ready":
continue
raw_id = item.get("model", "")
if not raw_id:
continue
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
swap_model = {
2026-06-14 16:34:31 +02:00
"name": normalized,
"id": normalized,
"original_name": raw_id,
"digest": "",
"details": {"quantization_level": quant} if quant else {},
"endpoint": endpoint,
"state": item.get("state"),
"ttl": item.get("ttl"),
"proxy": item.get("proxy"),
}
# llama-swap omits n_ctx from /running, but the worker's launch
# command carries --ctx-size, so parse it from there (no extra
# request). Workers whose cmd lacks the flag fall back to an
# /upstream/<model>/props probe below.
n_ctx = _ctx_size_from_cmd(item.get("cmd", ""))
if n_ctx is not None:
swap_model["context_length"] = n_ctx
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, normalized)] = n_ctx
else:
swap_nctx_fallbacks.append((endpoint, raw_id, swap_model))
models.append(swap_model)
# Resolve ctx for workers whose cmd lacked --ctx-size via /upstream props.
if swap_nctx_fallbacks:
fallback_results = await asyncio.gather(
*[_fetch_llama_swap_nctx(ep, rid) for ep, rid, _ in swap_nctx_fallbacks]
)
for (ep, _rid, swap_model), n_ctx in zip(swap_nctx_fallbacks, fallback_results):
if n_ctx is not None:
swap_model["context_length"] = n_ctx
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(ep, swap_model["id"])] = n_ctx
2026-06-14 16:34:31 +02:00
2026-05-19 14:57:39 +02:00
return JSONResponse(content={"models": models}, status_code=200)