1106 lines
50 KiB
Python
1106 lines
50 KiB
Python
"""Ollama-native API routes (``/api/*``).
|
||
|
||
These are the ``/api/generate``, ``/api/chat``, ``/api/embed(dings)`` and the
|
||
model-management routes (``/api/create``, ``/api/show``, ``/api/copy``,
|
||
``/api/delete``, ``/api/pull``, ``/api/push``, ``/api/version``,
|
||
``/api/tags``, ``/api/ps``, ``/api/ps_details``) that the Ollama clients
|
||
expect. The chat/generate handlers also serve OpenAI-compatible endpoints
|
||
when ``is_openai_compatible(endpoint)`` is true — in that case they
|
||
translate the request to the OpenAI Chat Completions / Completions API and
|
||
``rechunk`` the response back into Ollama wire format.
|
||
"""
|
||
import asyncio
|
||
import re
|
||
import time
|
||
from typing import Optional
|
||
|
||
import aiohttp
|
||
import ollama
|
||
import orjson
|
||
from fastapi import APIRouter, HTTPException, Request
|
||
from starlette.responses import JSONResponse, Response, StreamingResponse
|
||
|
||
from cache import get_llm_cache
|
||
from config import get_config
|
||
from context_window import (
|
||
_count_message_tokens,
|
||
_trim_messages_for_context,
|
||
_calibrated_trim_target,
|
||
_endpoint_nctx,
|
||
_CTX_TRIM_SMALL_LIMIT,
|
||
)
|
||
from fingerprint import _conversation_fingerprint
|
||
from state import token_queue, default_headers
|
||
from backends.health import (
|
||
_is_backend_connection_error,
|
||
_is_llama_model_loaded,
|
||
_is_llama_model_loaded_or_sleeping,
|
||
_mark_backend_unhealthy,
|
||
)
|
||
from backends.normalize import (
|
||
dedupe_on_keys,
|
||
is_openai_compatible,
|
||
_normalize_llama_model_name,
|
||
_extract_llama_quant,
|
||
)
|
||
from backends.probe import fetch
|
||
from backends.sessions import _make_openai_client, get_session
|
||
from requests.chat import _make_moe_requests
|
||
from requests.messages import (
|
||
transform_images_to_data_urls,
|
||
transform_tool_calls_to_openai,
|
||
_strip_assistant_prefill,
|
||
_strip_images_from_messages,
|
||
_accumulate_openai_tc_delta,
|
||
_build_ollama_tool_calls,
|
||
)
|
||
from requests.rechunk import rechunk
|
||
from routing import choose_endpoint, decrement_usage
|
||
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/api/generate")
|
||
async def proxy(request: Request):
|
||
"""
|
||
Proxy a generate request to Ollama and stream the response back to the client.
|
||
"""
|
||
config = get_config()
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
prompt = payload.get("prompt")
|
||
suffix = payload.get("suffix")
|
||
system = payload.get("system")
|
||
template = payload.get("template")
|
||
context = payload.get("context")
|
||
stream = payload.get("stream")
|
||
think = payload.get("think")
|
||
raw = payload.get("raw")
|
||
_format = payload.get("format")
|
||
images = payload.get("images")
|
||
options = payload.get("options")
|
||
keep_alive = payload.get("keep_alive")
|
||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not prompt:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'prompt'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
|
||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||
|
||
# Cache lookup — before endpoint selection so no slot is wasted on a hit
|
||
_cache = get_llm_cache()
|
||
if _cache is not None and _cache_enabled:
|
||
_cached = await _cache.get_generate(model, prompt, system or "")
|
||
if _cached is not None:
|
||
async def _serve_cached_generate():
|
||
yield _cached
|
||
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
|
||
|
||
_affinity_key = _conversation_fingerprint(model, None, prompt)
|
||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||
use_openai = is_openai_compatible(endpoint)
|
||
if use_openai:
|
||
if ":latest" in model:
|
||
model = model.split(":latest")
|
||
model = model[0]
|
||
params = {
|
||
"prompt": prompt,
|
||
"model": model,
|
||
}
|
||
|
||
optional_params = {
|
||
"stream": stream,
|
||
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
||
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
||
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
||
"seed": options.get("seed") if options and "seed" in options else None,
|
||
"stop": options.get("stop") if options and "stop" in options else None,
|
||
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
||
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
||
"suffix": suffix,
|
||
}
|
||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||
else:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
|
||
# 4. Async generator that streams data and decrements the counter
|
||
async def stream_generate_response():
|
||
try:
|
||
if use_openai:
|
||
start_ts = time.perf_counter()
|
||
async_gen = await oclient.completions.create(**params)
|
||
else:
|
||
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
||
if stream == True:
|
||
content_parts: list[str] = []
|
||
async for chunk in async_gen:
|
||
if use_openai:
|
||
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
||
prompt_tok = chunk.prompt_eval_count or 0
|
||
comp_tok = chunk.eval_count or 0
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
if hasattr(chunk, "model_dump_json"):
|
||
json_line = chunk.model_dump_json()
|
||
else:
|
||
json_line = orjson.dumps(chunk)
|
||
# Accumulate and store cache on done chunk — before yield so it always runs
|
||
if _cache is not None and _cache_enabled:
|
||
if getattr(chunk, "response", None):
|
||
content_parts.append(chunk.response)
|
||
if getattr(chunk, "done", False):
|
||
assembled = orjson.dumps({
|
||
k: v for k, v in {
|
||
"model": getattr(chunk, "model", model),
|
||
"response": "".join(content_parts),
|
||
"done": True,
|
||
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
||
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
||
"eval_count": getattr(chunk, "eval_count", None),
|
||
"total_duration": getattr(chunk, "total_duration", None),
|
||
"eval_duration": getattr(chunk, "eval_duration", None),
|
||
}.items() if v is not None
|
||
}) + b"\n"
|
||
try:
|
||
await _cache.set_generate(model, prompt, system or "", assembled)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_generate (streaming) failed: {_ce}")
|
||
yield json_line.encode("utf-8") + b"\n"
|
||
else:
|
||
if use_openai:
|
||
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
||
response = response.model_dump_json()
|
||
else:
|
||
response = async_gen.model_dump_json()
|
||
prompt_tok = async_gen.prompt_eval_count or 0
|
||
comp_tok = async_gen.eval_count or 0
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
json_line = (
|
||
response
|
||
if hasattr(async_gen, "model_dump_json")
|
||
else orjson.dumps(async_gen)
|
||
)
|
||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||
yield cache_bytes
|
||
# Cache non-streaming response
|
||
if _cache is not None and _cache_enabled:
|
||
try:
|
||
await _cache.set_generate(model, prompt, system or "", cache_bytes)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
|
||
|
||
finally:
|
||
# Ensure counter is decremented even if an exception occurs
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
# 5. Return a StreamingResponse backed by the generator
|
||
return StreamingResponse(
|
||
stream_generate_response(),
|
||
media_type="application/json",
|
||
)
|
||
|
||
|
||
@router.post("/api/chat")
|
||
async def chat_proxy(request: Request):
|
||
"""
|
||
Proxy a chat request to Ollama and stream the endpoint reply.
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
messages = payload.get("messages")
|
||
tools = payload.get("tools")
|
||
stream = payload.get("stream")
|
||
think = payload.get("think")
|
||
_format = payload.get("format")
|
||
keep_alive = payload.get("keep_alive")
|
||
options = payload.get("options")
|
||
logprobs = payload.get("logprobs")
|
||
top_logprobs = payload.get("top_logprobs")
|
||
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not isinstance(messages, list):
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
|
||
)
|
||
if options is not None and not isinstance(options, dict):
|
||
raise HTTPException(
|
||
status_code=400, detail="`options` must be a JSON object"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# Cache lookup — before endpoint selection, always bypassed for MOE
|
||
_is_moe = model.startswith("moe-")
|
||
_cache = get_llm_cache()
|
||
# Normalise model name for cache key: strip ":latest" suffix here so that
|
||
# get_chat and set_chat use the same model string regardless of when the
|
||
# strip happens further down (line ~1793 strips it for OpenAI endpoints).
|
||
_cache_model = model[: -len(":latest")] if model.endswith(":latest") else model
|
||
# Snapshot original messages before any OpenAI-format transformation so that
|
||
# get_chat and set_chat always use the same key regardless of backend type.
|
||
_cache_messages = messages
|
||
if _cache is not None and not _is_moe and _cache_enabled:
|
||
_cached = await _cache.get_chat("ollama_chat", _cache_model, messages)
|
||
if _cached is not None:
|
||
async def _serve_cached_chat():
|
||
yield _cached
|
||
return StreamingResponse(
|
||
_serve_cached_chat(),
|
||
media_type="application/x-ndjson" if stream else "application/json",
|
||
)
|
||
|
||
# 2. Endpoint logic
|
||
if model.startswith("moe-"):
|
||
model = model.split("moe-")[1]
|
||
opt = True
|
||
else:
|
||
opt = False
|
||
_affinity_key = _conversation_fingerprint(model, messages, None)
|
||
endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key)
|
||
use_openai = is_openai_compatible(endpoint)
|
||
if use_openai:
|
||
if ":latest" in model:
|
||
model = model.split(":latest")
|
||
model = model[0]
|
||
if messages:
|
||
if any("images" in m for m in messages):
|
||
messages = await asyncio.to_thread(transform_images_to_data_urls, messages)
|
||
messages = transform_tool_calls_to_openai(messages)
|
||
messages = _strip_assistant_prefill(messages)
|
||
params = {
|
||
"messages": messages,
|
||
"model": model,
|
||
}
|
||
optional_params = {
|
||
"tools": tools,
|
||
"stream": stream,
|
||
"stream_options": {"include_usage": True} if stream else None,
|
||
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
||
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
||
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
||
"seed": options.get("seed") if options and "seed" in options else None,
|
||
"stop": options.get("stop") if options and "stop" in options else None,
|
||
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
||
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
||
"logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None),
|
||
"top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None),
|
||
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
|
||
}
|
||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||
oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||
else:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# For OpenAI endpoints: make the API call in handler scope
|
||
# (try/except inside async generators is unreliable with Starlette's streaming)
|
||
start_ts = None
|
||
async_gen = None
|
||
if use_openai:
|
||
start_ts = time.perf_counter()
|
||
# Proactive trim: only for small-ctx models we've already seen run out of space
|
||
_lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model
|
||
_known_nctx = _endpoint_nctx.get((endpoint, _lookup_model))
|
||
if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||
_pre_target = int((_known_nctx - _known_nctx // 4) / 1.2)
|
||
_pre_est = _count_message_tokens(params.get("messages", []))
|
||
if _pre_est > _pre_target:
|
||
_pre_msgs = params.get("messages", [])
|
||
_pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target)
|
||
_dropped = len(_pre_msgs) - len(_pre_trimmed)
|
||
print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True)
|
||
params = {**params, "messages": _pre_trimmed}
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**params)
|
||
except Exception as e:
|
||
_e_str = str(e)
|
||
print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}")
|
||
if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str:
|
||
err_body = getattr(e, "body", {}) or {}
|
||
err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {}
|
||
n_ctx_limit = err_detail.get("n_ctx", 0)
|
||
actual_tokens = err_detail.get("n_prompt_tokens", 0)
|
||
if not n_ctx_limit:
|
||
_m = re.search(r"'n_ctx':\s*(\d+)", _e_str)
|
||
if _m:
|
||
n_ctx_limit = int(_m.group(1))
|
||
_m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str)
|
||
if _m:
|
||
actual_tokens = int(_m.group(1))
|
||
if not n_ctx_limit:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT:
|
||
_endpoint_nctx[(endpoint, model)] = n_ctx_limit
|
||
msgs_to_trim = params.get("messages", [])
|
||
cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens)
|
||
trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target)
|
||
print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying")
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed})
|
||
except Exception as e2:
|
||
_e2_str = str(e2)
|
||
if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str:
|
||
print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools")
|
||
params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")}
|
||
try:
|
||
async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed})
|
||
except Exception:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
else:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
elif _is_backend_connection_error(e):
|
||
print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True)
|
||
await _mark_backend_unhealthy(endpoint, model, _e_str)
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
elif "image input is not supported" in _e_str:
|
||
print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages")
|
||
try:
|
||
params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))}
|
||
async_gen = await oclient.chat.completions.create(**params)
|
||
except Exception:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
else:
|
||
await decrement_usage(endpoint, tracking_model)
|
||
raise
|
||
|
||
# 3. Async generator that streams chat data and decrements the counter
|
||
async def stream_chat_response():
|
||
try:
|
||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||
if use_openai:
|
||
_async_gen = async_gen # established in handler scope above
|
||
else:
|
||
if opt == True:
|
||
# Use the dedicated MOE helper function
|
||
_async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
||
else:
|
||
_async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
||
if stream == True:
|
||
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
||
content_parts: list[str] = []
|
||
async for chunk in _async_gen:
|
||
if use_openai:
|
||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
||
# Inject fully-accumulated tool calls only into the final chunk
|
||
if chunk.done and tc_acc and chunk.message:
|
||
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
||
prompt_tok = chunk.prompt_eval_count or 0
|
||
comp_tok = chunk.eval_count or 0
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
if hasattr(chunk, "model_dump_json"):
|
||
json_line = chunk.model_dump_json()
|
||
else:
|
||
json_line = orjson.dumps(chunk)
|
||
# Accumulate and store cache on done chunk — before yield so it always runs
|
||
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
||
# already converted to Ollama format by rechunk before this point.
|
||
if getattr(chunk, "done", False):
|
||
# Detect context exhaustion mid-generation for small-ctx models
|
||
_dr = getattr(chunk, "done_reason", None)
|
||
# Only cache when no max_tokens limit was set — otherwise
|
||
# finish_reason=length might just mean max_tokens was hit,
|
||
# not that the context window was exhausted.
|
||
_req_max_tok = (
|
||
params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict")
|
||
if use_openai else
|
||
(options.get("num_predict") if options else None)
|
||
)
|
||
if _dr == "length" and not _req_max_tok:
|
||
_pt = getattr(chunk, "prompt_eval_count", 0) or 0
|
||
_ct = getattr(chunk, "eval_count", 0) or 0
|
||
_inferred_nctx = _pt + _ct
|
||
if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT:
|
||
_endpoint_nctx[(endpoint, model)] = _inferred_nctx
|
||
print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True)
|
||
if _cache is not None and not _is_moe and _cache_enabled:
|
||
if chunk.message and getattr(chunk.message, "content", None):
|
||
content_parts.append(chunk.message.content)
|
||
if getattr(chunk, "done", False):
|
||
assembled = orjson.dumps({
|
||
k: v for k, v in {
|
||
"model": getattr(chunk, "model", model),
|
||
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
|
||
"message": {"role": "assistant", "content": "".join(content_parts)},
|
||
"done": True,
|
||
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
||
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
||
"eval_count": getattr(chunk, "eval_count", None),
|
||
"total_duration": getattr(chunk, "total_duration", None),
|
||
"eval_duration": getattr(chunk, "eval_duration", None),
|
||
}.items() if v is not None
|
||
}) + b"\n"
|
||
try:
|
||
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
|
||
yield json_line.encode("utf-8") + b"\n"
|
||
else:
|
||
if use_openai:
|
||
response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts)
|
||
response = response.model_dump_json()
|
||
else:
|
||
response = _async_gen.model_dump_json()
|
||
prompt_tok = _async_gen.prompt_eval_count or 0
|
||
comp_tok = _async_gen.eval_count or 0
|
||
if prompt_tok != 0 or comp_tok != 0:
|
||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||
json_line = (
|
||
response
|
||
if hasattr(_async_gen, "model_dump_json")
|
||
else orjson.dumps(_async_gen)
|
||
)
|
||
cache_bytes = json_line.encode("utf-8") + b"\n"
|
||
yield cache_bytes
|
||
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
|
||
if _cache is not None and not _is_moe and _cache_enabled:
|
||
try:
|
||
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
|
||
except Exception as _ce:
|
||
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
|
||
|
||
finally:
|
||
# Ensure counter is decremented even if an exception occurs
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
# 4. Return a StreamingResponse backed by the generator
|
||
media_type = "application/x-ndjson" if stream else "application/json"
|
||
return StreamingResponse(
|
||
stream_chat_response(),
|
||
media_type=media_type,
|
||
)
|
||
|
||
|
||
@router.post("/api/embeddings")
|
||
async def embedding_proxy(request: Request):
|
||
"""
|
||
Proxy an embedding request to Ollama and reply with embeddings.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
prompt = payload.get("prompt")
|
||
options = payload.get("options")
|
||
keep_alive = payload.get("keep_alive")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not prompt:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'prompt'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Endpoint logic
|
||
endpoint, tracking_model = await choose_endpoint(model)
|
||
use_openai = is_openai_compatible(endpoint)
|
||
if use_openai:
|
||
if ":latest" in model:
|
||
model = model.split(":latest")
|
||
model = model[0]
|
||
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
|
||
else:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# 3. Async generator that streams embedding data and decrements the counter
|
||
async def stream_embedding_response():
|
||
try:
|
||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||
if use_openai:
|
||
async_gen = await client.embeddings.create(input=prompt, model=model)
|
||
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
||
else:
|
||
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
||
if hasattr(async_gen, "model_dump_json"):
|
||
json_line = async_gen.model_dump_json()
|
||
else:
|
||
json_line = orjson.dumps(async_gen)
|
||
yield json_line.encode("utf-8") + b"\n"
|
||
finally:
|
||
# Ensure counter is decremented even if an exception occurs
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
# 5. Return a StreamingResponse backed by the generator
|
||
return StreamingResponse(
|
||
stream_embedding_response(),
|
||
media_type="application/json",
|
||
)
|
||
|
||
|
||
@router.post("/api/embed")
|
||
async def embed_proxy(request: Request):
|
||
"""
|
||
Proxy an embed request to Ollama and reply with embeddings.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
_input = payload.get("input")
|
||
truncate = payload.get("truncate")
|
||
options = payload.get("options")
|
||
keep_alive = payload.get("keep_alive")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not _input:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'input'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Endpoint logic
|
||
endpoint, tracking_model = await choose_endpoint(model)
|
||
use_openai = is_openai_compatible(endpoint)
|
||
if use_openai:
|
||
if ":latest" in model:
|
||
model = model.split(":latest")
|
||
model = model[0]
|
||
client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key"))
|
||
else:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# 3. Async generator that streams embed data and decrements the counter
|
||
async def stream_embedding_response():
|
||
try:
|
||
# The chat method returns a generator of dicts (or GenerateResponse)
|
||
if use_openai:
|
||
async_gen = await client.embeddings.create(input=_input, model=model)
|
||
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
||
else:
|
||
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
||
if hasattr(async_gen, "model_dump_json"):
|
||
json_line = async_gen.model_dump_json()
|
||
else:
|
||
json_line = orjson.dumps(async_gen)
|
||
yield json_line.encode("utf-8") + b"\n"
|
||
finally:
|
||
# Ensure counter is decremented even if an exception occurs
|
||
await decrement_usage(endpoint, tracking_model)
|
||
|
||
# 4. Return a StreamingResponse backed by the generator
|
||
return StreamingResponse(
|
||
stream_embedding_response(),
|
||
media_type="application/json",
|
||
)
|
||
|
||
|
||
@router.post("/api/create")
|
||
async def create_proxy(request: Request):
|
||
"""
|
||
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
|
||
"""
|
||
config = get_config()
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
quantize = payload.get("quantize")
|
||
from_ = payload.get("from")
|
||
files = payload.get("files")
|
||
adapters = payload.get("adapters")
|
||
template = payload.get("template")
|
||
license = payload.get("license")
|
||
system = payload.get("system")
|
||
parameters = payload.get("parameters")
|
||
messages = payload.get("messages")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
if not from_ and not files:
|
||
raise HTTPException(
|
||
status_code=400, detail="You need to provide either from_ or files parameter!"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
status_lists = []
|
||
|
||
for endpoint in config.endpoints:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
||
status_lists.append(create)
|
||
|
||
combined_status = []
|
||
for status_list in status_lists:
|
||
combined_status += status_list
|
||
|
||
final_status = list(dict.fromkeys(combined_status))
|
||
|
||
return dict(final_status)
|
||
|
||
|
||
@router.post("/api/show")
|
||
async def show_proxy(request: Request, model: Optional[str] = None):
|
||
"""
|
||
Proxy a model show request to Ollama and reply with ShowResponse.
|
||
|
||
"""
|
||
try:
|
||
body_bytes = await request.body()
|
||
|
||
if not model:
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
model = payload.get("model")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Endpoint logic
|
||
endpoint, _ = await choose_endpoint(model, reserve=False)
|
||
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
|
||
# 3. Proxy a simple show request
|
||
show = await client.show(model=model)
|
||
|
||
# 4. Return ShowResponse
|
||
return show
|
||
|
||
|
||
@router.post("/api/copy")
|
||
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
|
||
"""
|
||
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
|
||
if not source and not destination:
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
src = payload.get("source")
|
||
dst = payload.get("destination")
|
||
else:
|
||
src = source
|
||
dst = destination
|
||
|
||
if not src:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'source'"
|
||
)
|
||
if not dst:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'destination'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 3. Iterate over all endpoints to copy the model on each endpoint
|
||
status_list = []
|
||
|
||
for endpoint in config.endpoints:
|
||
if "/v1" not in endpoint:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# 4. Proxy a simple copy request
|
||
copy = await client.copy(source=src, destination=dst)
|
||
status_list.append(copy.status)
|
||
|
||
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
|
||
return Response(status_code=404 if 404 in status_list else 200)
|
||
|
||
|
||
@router.delete("/api/delete")
|
||
async def delete_proxy(request: Request, model: Optional[str] = None):
|
||
"""
|
||
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
|
||
if not model:
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
model = payload.get("model")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Iterate over all endpoints to delete the model on each endpoint
|
||
status_list = []
|
||
|
||
for endpoint in config.endpoints:
|
||
if "/v1" not in endpoint:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# 3. Proxy a simple copy request
|
||
copy = await client.delete(model=model)
|
||
status_list.append(copy.status)
|
||
|
||
# 4. Return 200 0K, if a single enpoint fails, respond with 404
|
||
return Response(status_code=404 if 404 in status_list else 200)
|
||
|
||
|
||
@router.post("/api/pull")
|
||
async def pull_proxy(request: Request, model: Optional[str] = None):
|
||
"""
|
||
Proxy a pull request to all Ollama endpoint and report status back.
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
|
||
if not model:
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
model = payload.get("model")
|
||
insecure = payload.get("insecure")
|
||
else:
|
||
insecure = None
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Iterate over all endpoints to pull the model
|
||
status_list = []
|
||
|
||
for endpoint in config.endpoints:
|
||
if "/v1" not in endpoint:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# 3. Proxy a simple pull request
|
||
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
||
status_list.append(pull)
|
||
|
||
combined_status = []
|
||
for status in status_list:
|
||
combined_status += status
|
||
|
||
# 4. Report back a deduplicated status message
|
||
final_status = list(dict.fromkeys(combined_status))
|
||
|
||
return dict(final_status)
|
||
|
||
|
||
@router.post("/api/push")
|
||
async def push_proxy(request: Request):
|
||
"""
|
||
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
|
||
"""
|
||
config = get_config()
|
||
# 1. Parse and validate request
|
||
try:
|
||
body_bytes = await request.body()
|
||
payload = orjson.loads(body_bytes.decode("utf-8"))
|
||
|
||
model = payload.get("model")
|
||
insecure = payload.get("insecure")
|
||
|
||
if not model:
|
||
raise HTTPException(
|
||
status_code=400, detail="Missing required field 'model'"
|
||
)
|
||
except orjson.JSONDecodeError as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
||
|
||
# 2. Iterate over all endpoints
|
||
status_list = []
|
||
|
||
for endpoint in config.endpoints:
|
||
client = ollama.AsyncClient(host=endpoint)
|
||
# 3. Proxy a simple push request
|
||
push = await client.push(model=model, insecure=insecure, stream=False)
|
||
status_list.append(push)
|
||
|
||
combined_status = []
|
||
for status in status_list:
|
||
combined_status += status
|
||
|
||
# 4. Report a deduplicated status
|
||
final_status = list(dict.fromkeys(combined_status))
|
||
|
||
return dict(final_status)
|
||
|
||
|
||
@router.get("/api/version")
|
||
async def version_proxy(request: Request):
|
||
"""
|
||
Proxy a version request to Ollama and reply lowest version of all endpoints.
|
||
|
||
"""
|
||
config = get_config()
|
||
# 1. Query all endpoints for version
|
||
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||
all_versions_raw = await asyncio.gather(*tasks)
|
||
|
||
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
|
||
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
|
||
|
||
if not all_versions:
|
||
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
|
||
|
||
def version_key(v):
|
||
return tuple(map(int, v.split('.')))
|
||
|
||
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
|
||
return JSONResponse(
|
||
content={"version": str(min(all_versions, key=version_key))},
|
||
status_code=200,
|
||
)
|
||
|
||
|
||
@router.get("/api/tags")
|
||
async def tags_proxy(request: Request):
|
||
"""
|
||
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
||
|
||
"""
|
||
config = get_config()
|
||
|
||
# 1. Query all endpoints for models
|
||
tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep]
|
||
# Also query llama-server endpoints not already covered by config.endpoints
|
||
llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags]
|
||
all_models = await asyncio.gather(*tasks)
|
||
|
||
models = {'models': []}
|
||
for modellist in all_models:
|
||
for model in modellist:
|
||
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
|
||
model['model'] = model['id'] + ":latest"
|
||
else:
|
||
model['id'] = model['model']
|
||
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
|
||
model['name'] = model['model']
|
||
else:
|
||
model['id'] = model['model']
|
||
models['models'] += modellist
|
||
|
||
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
||
return JSONResponse(
|
||
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
|
||
status_code=200,
|
||
)
|
||
|
||
|
||
@router.get("/api/ps")
|
||
async def ps_proxy(request: Request):
|
||
"""
|
||
Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models.
|
||
|
||
For Ollama endpoints: queries /api/ps
|
||
For llama-server endpoints: queries /v1/models with status.value == "loaded"
|
||
"""
|
||
config = get_config()
|
||
# 1. Query Ollama endpoints for running models via /api/ps
|
||
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep]
|
||
# 2. Query llama-server endpoints for loaded models via /v1/models
|
||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||
llama_tasks = [
|
||
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)
|
||
for ep in all_llama_endpoints
|
||
]
|
||
|
||
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
||
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
||
|
||
models = {'models': []}
|
||
# Add Ollama models (if any)
|
||
if ollama_loaded:
|
||
for modellist in ollama_loaded:
|
||
models['models'] += modellist
|
||
# Add llama-server models (filter for loaded only, if any)
|
||
if llama_loaded:
|
||
for modellist in llama_loaded:
|
||
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
|
||
# Convert llama-server format to Ollama-like format for consistency
|
||
for item in loaded_models:
|
||
raw_id = item.get("id", "")
|
||
normalized = _normalize_llama_model_name(raw_id)
|
||
quant = _extract_llama_quant(raw_id)
|
||
models['models'].append({
|
||
"name": normalized,
|
||
"id": normalized,
|
||
"digest": "",
|
||
"status": item.get("status"),
|
||
"details": {"quantization_level": quant} if quant else {}
|
||
})
|
||
|
||
# 3. Return a JSONResponse with deduplicated currently deployed models
|
||
# Deduplicate on 'name' rather than 'digest': llama-server models always
|
||
# have digest="" so deduping on digest collapses all of them to one entry.
|
||
return JSONResponse(
|
||
content={"models": dedupe_on_keys(models['models'], ['name'])},
|
||
status_code=200,
|
||
)
|
||
|
||
|
||
@router.get("/api/ps_details")
|
||
async def ps_details_proxy(request: Request):
|
||
"""
|
||
Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances.
|
||
This keeps /api/ps backward compatible while providing richer data.
|
||
|
||
For Ollama endpoints: queries /api/ps
|
||
For llama-server endpoints: queries /v1/models with status info
|
||
"""
|
||
config = get_config()
|
||
# 1. Query Ollama endpoints via /api/ps
|
||
ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep]
|
||
# 2. Query llama-server endpoints via /v1/models
|
||
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
||
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
||
llama_tasks = [
|
||
(ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8))
|
||
for ep in all_llama_endpoints
|
||
]
|
||
|
||
ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else []
|
||
llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else []
|
||
|
||
models: list[dict] = []
|
||
|
||
# Add Ollama models with endpoint info (if any)
|
||
if ollama_loaded:
|
||
for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded):
|
||
for model in modellist:
|
||
if isinstance(model, dict):
|
||
model_with_endpoint = dict(model)
|
||
model_with_endpoint["endpoint"] = endpoint
|
||
models.append(model_with_endpoint)
|
||
|
||
# Add llama-server models with endpoint info and full status metadata (if any)
|
||
if llama_loaded:
|
||
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
|
||
props_requests: list[tuple[str, str]] = []
|
||
llama_models_pending: list[dict] = []
|
||
|
||
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
|
||
# Include sleeping models too so _fetch_llama_props can unload them
|
||
loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)]
|
||
for item in loaded_models:
|
||
if isinstance(item, dict) and item.get("id"):
|
||
raw_id = item["id"]
|
||
normalized = _normalize_llama_model_name(raw_id)
|
||
quant = _extract_llama_quant(raw_id)
|
||
model_with_endpoint = {
|
||
"name": normalized,
|
||
"id": normalized,
|
||
"original_name": raw_id,
|
||
"digest": "",
|
||
"details": {"quantization_level": quant} if quant else {},
|
||
"endpoint": endpoint,
|
||
"status": item.get("status"),
|
||
"created": item.get("created"),
|
||
"owned_by": item.get("owned_by")
|
||
}
|
||
# Include full llama-server status details (args, preset)
|
||
status_info = item.get("status", {})
|
||
if isinstance(status_info, dict):
|
||
model_with_endpoint["llama_status_args"] = status_info.get("args")
|
||
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
|
||
llama_models_pending.append(model_with_endpoint)
|
||
props_requests.append((endpoint, raw_id))
|
||
|
||
# Fetch /props for each llama-server model to get context length (n_ctx)
|
||
# and unload sleeping models automatically
|
||
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
|
||
client: aiohttp.ClientSession = get_session(endpoint)
|
||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||
props_url = f"{base_url}/props?model={model_id}"
|
||
headers = None
|
||
api_key = config.api_keys.get(endpoint)
|
||
if api_key:
|
||
headers = {"Authorization": f"Bearer {api_key}"}
|
||
try:
|
||
async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||
if resp.status == 200:
|
||
data = await resp.json()
|
||
dgs = data.get("default_generation_settings", {})
|
||
n_ctx = dgs.get("n_ctx")
|
||
is_sleeping = data.get("is_sleeping", False)
|
||
# Embedding models have no sampling params in default_generation_settings
|
||
is_generation = "temperature" in dgs
|
||
|
||
if is_sleeping:
|
||
unload_url = f"{base_url}/models/unload"
|
||
try:
|
||
async with client.post(
|
||
unload_url,
|
||
json={"model": model_id},
|
||
headers=headers,
|
||
) as unload_resp:
|
||
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
|
||
except Exception as ue:
|
||
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
|
||
|
||
return n_ctx, is_sleeping, is_generation
|
||
except Exception as e:
|
||
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
|
||
return None, False, False
|
||
|
||
props_results = await asyncio.gather(
|
||
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
||
)
|
||
|
||
for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results):
|
||
if n_ctx is not None:
|
||
model_dict["context_length"] = n_ctx
|
||
if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||
normalized = _normalize_llama_model_name(raw_id)
|
||
_endpoint_nctx[(ep, normalized)] = n_ctx
|
||
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
|
||
if not is_sleeping:
|
||
models.append(model_dict)
|
||
|
||
return JSONResponse(content={"models": models}, status_code=200)
|