2026-05-19 14:57:39 +02:00
|
|
|
|
"""Ollama-native API routes (``/api/*``).
|
|
|
|
|
|
|
|
|
|
|
|
These are the ``/api/generate``, ``/api/chat``, ``/api/embed(dings)`` and the
|
|
|
|
|
|
model-management routes (``/api/create``, ``/api/show``, ``/api/copy``,
|
|
|
|
|
|
``/api/delete``, ``/api/pull``, ``/api/push``, ``/api/version``,
|
|
|
|
|
|
``/api/tags``, ``/api/ps``, ``/api/ps_details``) that the Ollama clients
|
|
|
|
|
|
expect. The chat/generate handlers also serve OpenAI-compatible endpoints
|
|
|
|
|
|
when ``is_openai_compatible(endpoint)`` is true — in that case they
|
|
|
|
|
|
translate the request to the OpenAI Chat Completions / Completions API and
|
|
|
|
|
|
``rechunk`` the response back into Ollama wire format.
|
|
|
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import re
|
|
|
|
|
|
import time
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
|
|
import ollama
|
|
|
|
|
|
import orjson
|
|
|
|
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
|
|
|
|
from starlette.responses import JSONResponse, Response, StreamingResponse
|
|
|
|
|
|
|
|
|
|
|
|
from cache import get_llm_cache
|
|
|
|
|
|
from config import get_config
|
|
|
|
|
|
from context_window import (
|
|
|
|
|
|
_count_message_tokens,
|
|
|
|
|
|
_trim_messages_for_context,
|
|
|
|
|
|
_calibrated_trim_target,
|
|
|
|
|
|
_endpoint_nctx,
|
|
|
|
|
|
_CTX_TRIM_SMALL_LIMIT,
|
|
|
|
|
|
)
|
|
|
|
|
|
from fingerprint import _conversation_fingerprint
|
|
|
|
|
|
from state import token_queue, default_headers
|
|
|
|
|
|
from backends.health import (
|
|
|
|
|
|
_is_backend_connection_error,
|
|
|
|
|
|
_is_llama_model_loaded,
|
|
|
|
|
|
_is_llama_model_loaded_or_sleeping,
|
|
|
|
|
|
_mark_backend_unhealthy,
|
|
|
|
|
|
)
|
|
|
|
|
|
from backends.normalize import (
|
|
|
|
|
|
dedupe_on_keys,
|
|
|
|
|
|
is_openai_compatible,
|
|
|
|
|
|
_normalize_llama_model_name,
|
|
|
|
|
|
_extract_llama_quant,
|
|
|
|
|
|
)
|
|
|
|
|
|
from backends.probe import fetch
|
2026-05-28 09:54:53 +02:00
|
|
|
|
from backends.sessions import _make_openai_client, get_probe_session
|
2026-05-19 14:57:39 +02:00
|
|
|
|
from requests.chat import _make_moe_requests
|
|
|
|
|
|
from requests.messages import (
|
|
|
|
|
|
transform_images_to_data_urls,
|
|
|
|
|
|
transform_tool_calls_to_openai,
|
|
|
|
|
|
_strip_assistant_prefill,
|
|
|
|
|
|
_strip_images_from_messages,
|
|
|
|
|
|
_accumulate_openai_tc_delta,
|
|
|
|
|
|
_build_ollama_tool_calls,
|
|
|
|
|
|
)
|
|
|
|
|
|
from requests.rechunk import rechunk
|
|
|
|
|
|
from routing import choose_endpoint, decrement_usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-06-04 10:33:47 +02:00
|
|
|
|
async def _handle_stream_error(
|
|
|
|
|
|
exc: Exception, endpoint: str, model: str, *, context: str
|
|
|
|
|
|
) -> bytes:
|
|
|
|
|
|
"""Surface an upstream backend error transitively from a streaming generator.
|
|
|
|
|
|
|
|
|
|
|
|
Errors raised while iterating a backend response (e.g. an ollama
|
|
|
|
|
|
``ResponseError`` for a 504 Gateway Time-out) would otherwise escape the
|
|
|
|
|
|
StreamingResponse generator and be dumped by Starlette as an opaque
|
|
|
|
|
|
"Exception in ASGI application" traceback with no indication of which
|
|
|
|
|
|
endpoint/model failed. This logs the failure with that context — which is
|
|
|
|
|
|
what makes the many timeout errors greppable and analyzable — marks the
|
|
|
|
|
|
backend unhealthy when it is a connection-class failure, and returns a
|
|
|
|
|
|
terminal Ollama-format ``{"error": ...}`` line so the client receives a
|
|
|
|
|
|
meaningful error instead of a silently truncated stream.
|
|
|
|
|
|
"""
|
|
|
|
|
|
status_code = getattr(exc, "status_code", None)
|
|
|
|
|
|
err_msg = getattr(exc, "error", None) or str(exc)
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"[{context}] upstream error from ({endpoint}, {model}) "
|
|
|
|
|
|
f"status={status_code} type={type(exc).__name__}: {str(err_msg)[:500]}",
|
|
|
|
|
|
flush=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
if _is_backend_connection_error(exc):
|
|
|
|
|
|
await _mark_backend_unhealthy(endpoint, model, str(err_msg))
|
|
|
|
|
|
err_payload = {"error": str(err_msg)}
|
|
|
|
|
|
if status_code is not None:
|
|
|
|
|
|
err_payload["status_code"] = status_code
|
|
|
|
|
|
return orjson.dumps(err_payload) + b"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-05-19 14:57:39 +02:00
|
|
|
|
@router.post("/api/generate")
|
|
|
|
|
|
async def proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a generate request to Ollama and stream the response back to the client.
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
suffix = payload.get("suffix")
|
|
|
|
|
|
system = payload.get("system")
|
|
|
|
|
|
template = payload.get("template")
|
|
|
|
|
|
context = payload.get("context")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
think = payload.get("think")
|
|
|
|
|
|
raw = payload.get("raw")
|
|
|
|
|
|
_format = payload.get("format")
|
|
|
|
|
|
images = payload.get("images")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=error_msg) from e
|
|
|
|
|
|
|
|
|
|
|
|
# Cache lookup — before endpoint selection so no slot is wasted on a hit
|
|
|
|
|
|
_cache = get_llm_cache()
|
|
|
|
|
|
if _cache is not None and _cache_enabled:
|
|
|
|
|
|
_cached = await _cache.get_generate(model, prompt, system or "")
|
|
|
|
|
|
if _cached is not None:
|
|
|
|
|
|
async def _serve_cached_generate():
|
|
|
|
|
|
yield _cached
|
|
|
|
|
|
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
|
|
|
|
|
|
|
|
|
|
|
|
_affinity_key = _conversation_fingerprint(model, None, prompt)
|
|
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
|
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")
|
|
|
|
|
|
model = model[0]
|
|
|
|
|
|
params = {
|
|
|
|
|
|
"prompt": prompt,
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
optional_params = {
|
|
|
|
|
|
"stream": stream,
|
|
|
|
|
|
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
|
|
|
|
|
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
|
|
|
|
|
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
|
|
|
|
|
"seed": options.get("seed") if options and "seed" in options else None,
|
|
|
|
|
|
"stop": options.get("stop") if options and "stop" in options else None,
|
|
|
|
|
|
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
|
|
|
|
|
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
|
|
|
|
|
"suffix": suffix,
|
|
|
|
|
|
}
|
|
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
|
|
|
|
|
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
|
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Async generator that streams data and decrements the counter
|
|
|
|
|
|
async def stream_generate_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
start_ts = time.perf_counter()
|
|
|
|
|
|
async_gen = await oclient.completions.create(**params)
|
|
|
|
|
|
else:
|
|
|
|
|
|
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if stream == True:
|
|
|
|
|
|
content_parts: list[str] = []
|
|
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
|
|
|
|
|
prompt_tok = chunk.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = chunk.eval_count or 0
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json"):
|
|
|
|
|
|
json_line = chunk.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = orjson.dumps(chunk)
|
|
|
|
|
|
# Accumulate and store cache on done chunk — before yield so it always runs
|
|
|
|
|
|
if _cache is not None and _cache_enabled:
|
|
|
|
|
|
if getattr(chunk, "response", None):
|
|
|
|
|
|
content_parts.append(chunk.response)
|
|
|
|
|
|
if getattr(chunk, "done", False):
|
|
|
|
|
|
assembled = orjson.dumps({
|
|
|
|
|
|
k: v for k, v in {
|
|
|
|
|
|
"model": getattr(chunk, "model", model),
|
|
|
|
|
|
"response": "".join(content_parts),
|
|
|
|
|
|
"done": True,
|
|
|
|
|
|
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
|
|
|
|
|
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
|
|
|
|
|
"eval_count": getattr(chunk, "eval_count", None),
|
|
|
|
|
|
"total_duration": getattr(chunk, "total_duration", None),
|
|
|
|
|
|
"eval_duration": getattr(chunk, "eval_duration", None),
|
|
|
|
|
|
}.items() if v is not None
|
|
|
|
|
|
}) + b"\n"
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_generate(model, prompt, system or "", assembled)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_generate (streaming) failed: {_ce}")
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
else:
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
|
|
|
|
|
response = response.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = async_gen.model_dump_json()
|
|
|
|
|
|
prompt_tok = async_gen.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = async_gen.eval_count or 0
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|
|
|
|
|
json_line = (
|
|
|
|
|
|
response
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
|
|
|
|
|
else orjson.dumps(async_gen)
|
|
|
|
|
|
)
|
|
|
|
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
yield cache_bytes
|
|
|
|
|
|
# Cache non-streaming response
|
|
|
|
|
|
if _cache is not None and _cache_enabled:
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_generate(model, prompt, system or "", cache_bytes)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
|
|
|
|
|
|
|
2026-06-04 10:33:47 +02:00
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await _handle_stream_error(e, endpoint, model, context="generate_proxy")
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
2026-05-19 14:57:39 +02:00
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_generate_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/chat")
|
|
|
|
|
|
async def chat_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a chat request to Ollama and stream the endpoint reply.
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
tools = payload.get("tools")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
think = payload.get("think")
|
|
|
|
|
|
_format = payload.get("format")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
logprobs = payload.get("logprobs")
|
|
|
|
|
|
top_logprobs = payload.get("top_logprobs")
|
|
|
|
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
|
|
|
|
|
|
)
|
|
|
|
|
|
if options is not None and not isinstance(options, dict):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="`options` must be a JSON object"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# Cache lookup — before endpoint selection, always bypassed for MOE
|
|
|
|
|
|
_is_moe = model.startswith("moe-")
|
|
|
|
|
|
_cache = get_llm_cache()
|
|
|
|
|
|
# Normalise model name for cache key: strip ":latest" suffix here so that
|
|
|
|
|
|
# get_chat and set_chat use the same model string regardless of when the
|
|
|
|
|
|
# strip happens further down (line ~1793 strips it for OpenAI endpoints).
|
|
|
|
|
|
_cache_model = model[: -len(":latest")] if model.endswith(":latest") else model
|
|
|
|
|
|
# Snapshot original messages before any OpenAI-format transformation so that
|
|
|
|
|
|
# get_chat and set_chat always use the same key regardless of backend type.
|
|
|
|
|
|
_cache_messages = messages
|
|
|
|
|
|
if _cache is not None and not _is_moe and _cache_enabled:
|
|
|
|
|
|
_cached = await _cache.get_chat("ollama_chat", _cache_model, messages)
|
|
|
|
|
|
if _cached is not None:
|
|
|
|
|
|
async def _serve_cached_chat():
|
|
|
|
|
|
yield _cached
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
_serve_cached_chat(),
|
|
|
|
|
|
media_type="application/x-ndjson" if stream else "application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
if model.startswith("moe-"):
|
|
|
|
|
|
model = model.split("moe-")[1]
|
|
|
|
|
|
opt = True
|
|
|
|
|
|
else:
|
|
|
|
|
|
opt = False
|
|
|
|
|
|
_affinity_key = _conversation_fingerprint(model, messages, None)
|
|
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
|
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")
|
|
|
|
|
|
model = model[0]
|
|
|
|
|
|
if messages:
|
|
|
|
|
|
if any("images" in m for m in messages):
|
|
|
|
|
|
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
|
|
|
|
|
|
messages = transform_tool_calls_to_openai(messages)
|
|
|
|
|
|
messages = _strip_assistant_prefill(messages)
|
|
|
|
|
|
params = {
|
|
|
|
|
|
"messages": messages,
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
}
|
|
|
|
|
|
optional_params = {
|
|
|
|
|
|
"tools": tools,
|
|
|
|
|
|
"stream": stream,
|
|
|
|
|
|
"stream_options": {"include_usage": True} if stream else None,
|
|
|
|
|
|
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
|
|
|
|
|
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
|
|
|
|
|
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
|
|
|
|
|
"seed": options.get("seed") if options and "seed" in options else None,
|
|
|
|
|
|
"stop": options.get("stop") if options and "stop" in options else None,
|
|
|
|
|
|
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
|
|
|
|
|
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
|
|
|
|
|
"logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None),
|
|
|
|
|
|
"top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None),
|
|
|
|
|
|
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
|
|
|
|
|
|
}
|
|
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
|
|
|
|
|
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
|
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# For OpenAI endpoints: make the API call in handler scope
|
|
|
|
|
|
# (try/except inside async generators is unreliable with Starlette's streaming)
|
|
|
|
|
|
start_ts = None
|
|
|
|
|
|
async_gen = None
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
start_ts = time.perf_counter()
|
|
|
|
|
|
# Proactive trim: only for small-ctx models we've already seen run out of space
|
|
|
|
|
|
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
|
|
|
|
|
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
|
|
|
|
|
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
|
|
|
|
|
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
|
|
|
|
|
|
_pre_est = _count_message_tokens(params.get("messages", []))
|
|
|
|
|
|
if _pre_est > _pre_target:
|
|
|
|
|
|
_pre_msgs = params.get("messages", [])
|
|
|
|
|
|
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
|
|
|
|
|
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
|
|
|
|
|
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
|
|
|
|
|
params = {**params, "messages": _pre_trimmed}
|
|
|
|
|
|
try:
|
|
|
|
|
|
async_gen = await oclient.chat.completions.create(**params)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
_e_str = str(e)
|
|
|
|
|
|
print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}")
|
|
|
|
|
|
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
|
|
|
|
|
|
err_body = getattr(e, "body", {}) or {}
|
|
|
|
|
|
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
|
|
|
|
|
n_ctx_limit = err_detail.get("n_ctx", 0)
|
|
|
|
|
|
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
|
|
|
|
|
if not n_ctx_limit:
|
|
|
|
|
|
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
|
|
|
|
|
if _m:
|
|
|
|
|
|
n_ctx_limit = int(_m.group(1))
|
|
|
|
|
|
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
|
|
|
|
|
if _m:
|
|
|
|
|
|
actual_tokens = int(_m.group(1))
|
|
|
|
|
|
if not n_ctx_limit:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
raise
|
|
|
|
|
|
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
|
|
|
|
|
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
|
|
|
|
|
msgs_to_trim = params.get("messages", [])
|
|
|
|
|
|
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
|
|
|
|
|
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
|
|
|
|
|
print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
|
|
|
|
|
|
try:
|
|
|
|
|
|
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
|
|
|
|
|
except Exception as e2:
|
|
|
|
|
|
_e2_str = str(e2)
|
|
|
|
|
|
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
|
|
|
|
|
print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools")
|
|
|
|
|
|
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
|
|
|
|
|
|
try:
|
|
|
|
|
|
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
raise
|
|
|
|
|
|
else:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
raise
|
|
|
|
|
|
elif _is_backend_connection_error(e):
|
|
|
|
|
|
print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
|
|
|
|
|
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
raise
|
|
|
|
|
|
elif "image input is not supported" in _e_str:
|
|
|
|
|
|
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
|
|
|
|
|
try:
|
|
|
|
|
|
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
|
|
|
|
|
|
async_gen = await oclient.chat.completions.create(**params)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
raise
|
|
|
|
|
|
else:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams chat data and decrements the counter
|
|
|
|
|
|
async def stream_chat_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
_async_gen = async_gen # established in handler scope above
|
|
|
|
|
|
else:
|
|
|
|
|
|
if opt == True:
|
|
|
|
|
|
# Use the dedicated MOE helper function
|
|
|
|
|
|
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
|
|
|
|
|
else:
|
|
|
|
|
|
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
|
|
|
|
|
if stream == True:
|
|
|
|
|
|
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
|
|
|
|
|
content_parts: list[str] = []
|
|
|
|
|
|
async for chunk in _async_gen:
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
_accumulate_openai_tc_delta(chunk, tc_acc)
|
|
|
|
|
|
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
|
|
|
|
|
# Inject fully-accumulated tool calls only into the final chunk
|
|
|
|
|
|
if chunk.done and tc_acc and chunk.message:
|
|
|
|
|
|
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
|
|
|
|
|
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
|
|
|
|
|
prompt_tok = chunk.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = chunk.eval_count or 0
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json"):
|
|
|
|
|
|
json_line = chunk.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = orjson.dumps(chunk)
|
|
|
|
|
|
# Accumulate and store cache on done chunk — before yield so it always runs
|
|
|
|
|
|
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
|
|
|
|
|
# already converted to Ollama format by rechunk before this point.
|
|
|
|
|
|
if getattr(chunk, "done", False):
|
|
|
|
|
|
# Detect context exhaustion mid-generation for small-ctx models
|
|
|
|
|
|
_dr = getattr(chunk, "done_reason", None)
|
|
|
|
|
|
# Only cache when no max_tokens limit was set — otherwise
|
|
|
|
|
|
# finish_reason=length might just mean max_tokens was hit,
|
|
|
|
|
|
# not that the context window was exhausted.
|
|
|
|
|
|
_req_max_tok = (
|
|
|
|
|
|
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
|
|
|
|
|
|
if use_openai else
|
|
|
|
|
|
(options.get("num_predict") if options else None)
|
|
|
|
|
|
)
|
|
|
|
|
|
if _dr == "length" and not _req_max_tok:
|
|
|
|
|
|
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
|
|
|
|
|
|
_ct = getattr(chunk, "eval_count", 0) or 0
|
|
|
|
|
|
_inferred_nctx = _pt + _ct
|
|
|
|
|
|
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
|
|
|
|
|
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
|
|
|
|
|
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
|
|
|
|
|
if _cache is not None and not _is_moe and _cache_enabled:
|
|
|
|
|
|
if chunk.message and getattr(chunk.message, "content", None):
|
|
|
|
|
|
content_parts.append(chunk.message.content)
|
|
|
|
|
|
if getattr(chunk, "done", False):
|
|
|
|
|
|
assembled = orjson.dumps({
|
|
|
|
|
|
k: v for k, v in {
|
|
|
|
|
|
"model": getattr(chunk, "model", model),
|
|
|
|
|
|
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
|
|
|
|
|
|
"message": {"role": "assistant", "content": "".join(content_parts)},
|
|
|
|
|
|
"done": True,
|
|
|
|
|
|
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
|
|
|
|
|
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
|
|
|
|
|
"eval_count": getattr(chunk, "eval_count", None),
|
|
|
|
|
|
"total_duration": getattr(chunk, "total_duration", None),
|
|
|
|
|
|
"eval_duration": getattr(chunk, "eval_duration", None),
|
|
|
|
|
|
}.items() if v is not None
|
|
|
|
|
|
}) + b"\n"
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
else:
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
|
|
|
|
|
|
response = response.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = _async_gen.model_dump_json()
|
|
|
|
|
|
prompt_tok = _async_gen.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = _async_gen.eval_count or 0
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
|
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
|
|
|
|
|
json_line = (
|
|
|
|
|
|
response
|
|
|
|
|
|
if hasattr(_async_gen, "model_dump_json")
|
|
|
|
|
|
else orjson.dumps(_async_gen)
|
|
|
|
|
|
)
|
|
|
|
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
yield cache_bytes
|
|
|
|
|
|
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
|
|
|
|
|
|
if _cache is not None and not _is_moe and _cache_enabled:
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
|
|
|
|
|
|
|
2026-06-04 10:33:47 +02:00
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await _handle_stream_error(e, endpoint, model, context="chat_proxy")
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
2026-05-19 14:57:39 +02:00
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
media_type = "application/x-ndjson" if stream else "application/json"
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_chat_response(),
|
|
|
|
|
|
media_type=media_type,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/embeddings")
|
|
|
|
|
|
async def embedding_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an embedding request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
|
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")
|
|
|
|
|
|
model = model[0]
|
|
|
|
|
|
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
|
|
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Async generator that streams embedding data and decrements the counter
|
|
|
|
|
|
async def stream_embedding_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
async_gen = await client.embeddings.create(input=prompt, model=model)
|
|
|
|
|
|
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
|
|
|
|
|
else:
|
|
|
|
|
|
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json"):
|
|
|
|
|
|
json_line = async_gen.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = orjson.dumps(async_gen)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
2026-06-04 10:33:47 +02:00
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await _handle_stream_error(e, endpoint, model, context="embeddings_proxy")
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
2026-05-19 14:57:39 +02:00
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_embedding_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/embed")
|
|
|
|
|
|
async def embed_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an embed request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
_input = payload.get("input")
|
|
|
|
|
|
truncate = payload.get("truncate")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not _input:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'input'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
|
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")
|
|
|
|
|
|
model = model[0]
|
|
|
|
|
|
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
|
|
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Async generator that streams embed data and decrements the counter
|
|
|
|
|
|
async def stream_embedding_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
|
|
|
|
|
if use_openai:
|
|
|
|
|
|
async_gen = await client.embeddings.create(input=_input, model=model)
|
|
|
|
|
|
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
|
|
|
|
|
else:
|
|
|
|
|
|
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json"):
|
|
|
|
|
|
json_line = async_gen.model_dump_json()
|
|
|
|
|
|
else:
|
|
|
|
|
|
json_line = orjson.dumps(async_gen)
|
|
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
2026-06-04 10:33:47 +02:00
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
raise
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
try:
|
|
|
|
|
|
yield await _handle_stream_error(e, endpoint, model, context="embed_proxy")
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
2026-05-19 14:57:39 +02:00
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_embedding_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/create")
|
|
|
|
|
|
async def create_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
quantize = payload.get("quantize")
|
|
|
|
|
|
from_ = payload.get("from")
|
|
|
|
|
|
files = payload.get("files")
|
|
|
|
|
|
adapters = payload.get("adapters")
|
|
|
|
|
|
template = payload.get("template")
|
|
|
|
|
|
license = payload.get("license")
|
|
|
|
|
|
system = payload.get("system")
|
|
|
|
|
|
parameters = payload.get("parameters")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not from_ and not files:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="You need to provide either from_ or files parameter!"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
status_lists = []
|
|
|
|
|
|
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
|
|
|
|
|
status_lists.append(create)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status_list in status_lists:
|
|
|
|
|
|
combined_status += status_list
|
|
|
|
|
|
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/show")
|
|
|
|
|
|
async def show_proxy(request: Request, model: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model show request to Ollama and reply with ShowResponse.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
|
|
|
|
|
endpoint, _ = await choose_endpoint(model, reserve=False)
|
|
|
|
|
|
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Proxy a simple show request
|
|
|
|
|
|
show = await client.show(model=model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return ShowResponse
|
|
|
|
|
|
return show
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/copy")
|
|
|
|
|
|
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
|
|
|
|
|
if not source and not destination:
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
src = payload.get("source")
|
|
|
|
|
|
dst = payload.get("destination")
|
|
|
|
|
|
else:
|
|
|
|
|
|
src = source
|
|
|
|
|
|
dst = destination
|
|
|
|
|
|
|
|
|
|
|
|
if not src:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'source'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not dst:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'destination'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Iterate over all endpoints to copy the model on each endpoint
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 4. Proxy a simple copy request
|
|
|
|
|
|
copy = await client.copy(source=src, destination=dst)
|
|
|
|
|
|
status_list.append(copy.status)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
|
|
|
|
|
|
return Response(status_code=404 if 404 in status_list else 200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/api/delete")
|
|
|
|
|
|
async def delete_proxy(request: Request, model: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints to delete the model on each endpoint
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple copy request
|
|
|
|
|
|
copy = await client.delete(model=model)
|
|
|
|
|
|
status_list.append(copy.status)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return 200 0K, if a single enpoint fails, respond with 404
|
|
|
|
|
|
return Response(status_code=404 if 404 in status_list else 200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/pull")
|
|
|
|
|
|
async def pull_proxy(request: Request, model: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a pull request to all Ollama endpoint and report status back.
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
insecure = payload.get("insecure")
|
|
|
|
|
|
else:
|
|
|
|
|
|
insecure = None
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints to pull the model
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple pull request
|
|
|
|
|
|
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
|
|
|
|
|
status_list.append(pull)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status in status_list:
|
|
|
|
|
|
combined_status += status
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Report back a deduplicated status message
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/api/push")
|
|
|
|
|
|
async def push_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
insecure = payload.get("insecure")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints
|
|
|
|
|
|
status_list = []
|
|
|
|
|
|
|
|
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple push request
|
|
|
|
|
|
push = await client.push(model=model, insecure=insecure, stream=False)
|
|
|
|
|
|
status_list.append(push)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status in status_list:
|
|
|
|
|
|
combined_status += status
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Report a deduplicated status
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/api/version")
|
|
|
|
|
|
async def version_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a version request to Ollama and reply lowest version of all endpoints.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Query all endpoints for version
|
|
|
|
|
|
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
all_versions_raw = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
|
|
|
|
|
|
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
|
|
|
|
|
|
|
|
|
|
|
|
if not all_versions:
|
|
|
|
|
|
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
|
|
|
|
|
|
|
|
|
|
|
|
def version_key(v):
|
|
|
|
|
|
return tuple(map(int, v.split('.')))
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"version": str(min(all_versions, key=version_key))},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/api/tags")
|
|
|
|
|
|
async def tags_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Query all endpoints for models
|
|
|
|
|
|
tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep]
|
|
|
|
|
|
# Also query llama-server endpoints not already covered by config.endpoints
|
|
|
|
|
|
llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
|
|
|
|
|
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags]
|
|
|
|
|
|
all_models = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
models = {'models': []}
|
|
|
|
|
|
for modellist in all_models:
|
|
|
|
|
|
for model in modellist:
|
|
|
|
|
|
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
|
|
|
|
|
|
model['model'] = model['id'] + ":latest"
|
|
|
|
|
|
else:
|
|
|
|
|
|
model['id'] = model['model']
|
|
|
|
|
|
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
|
|
|
|
|
|
model['name'] = model['model']
|
|
|
|
|
|
else:
|
|
|
|
|
|
model['id'] = model['model']
|
|
|
|
|
|
models['models'] += modellist
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/api/ps")
|
|
|
|
|
|
async def ps_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models.
|
|
|
|
|
|
|
|
|
|
|
|
For Ollama endpoints: queries /api/ps
|
|
|
|
|
|
For llama-server endpoints: queries /v1/models with status.value == "loaded"
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Query Ollama endpoints for running models via /api/ps
|
|
|
|
|
|
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
# 2. Query llama-server endpoints for loaded models via /v1/models
|
|
|
|
|
|
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
|
|
|
|
|
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
|
|
|
|
|
llama_tasks = [
|
|
|
|
|
|
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
|
|
|
|
|
for ep in all_llama_endpoints
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
|
|
|
|
|
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
|
|
|
|
|
|
|
|
|
|
|
models = {'models': []}
|
|
|
|
|
|
# Add Ollama models (if any)
|
|
|
|
|
|
if ollama_loaded:
|
|
|
|
|
|
for modellist in ollama_loaded:
|
|
|
|
|
|
models['models'] += modellist
|
|
|
|
|
|
# Add llama-server models (filter for loaded only, if any)
|
|
|
|
|
|
if llama_loaded:
|
|
|
|
|
|
for modellist in llama_loaded:
|
|
|
|
|
|
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
|
|
|
|
|
|
# Convert llama-server format to Ollama-like format for consistency
|
|
|
|
|
|
for item in loaded_models:
|
|
|
|
|
|
raw_id = item.get("id", "")
|
|
|
|
|
|
normalized = _normalize_llama_model_name(raw_id)
|
|
|
|
|
|
quant = _extract_llama_quant(raw_id)
|
|
|
|
|
|
models['models'].append({
|
|
|
|
|
|
"name": normalized,
|
|
|
|
|
|
"id": normalized,
|
|
|
|
|
|
"digest": "",
|
|
|
|
|
|
"status": item.get("status"),
|
|
|
|
|
|
"details": {"quantization_level": quant} if quant else {}
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Return a JSONResponse with deduplicated currently deployed models
|
|
|
|
|
|
# Deduplicate on 'name' rather than 'digest': llama-server models always
|
|
|
|
|
|
# have digest="" so deduping on digest collapses all of them to one entry.
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"models": dedupe_on_keys(models['models'], ['name'])},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/api/ps_details")
|
|
|
|
|
|
async def ps_details_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances.
|
|
|
|
|
|
This keeps /api/ps backward compatible while providing richer data.
|
|
|
|
|
|
|
|
|
|
|
|
For Ollama endpoints: queries /api/ps
|
|
|
|
|
|
For llama-server endpoints: queries /v1/models with status info
|
|
|
|
|
|
"""
|
|
|
|
|
|
config = get_config()
|
|
|
|
|
|
# 1. Query Ollama endpoints via /api/ps
|
|
|
|
|
|
ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
# 2. Query llama-server endpoints via /v1/models
|
|
|
|
|
|
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
|
|
|
|
|
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
|
|
|
|
|
llama_tasks = [
|
|
|
|
|
|
(ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8))
|
|
|
|
|
|
for ep in all_llama_endpoints
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else []
|
|
|
|
|
|
llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else []
|
|
|
|
|
|
|
|
|
|
|
|
models: list[dict] = []
|
|
|
|
|
|
|
|
|
|
|
|
# Add Ollama models with endpoint info (if any)
|
|
|
|
|
|
if ollama_loaded:
|
|
|
|
|
|
for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded):
|
|
|
|
|
|
for model in modellist:
|
|
|
|
|
|
if isinstance(model, dict):
|
|
|
|
|
|
model_with_endpoint = dict(model)
|
|
|
|
|
|
model_with_endpoint["endpoint"] = endpoint
|
|
|
|
|
|
models.append(model_with_endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# Add llama-server models with endpoint info and full status metadata (if any)
|
|
|
|
|
|
if llama_loaded:
|
|
|
|
|
|
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
|
|
|
|
|
|
props_requests: list[tuple[str, str]] = []
|
|
|
|
|
|
llama_models_pending: list[dict] = []
|
|
|
|
|
|
|
|
|
|
|
|
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
|
|
|
|
|
|
# Include sleeping models too so _fetch_llama_props can unload them
|
|
|
|
|
|
loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)]
|
|
|
|
|
|
for item in loaded_models:
|
|
|
|
|
|
if isinstance(item, dict) and item.get("id"):
|
|
|
|
|
|
raw_id = item["id"]
|
|
|
|
|
|
normalized = _normalize_llama_model_name(raw_id)
|
|
|
|
|
|
quant = _extract_llama_quant(raw_id)
|
|
|
|
|
|
model_with_endpoint = {
|
|
|
|
|
|
"name": normalized,
|
|
|
|
|
|
"id": normalized,
|
|
|
|
|
|
"original_name": raw_id,
|
|
|
|
|
|
"digest": "",
|
|
|
|
|
|
"details": {"quantization_level": quant} if quant else {},
|
|
|
|
|
|
"endpoint": endpoint,
|
|
|
|
|
|
"status": item.get("status"),
|
|
|
|
|
|
"created": item.get("created"),
|
|
|
|
|
|
"owned_by": item.get("owned_by")
|
|
|
|
|
|
}
|
|
|
|
|
|
# Include full llama-server status details (args, preset)
|
|
|
|
|
|
status_info = item.get("status", {})
|
|
|
|
|
|
if isinstance(status_info, dict):
|
|
|
|
|
|
model_with_endpoint["llama_status_args"] = status_info.get("args")
|
|
|
|
|
|
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
|
|
|
|
|
|
llama_models_pending.append(model_with_endpoint)
|
|
|
|
|
|
props_requests.append((endpoint, raw_id))
|
|
|
|
|
|
|
|
|
|
|
|
# Fetch /props for each llama-server model to get context length (n_ctx)
|
|
|
|
|
|
# and unload sleeping models automatically
|
|
|
|
|
|
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
|
2026-05-28 09:54:53 +02:00
|
|
|
|
client: aiohttp.ClientSession = get_probe_session(endpoint)
|
2026-05-19 14:57:39 +02:00
|
|
|
|
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
|
|
|
|
|
props_url = f"{base_url}/props?model={model_id}"
|
|
|
|
|
|
headers = None
|
|
|
|
|
|
api_key = config.api_keys.get(endpoint)
|
|
|
|
|
|
if api_key:
|
|
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
|
|
|
|
|
if resp.status == 200:
|
|
|
|
|
|
data = await resp.json()
|
|
|
|
|
|
dgs = data.get("default_generation_settings", {})
|
|
|
|
|
|
n_ctx = dgs.get("n_ctx")
|
|
|
|
|
|
is_sleeping = data.get("is_sleeping", False)
|
|
|
|
|
|
# Embedding models have no sampling params in default_generation_settings
|
|
|
|
|
|
is_generation = "temperature" in dgs
|
|
|
|
|
|
|
|
|
|
|
|
if is_sleeping:
|
|
|
|
|
|
unload_url = f"{base_url}/models/unload"
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.post(
|
|
|
|
|
|
unload_url,
|
|
|
|
|
|
json={"model": model_id},
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
) as unload_resp:
|
|
|
|
|
|
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
|
|
|
|
|
|
except Exception as ue:
|
|
|
|
|
|
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
|
|
|
|
|
|
|
|
|
|
|
|
return n_ctx, is_sleeping, is_generation
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
|
|
|
|
|
|
return None, False, False
|
|
|
|
|
|
|
|
|
|
|
|
props_results = await asyncio.gather(
|
|
|
|
|
|
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results):
|
|
|
|
|
|
if n_ctx is not None:
|
|
|
|
|
|
model_dict["context_length"] = n_ctx
|
|
|
|
|
|
if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
|
|
|
|
|
normalized = _normalize_llama_model_name(raw_id)
|
|
|
|
|
|
_endpoint_nctx[(ep, normalized)] = n_ctx
|
|
|
|
|
|
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
|
|
|
|
|
|
if not is_sleeping:
|
|
|
|
|
|
models.append(model_dict)
|
|
|
|
|
|
|
|
|
|
|
|
return JSONResponse(content={"models": models}, status_code=200)
|