nomyo-router/api/ollama.py
alpha nerd 3cd530586c
All checks were successful
Build and Publish Docker Image (Semantic Cache) / build (amd64, linux/amd64, docker-amd64) (push) Successful in 3m59s
Build and Publish Docker Image / build (amd64, linux/amd64, docker-amd64) (push) Successful in 1m25s
Build and Publish Docker Image / build (arm64, linux/arm64, docker-arm64) (push) Successful in 12m46s
Build and Publish Docker Image / merge (push) Successful in 33s
Build and Publish Docker Image (Semantic Cache) / build (arm64, linux/arm64, docker-arm64) (push) Successful in 19m56s
Build and Publish Docker Image (Semantic Cache) / merge (push) Successful in 33s
feat: cache backend clients per endpoint instead of building one (with a fresh SSL context) per request
2026-06-07 09:55:54 +02:00

1134 lines
50 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Ollama-native API routes (``/api/*``).
These are the ``/api/generate``, ``/api/chat``, ``/api/embed(dings)`` and the
model-management routes (``/api/create``, ``/api/show``, ``/api/copy``,
``/api/delete``, ``/api/pull``, ``/api/push``, ``/api/version``,
``/api/tags``, ``/api/ps``, ``/api/ps_details``) that the Ollama clients
expect. The chat/generate handlers also serve OpenAI-compatible endpoints
when ``is_openai_compatible(endpoint)`` is true — in that case they
translate the request to the OpenAI Chat Completions / Completions API and
``rechunk`` the response back into Ollama wire format.
"""
import asyncio
import re
import time
from typing import Optional
import aiohttp
import ollama
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, Response, StreamingResponse
from cache import get_llm_cache
from config import get_config
from context_window import (
_count_message_tokens,
_trim_messages_for_context,
_calibrated_trim_target,
_endpoint_nctx,
_CTX_TRIM_SMALL_LIMIT,
)
from fingerprint import _conversation_fingerprint
from state import token_queue, default_headers
from backends.health import (
_is_backend_connection_error,
_is_llama_model_loaded,
_is_llama_model_loaded_or_sleeping,
_mark_backend_unhealthy,
)
from backends.normalize import (
dedupe_on_keys,
is_openai_compatible,
_normalize_llama_model_name,
_extract_llama_quant,
)
from backends.probe import fetch
from backends.sessions import _make_openai_client, get_ollama_client, get_probe_session
from requests.chat import _make_moe_requests
from requests.messages import (
transform_images_to_data_urls,
transform_tool_calls_to_openai,
_strip_assistant_prefill,
_strip_images_from_messages,
_accumulate_openai_tc_delta,
_build_ollama_tool_calls,
)
from requests.rechunk import rechunk
from routing import choose_endpoint, decrement_usage
router = APIRouter()
async def _handle_stream_error(
exc: Exception, endpoint: str, model: str, *, context: str
) -> bytes:
"""Surface an upstream backend error transitively from a streaming generator.
Errors raised while iterating a backend response (e.g. an ollama
``ResponseError`` for a 504 Gateway Time-out) would otherwise escape the
StreamingResponse generator and be dumped by Starlette as an opaque
"Exception in ASGI application" traceback with no indication of which
endpoint/model failed. This logs the failure with that context — which is
what makes the many timeout errors greppable and analyzable — marks the
backend unhealthy when it is a connection-class failure, and returns a
terminal Ollama-format ``{"error": ...}`` line so the client receives a
meaningful error instead of a silently truncated stream.
"""
status_code = getattr(exc, "status_code", None)
err_msg = getattr(exc, "error", None) or str(exc)
print(
f"[{context}] upstream error from ({endpoint}, {model}) "
f"status={status_code} type={type(exc).__name__}: {str(err_msg)[:500]}",
flush=True,
)
if _is_backend_connection_error(exc):
await _mark_backend_unhealthy(endpoint, model, str(err_msg))
err_payload = {"error": str(err_msg)}
if status_code is not None:
err_payload["status_code"] = status_code
return orjson.dumps(err_payload) + b"\n"
async def _guarded_stream(inner, *, endpoint: str, model: str, tracking_model: str, context: str):
"""Wrap a per-route body generator with the shared streaming contract.
Every ``/api/*`` streaming handler needs the same three guarantees around its
body: surface backend errors transitively (via :func:`_handle_stream_error`),
let client-disconnect cancellation propagate untouched, and always decrement
the usage counter. Centralising them here keeps the four bodies free of
duplicated ``try/except/finally`` scaffolding.
"""
try:
async for item in inner:
yield item
except asyncio.CancelledError:
raise
except Exception as e:
try:
yield await _handle_stream_error(e, endpoint, model, context=context)
except Exception:
pass
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, tracking_model)
@router.post("/api/generate")
async def proxy(request: Request):
"""
Proxy a generate request to Ollama and stream the response back to the client.
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
prompt = payload.get("prompt")
suffix = payload.get("suffix")
system = payload.get("system")
template = payload.get("template")
context = payload.get("context")
stream = payload.get("stream")
think = payload.get("think")
raw = payload.get("raw")
_format = payload.get("format")
images = payload.get("images")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not prompt:
raise HTTPException(
status_code=400, detail="Missing required field 'prompt'"
)
except orjson.JSONDecodeError as e:
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
raise HTTPException(status_code=400, detail=error_msg) from e
# Cache lookup — before endpoint selection so no slot is wasted on a hit
_cache = get_llm_cache()
if _cache is not None and _cache_enabled:
_cached = await _cache.get_generate(model, prompt, system or "")
if _cached is not None:
async def _serve_cached_generate():
yield _cached
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
_affinity_key = _conversation_fingerprint(model, None, prompt)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
params = {
"prompt": prompt,
"model": model,
}
optional_params = {
"stream": stream,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"suffix": suffix,
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = get_ollama_client(endpoint)
# 4. Async generator body (error handling + cleanup handled by _guarded_stream)
async def stream_generate_response():
if use_openai:
start_ts = time.perf_counter()
async_gen = await oclient.completions.create(**params)
else:
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
if stream == True:
content_parts: list[str] = []
async for chunk in async_gen:
if use_openai:
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = orjson.dumps(chunk)
# Accumulate and store cache on done chunk — before yield so it always runs
if _cache is not None and _cache_enabled:
if getattr(chunk, "response", None):
content_parts.append(chunk.response)
if getattr(chunk, "done", False):
assembled = orjson.dumps({
k: v for k, v in {
"model": getattr(chunk, "model", model),
"response": "".join(content_parts),
"done": True,
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
"eval_count": getattr(chunk, "eval_count", None),
"total_duration": getattr(chunk, "total_duration", None),
"eval_duration": getattr(chunk, "eval_duration", None),
}.items() if v is not None
}) + b"\n"
try:
await _cache.set_generate(model, prompt, system or "", assembled)
except Exception as _ce:
print(f"[cache] set_generate (streaming) failed: {_ce}")
yield json_line.encode("utf-8") + b"\n"
else:
if use_openai:
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
response = response.model_dump_json()
else:
response = async_gen.model_dump_json()
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
else orjson.dumps(async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response
if _cache is not None and _cache_enabled:
try:
await _cache.set_generate(model, prompt, system or "", cache_bytes)
except Exception as _ce:
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
# 5. Return a StreamingResponse backed by the guarded generator
return StreamingResponse(
_guarded_stream(
stream_generate_response(),
endpoint=endpoint, model=model,
tracking_model=tracking_model, context="generate_proxy",
),
media_type="application/json",
)
@router.post("/api/chat")
async def chat_proxy(request: Request):
"""
Proxy a chat request to Ollama and stream the endpoint reply.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
messages = payload.get("messages")
tools = payload.get("tools")
stream = payload.get("stream")
think = payload.get("think")
_format = payload.get("format")
keep_alive = payload.get("keep_alive")
options = payload.get("options")
logprobs = payload.get("logprobs")
top_logprobs = payload.get("top_logprobs")
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not isinstance(messages, list):
raise HTTPException(
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
)
if options is not None and not isinstance(options, dict):
raise HTTPException(
status_code=400, detail="`options` must be a JSON object"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# Cache lookup — before endpoint selection, always bypassed for MOE
_is_moe = model.startswith("moe-")
_cache = get_llm_cache()
# Normalise model name for cache key: strip ":latest" suffix here so that
# get_chat and set_chat use the same model string regardless of when the
# strip happens further down (line ~1793 strips it for OpenAI endpoints).
_cache_model = model[: -len(":latest")] if model.endswith(":latest") else model
# Snapshot original messages before any OpenAI-format transformation so that
# get_chat and set_chat always use the same key regardless of backend type.
_cache_messages = messages
if _cache is not None and not _is_moe and _cache_enabled:
_cached = await _cache.get_chat("ollama_chat", _cache_model, messages)
if _cached is not None:
async def _serve_cached_chat():
yield _cached
return StreamingResponse(
_serve_cached_chat(),
media_type="application/x-ndjson" if stream else "application/json",
)
# 2. Endpoint logic
if model.startswith("moe-"):
model = model.split("moe-")[1]
opt = True
else:
opt = False
_affinity_key = _conversation_fingerprint(model, messages, None)
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
if messages:
if any("images" in m for m in messages):
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
messages = transform_tool_calls_to_openai(messages)
messages = _strip_assistant_prefill(messages)
params = {
"messages": messages,
"model": model,
}
optional_params = {
"tools": tools,
"stream": stream,
"stream_options": {"include_usage": True} if stream else None,
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
"seed": options.get("seed") if options and "seed" in options else None,
"stop": options.get("stop") if options and "stop" in options else None,
"top_p": options.get("top_p") if options and "top_p" in options else None,
"temperature": options.get("temperature") if options and "temperature" in options else None,
"logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None),
"top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None),
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
}
params.update({k: v for k, v in optional_params.items() if v is not None})
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = get_ollama_client(endpoint)
# For OpenAI endpoints: make the API call in handler scope
# (try/except inside async generators is unreliable with Starlette's streaming)
start_ts = None
async_gen = None
if use_openai:
start_ts = time.perf_counter()
# Proactive trim: only for small-ctx models we've already seen run out of space
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
_pre_est = _count_message_tokens(params.get("messages", []))
if _pre_est > _pre_target:
_pre_msgs = params.get("messages", [])
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
_dropped = len(_pre_msgs) - len(_pre_trimmed)
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
params = {**params, "messages": _pre_trimmed}
try:
async_gen = await oclient.chat.completions.create(**params)
except Exception as e:
_e_str = str(e)
print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}")
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
err_body = getattr(e, "body", {}) or {}
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
n_ctx_limit = err_detail.get("n_ctx", 0)
actual_tokens = err_detail.get("n_prompt_tokens", 0)
if not n_ctx_limit:
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
if _m:
n_ctx_limit = int(_m.group(1))
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
if _m:
actual_tokens = int(_m.group(1))
if not n_ctx_limit:
await decrement_usage(endpoint, tracking_model)
raise
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
msgs_to_trim = params.get("messages", [])
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
try:
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
except Exception as e2:
_e2_str = str(e2)
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools")
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
try:
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
elif _is_backend_connection_error(e):
print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
await _mark_backend_unhealthy(endpoint, model, _e_str)
await decrement_usage(endpoint, tracking_model)
raise
elif "image input is not supported" in _e_str:
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
try:
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
async_gen = await oclient.chat.completions.create(**params)
except Exception:
await decrement_usage(endpoint, tracking_model)
raise
else:
await decrement_usage(endpoint, tracking_model)
raise
# 3. Async generator body (error handling + cleanup handled by _guarded_stream)
async def stream_chat_response():
# The chat method returns a generator of dicts (or GenerateResponse)
if use_openai:
_async_gen = async_gen # established in handler scope above
else:
if opt == True:
# Use the dedicated MOE helper function
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
else:
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
if stream == True:
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
content_parts: list[str] = []
async for chunk in _async_gen:
if use_openai:
_accumulate_openai_tc_delta(chunk, tc_acc)
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
# Inject fully-accumulated tool calls only into the final chunk
if chunk.done and tc_acc and chunk.message:
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
# `chunk` can be a dict or a pydantic model dump to JSON safely
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
json_line = orjson.dumps(chunk)
# Accumulate and store cache on done chunk — before yield so it always runs
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
# already converted to Ollama format by rechunk before this point.
if getattr(chunk, "done", False):
# Detect context exhaustion mid-generation for small-ctx models
_dr = getattr(chunk, "done_reason", None)
# Only cache when no max_tokens limit was set — otherwise
# finish_reason=length might just mean max_tokens was hit,
# not that the context window was exhausted.
_req_max_tok = (
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
if use_openai else
(options.get("num_predict") if options else None)
)
if _dr == "length" and not _req_max_tok:
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
_ct = getattr(chunk, "eval_count", 0) or 0
_inferred_nctx = _pt + _ct
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
if _cache is not None and not _is_moe and _cache_enabled:
if chunk.message and getattr(chunk.message, "content", None):
content_parts.append(chunk.message.content)
if getattr(chunk, "done", False):
assembled = orjson.dumps({
k: v for k, v in {
"model": getattr(chunk, "model", model),
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
"message": {"role": "assistant", "content": "".join(content_parts)},
"done": True,
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
"eval_count": getattr(chunk, "eval_count", None),
"total_duration": getattr(chunk, "total_duration", None),
"eval_duration": getattr(chunk, "eval_duration", None),
}.items() if v is not None
}) + b"\n"
try:
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
except Exception as _ce:
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
yield json_line.encode("utf-8") + b"\n"
else:
if use_openai:
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
response = response.model_dump_json()
else:
response = _async_gen.model_dump_json()
prompt_tok = _async_gen.prompt_eval_count or 0
comp_tok = _async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(_async_gen, "model_dump_json")
else orjson.dumps(_async_gen)
)
cache_bytes = json_line.encode("utf-8") + b"\n"
yield cache_bytes
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
if _cache is not None and not _is_moe and _cache_enabled:
try:
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
except Exception as _ce:
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
# 4. Return a StreamingResponse backed by the guarded generator
media_type = "application/x-ndjson" if stream else "application/json"
return StreamingResponse(
_guarded_stream(
stream_chat_response(),
endpoint=endpoint, model=model,
tracking_model=tracking_model, context="chat_proxy",
),
media_type=media_type,
)
async def _handle_embedding_request(
request: Request,
*,
input_field: str,
context: str,
make_native,
make_openai,
):
"""Shared implementation for ``/api/embeddings`` and ``/api/embed``.
The two routes differ only in the request field they read (``prompt`` vs
``input``), the ollama SDK method they call, and the OpenAI rechunk helper.
Those are passed in via ``input_field`` and the ``make_native`` /
``make_openai`` callables; everything else — parsing, endpoint selection,
serialization, and the streaming error/cleanup contract — is shared.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
value = payload.get(input_field)
truncate = payload.get("truncate")
options = payload.get("options")
keep_alive = payload.get("keep_alive")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not value:
raise HTTPException(
status_code=400, detail=f"Missing required field '{input_field}'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, tracking_model = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
model = model[0]
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = get_ollama_client(endpoint)
# 3. Async generator body (error handling + cleanup handled by _guarded_stream)
async def stream_embedding_response():
if use_openai:
response = await make_openai(client, model, value)
else:
response = await make_native(client, model, value, options, keep_alive, truncate)
if hasattr(response, "model_dump_json"):
json_line = response.model_dump_json()
else:
json_line = orjson.dumps(response)
yield json_line.encode("utf-8") + b"\n"
# 4. Return a StreamingResponse backed by the guarded generator
return StreamingResponse(
_guarded_stream(
stream_embedding_response(),
endpoint=endpoint, model=model,
tracking_model=tracking_model, context=context,
),
media_type="application/json",
)
@router.post("/api/embeddings")
async def embedding_proxy(request: Request):
"""Proxy an embedding request to Ollama and reply with embeddings."""
async def _native(client, model, value, options, keep_alive, truncate):
return await client.embeddings(model=model, prompt=value, options=options, keep_alive=keep_alive)
async def _openai(client, model, value):
return rechunk.openai_embeddings2ollama(await client.embeddings.create(input=value, model=model))
return await _handle_embedding_request(
request, input_field="prompt", context="embeddings_proxy",
make_native=_native, make_openai=_openai,
)
@router.post("/api/embed")
async def embed_proxy(request: Request):
"""Proxy an embed request to Ollama and reply with embeddings."""
async def _native(client, model, value, options, keep_alive, truncate):
return await client.embed(model=model, input=value, truncate=truncate, options=options, keep_alive=keep_alive)
async def _openai(client, model, value):
return rechunk.openai_embed2ollama(await client.embeddings.create(input=value, model=model), model)
return await _handle_embedding_request(
request, input_field="input", context="embed_proxy",
make_native=_native, make_openai=_openai,
)
@router.post("/api/create")
async def create_proxy(request: Request):
"""
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
"""
config = get_config()
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
quantize = payload.get("quantize")
from_ = payload.get("from")
files = payload.get("files")
adapters = payload.get("adapters")
template = payload.get("template")
license = payload.get("license")
system = payload.get("system")
parameters = payload.get("parameters")
messages = payload.get("messages")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
if not from_ and not files:
raise HTTPException(
status_code=400, detail="You need to provide either from_ or files parameter!"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
status_lists = []
for endpoint in config.endpoints:
client = get_ollama_client(endpoint)
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
status_lists.append(create)
combined_status = []
for status_list in status_lists:
combined_status += status_list
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
@router.post("/api/show")
async def show_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a model show request to Ollama and reply with ShowResponse.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Endpoint logic
endpoint, _ = await choose_endpoint(model, reserve=False)
client = get_ollama_client(endpoint)
# 3. Proxy a simple show request
show = await client.show(model=model)
# 4. Return ShowResponse
return show
@router.post("/api/copy")
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
"""
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
if not source and not destination:
payload = orjson.loads(body_bytes.decode("utf-8"))
src = payload.get("source")
dst = payload.get("destination")
else:
src = source
dst = destination
if not src:
raise HTTPException(
status_code=400, detail="Missing required field 'source'"
)
if not dst:
raise HTTPException(
status_code=400, detail="Missing required field 'destination'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 3. Iterate over all endpoints to copy the model on each endpoint
status_list = []
for endpoint in config.endpoints:
if "/v1" not in endpoint:
client = get_ollama_client(endpoint)
# 4. Proxy a simple copy request
copy = await client.copy(source=src, destination=dst)
status_list.append(copy.status)
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
return Response(status_code=404 if 404 in status_list else 200)
@router.delete("/api/delete")
async def delete_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to delete the model on each endpoint
status_list = []
for endpoint in config.endpoints:
if "/v1" not in endpoint:
client = get_ollama_client(endpoint)
# 3. Proxy a simple copy request
copy = await client.delete(model=model)
status_list.append(copy.status)
# 4. Return 200 0K, if a single enpoint fails, respond with 404
return Response(status_code=404 if 404 in status_list else 200)
@router.post("/api/pull")
async def pull_proxy(request: Request, model: Optional[str] = None):
"""
Proxy a pull request to all Ollama endpoint and report status back.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
else:
insecure = None
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints to pull the model
status_list = []
for endpoint in config.endpoints:
if "/v1" not in endpoint:
client = get_ollama_client(endpoint)
# 3. Proxy a simple pull request
pull = await client.pull(model=model, insecure=insecure, stream=False)
status_list.append(pull)
combined_status = []
for status in status_list:
combined_status += status
# 4. Report back a deduplicated status message
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
@router.post("/api/push")
async def push_proxy(request: Request):
"""
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
"""
config = get_config()
# 1. Parse and validate request
try:
body_bytes = await request.body()
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
insecure = payload.get("insecure")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
# 2. Iterate over all endpoints
status_list = []
for endpoint in config.endpoints:
client = get_ollama_client(endpoint)
# 3. Proxy a simple push request
push = await client.push(model=model, insecure=insecure, stream=False)
status_list.append(push)
combined_status = []
for status in status_list:
combined_status += status
# 4. Report a deduplicated status
final_status = list(dict.fromkeys(combined_status))
return dict(final_status)
@router.get("/api/version")
async def version_proxy(request: Request):
"""
Proxy a version request to Ollama and reply lowest version of all endpoints.
"""
config = get_config()
# 1. Query all endpoints for version
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions_raw = await asyncio.gather(*tasks)
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
if not all_versions:
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
def version_key(v):
return tuple(map(int, v.split('.')))
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
return JSONResponse(
content={"version": str(min(all_versions, key=version_key))},
status_code=200,
)
@router.get("/api/tags")
async def tags_proxy(request: Request):
"""
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
"""
config = get_config()
# 1. Query all endpoints for models
tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep]
# Also query llama-server endpoints not already covered by config.endpoints
llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags]
all_models = await asyncio.gather(*tasks)
models = {'models': []}
for modellist in all_models:
for model in modellist:
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
model['model'] = model['id'] + ":latest"
else:
model['id'] = model['model']
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
model['name'] = model['model']
else:
model['id'] = model['model']
models['models'] += modellist
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
status_code=200,
)
@router.get("/api/ps")
async def ps_proxy(request: Request):
"""
Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models.
For Ollama endpoints: queries /api/ps
For llama-server endpoints: queries /v1/models with status.value == "loaded"
"""
config = get_config()
# 1. Query Ollama endpoints for running models via /api/ps
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
# 2. Query llama-server endpoints for loaded models via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
for ep in all_llama_endpoints
]
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
models = {'models': []}
# Add Ollama models (if any)
if ollama_loaded:
for modellist in ollama_loaded:
models['models'] += modellist
# Add llama-server models (filter for loaded only, if any)
if llama_loaded:
for modellist in llama_loaded:
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
# Convert llama-server format to Ollama-like format for consistency
for item in loaded_models:
raw_id = item.get("id", "")
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
models['models'].append({
"name": normalized,
"id": normalized,
"digest": "",
"status": item.get("status"),
"details": {"quantization_level": quant} if quant else {}
})
# 3. Return a JSONResponse with deduplicated currently deployed models
# Deduplicate on 'name' rather than 'digest': llama-server models always
# have digest="" so deduping on digest collapses all of them to one entry.
return JSONResponse(
content={"models": dedupe_on_keys(models['models'], ['name'])},
status_code=200,
)
@router.get("/api/ps_details")
async def ps_details_proxy(request: Request):
"""
Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances.
This keeps /api/ps backward compatible while providing richer data.
For Ollama endpoints: queries /api/ps
For llama-server endpoints: queries /v1/models with status info
"""
config = get_config()
# 1. Query Ollama endpoints via /api/ps
ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep]
# 2. Query llama-server endpoints via /v1/models
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
llama_tasks = [
(ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8))
for ep in all_llama_endpoints
]
ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else []
llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else []
models: list[dict] = []
# Add Ollama models with endpoint info (if any)
if ollama_loaded:
for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded):
for model in modellist:
if isinstance(model, dict):
model_with_endpoint = dict(model)
model_with_endpoint["endpoint"] = endpoint
models.append(model_with_endpoint)
# Add llama-server models with endpoint info and full status metadata (if any)
if llama_loaded:
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
props_requests: list[tuple[str, str]] = []
llama_models_pending: list[dict] = []
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
# Include sleeping models too so _fetch_llama_props can unload them
loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)]
for item in loaded_models:
if isinstance(item, dict) and item.get("id"):
raw_id = item["id"]
normalized = _normalize_llama_model_name(raw_id)
quant = _extract_llama_quant(raw_id)
model_with_endpoint = {
"name": normalized,
"id": normalized,
"original_name": raw_id,
"digest": "",
"details": {"quantization_level": quant} if quant else {},
"endpoint": endpoint,
"status": item.get("status"),
"created": item.get("created"),
"owned_by": item.get("owned_by")
}
# Include full llama-server status details (args, preset)
status_info = item.get("status", {})
if isinstance(status_info, dict):
model_with_endpoint["llama_status_args"] = status_info.get("args")
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
llama_models_pending.append(model_with_endpoint)
props_requests.append((endpoint, raw_id))
# Fetch /props for each llama-server model to get context length (n_ctx)
# and unload sleeping models automatically
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
client: aiohttp.ClientSession = get_probe_session(endpoint)
base_url = endpoint.rstrip("/").removesuffix("/v1")
props_url = f"{base_url}/props?model={model_id}"
headers = None
api_key = config.api_keys.get(endpoint)
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
try:
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status == 200:
data = await resp.json()
dgs = data.get("default_generation_settings", {})
n_ctx = dgs.get("n_ctx")
is_sleeping = data.get("is_sleeping", False)
# Embedding models have no sampling params in default_generation_settings
is_generation = "temperature" in dgs
if is_sleeping:
unload_url = f"{base_url}/models/unload"
try:
async with client.post(
unload_url,
json={"model": model_id},
headers=headers,
) as unload_resp:
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
except Exception as ue:
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
return n_ctx, is_sleeping, is_generation
except Exception as e:
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
return None, False, False
props_results = await asyncio.gather(
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
)
for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results):
if n_ctx is not None:
model_dict["context_length"] = n_ctx
if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
normalized = _normalize_llama_model_name(raw_id)
_endpoint_nctx[(ep, normalized)] = n_ctx
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
if not is_sleeping:
models.append(model_dict)
return JSONResponse(content={"models": models}, status_code=200)