"""Ollama-native API routes (``/api/*``). These are the ``/api/generate``, ``/api/chat``, ``/api/embed(dings)`` and the model-management routes (``/api/create``, ``/api/show``, ``/api/copy``, ``/api/delete``, ``/api/pull``, ``/api/push``, ``/api/version``, ``/api/tags``, ``/api/ps``, ``/api/ps_details``) that the Ollama clients expect. The chat/generate handlers also serve OpenAI-compatible endpoints when ``is_openai_compatible(endpoint)`` is true — in that case they translate the request to the OpenAI Chat Completions / Completions API and ``rechunk`` the response back into Ollama wire format. """ import asyncio import re import time from typing import Optional import aiohttp import ollama import orjson from fastapi import APIRouter, HTTPException, Request from starlette.responses import JSONResponse, Response, StreamingResponse from cache import get_llm_cache from config import get_config from context_window import ( _count_message_tokens, _trim_messages_for_context, _calibrated_trim_target, _endpoint_nctx, _CTX_TRIM_SMALL_LIMIT, ) from fingerprint import _conversation_fingerprint from state import token_queue, default_headers from backends.health import ( _is_backend_connection_error, _is_llama_model_loaded, _is_llama_model_loaded_or_sleeping, _mark_backend_unhealthy, ) from backends.normalize import ( dedupe_on_keys, is_openai_compatible, _normalize_llama_model_name, _extract_llama_quant, ) from backends.probe import fetch from backends.sessions import _make_openai_client, get_probe_session from requests.chat import _make_moe_requests from requests.messages import ( transform_images_to_data_urls, transform_tool_calls_to_openai, _strip_assistant_prefill, _strip_images_from_messages, _accumulate_openai_tc_delta, _build_ollama_tool_calls, ) from requests.rechunk import rechunk from routing import choose_endpoint, decrement_usage router = APIRouter() @router.post("/api/generate") async def proxy(request: Request): """ Proxy a generate request to Ollama and stream the response back to the client. """ config = get_config() try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") prompt = payload.get("prompt") suffix = payload.get("suffix") system = payload.get("system") template = payload.get("template") context = payload.get("context") stream = payload.get("stream") think = payload.get("think") raw = payload.get("raw") _format = payload.get("format") images = payload.get("images") options = payload.get("options") keep_alive = payload.get("keep_alive") _cache_enabled = payload.get("nomyo", {}).get("cache", False) if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not prompt: raise HTTPException( status_code=400, detail="Missing required field 'prompt'" ) except orjson.JSONDecodeError as e: error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted." raise HTTPException(status_code=400, detail=error_msg) from e # Cache lookup — before endpoint selection so no slot is wasted on a hit _cache = get_llm_cache() if _cache is not None and _cache_enabled: _cached = await _cache.get_generate(model, prompt, system or "") if _cached is not None: async def _serve_cached_generate(): yield _cached return StreamingResponse(_serve_cached_generate(), media_type="application/json") _affinity_key = _conversation_fingerprint(model, None, prompt) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: model = model.split(":latest") model = model[0] params = { "prompt": prompt, "model": model, } optional_params = { "stream": stream, "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, "seed": options.get("seed") if options and "seed" in options else None, "stop": options.get("stop") if options and "stop" in options else None, "top_p": options.get("top_p") if options and "top_p" in options else None, "temperature": options.get("temperature") if options and "temperature" in options else None, "suffix": suffix, } params.update({k: v for k, v in optional_params.items() if v is not None}) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # 4. Async generator that streams data and decrements the counter async def stream_generate_response(): try: if use_openai: start_ts = time.perf_counter() async_gen = await oclient.completions.create(**params) else: async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) if stream == True: content_parts: list[str] = [] async for chunk in async_gen: if use_openai: chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: json_line = orjson.dumps(chunk) # Accumulate and store cache on done chunk — before yield so it always runs if _cache is not None and _cache_enabled: if getattr(chunk, "response", None): content_parts.append(chunk.response) if getattr(chunk, "done", False): assembled = orjson.dumps({ k: v for k, v in { "model": getattr(chunk, "model", model), "response": "".join(content_parts), "done": True, "done_reason": getattr(chunk, "done_reason", "stop") or "stop", "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), "eval_count": getattr(chunk, "eval_count", None), "total_duration": getattr(chunk, "total_duration", None), "eval_duration": getattr(chunk, "eval_duration", None), }.items() if v is not None }) + b"\n" try: await _cache.set_generate(model, prompt, system or "", assembled) except Exception as _ce: print(f"[cache] set_generate (streaming) failed: {_ce}") yield json_line.encode("utf-8") + b"\n" else: if use_openai: response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) response = response.model_dump_json() else: response = async_gen.model_dump_json() prompt_tok = async_gen.prompt_eval_count or 0 comp_tok = async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response if hasattr(async_gen, "model_dump_json") else orjson.dumps(async_gen) ) cache_bytes = json_line.encode("utf-8") + b"\n" yield cache_bytes # Cache non-streaming response if _cache is not None and _cache_enabled: try: await _cache.set_generate(model, prompt, system or "", cache_bytes) except Exception as _ce: print(f"[cache] set_generate (non-streaming) failed: {_ce}") finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return StreamingResponse( stream_generate_response(), media_type="application/json", ) @router.post("/api/chat") async def chat_proxy(request: Request): """ Proxy a chat request to Ollama and stream the endpoint reply. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") messages = payload.get("messages") tools = payload.get("tools") stream = payload.get("stream") think = payload.get("think") _format = payload.get("format") keep_alive = payload.get("keep_alive") options = payload.get("options") logprobs = payload.get("logprobs") top_logprobs = payload.get("top_logprobs") _cache_enabled = payload.get("nomyo", {}).get("cache", False) if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not isinstance(messages, list): raise HTTPException( status_code=400, detail="Missing or invalid 'messages' field (must be a list)" ) if options is not None and not isinstance(options, dict): raise HTTPException( status_code=400, detail="`options` must be a JSON object" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # Cache lookup — before endpoint selection, always bypassed for MOE _is_moe = model.startswith("moe-") _cache = get_llm_cache() # Normalise model name for cache key: strip ":latest" suffix here so that # get_chat and set_chat use the same model string regardless of when the # strip happens further down (line ~1793 strips it for OpenAI endpoints). _cache_model = model[: -len(":latest")] if model.endswith(":latest") else model # Snapshot original messages before any OpenAI-format transformation so that # get_chat and set_chat always use the same key regardless of backend type. _cache_messages = messages if _cache is not None and not _is_moe and _cache_enabled: _cached = await _cache.get_chat("ollama_chat", _cache_model, messages) if _cached is not None: async def _serve_cached_chat(): yield _cached return StreamingResponse( _serve_cached_chat(), media_type="application/x-ndjson" if stream else "application/json", ) # 2. Endpoint logic if model.startswith("moe-"): model = model.split("moe-")[1] opt = True else: opt = False _affinity_key = _conversation_fingerprint(model, messages, None) endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: model = model.split(":latest") model = model[0] if messages: if any("images" in m for m in messages): messages = await asyncio.to_thread(transform_images_to_data_urls, messages) messages = transform_tool_calls_to_openai(messages) messages = _strip_assistant_prefill(messages) params = { "messages": messages, "model": model, } optional_params = { "tools": tools, "stream": stream, "stream_options": {"include_usage": True} if stream else None, "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, "seed": options.get("seed") if options and "seed" in options else None, "stop": options.get("stop") if options and "stop" in options else None, "top_p": options.get("top_p") if options and "top_p" in options else None, "temperature": options.get("temperature") if options and "temperature" in options else None, "logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None), "top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None), "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None } params.update({k: v for k, v in optional_params.items() if v is not None}) oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # For OpenAI endpoints: make the API call in handler scope # (try/except inside async generators is unreliable with Starlette's streaming) start_ts = None async_gen = None if use_openai: start_ts = time.perf_counter() # Proactive trim: only for small-ctx models we've already seen run out of space _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) _pre_est = _count_message_tokens(params.get("messages", [])) if _pre_est > _pre_target: _pre_msgs = params.get("messages", []) _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) _dropped = len(_pre_msgs) - len(_pre_trimmed) print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) params = {**params, "messages": _pre_trimmed} try: async_gen = await oclient.chat.completions.create(**params) except Exception as e: _e_str = str(e) print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}") if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: err_body = getattr(e, "body", {}) or {} err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} n_ctx_limit = err_detail.get("n_ctx", 0) actual_tokens = err_detail.get("n_prompt_tokens", 0) if not n_ctx_limit: _m = re.search(r"'n_ctx':\s*(\d+)", _e_str) if _m: n_ctx_limit = int(_m.group(1)) _m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) if _m: actual_tokens = int(_m.group(1)) if not n_ctx_limit: await decrement_usage(endpoint, tracking_model) raise if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: _endpoint_nctx[(endpoint, model)] = n_ctx_limit msgs_to_trim = params.get("messages", []) cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") try: async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) except Exception as e2: _e2_str = str(e2) if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools") params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} try: async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) except Exception: await decrement_usage(endpoint, tracking_model) raise else: await decrement_usage(endpoint, tracking_model) raise elif _is_backend_connection_error(e): print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) await _mark_backend_unhealthy(endpoint, model, _e_str) await decrement_usage(endpoint, tracking_model) raise elif "image input is not supported" in _e_str: print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") try: params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} async_gen = await oclient.chat.completions.create(**params) except Exception: await decrement_usage(endpoint, tracking_model) raise else: await decrement_usage(endpoint, tracking_model) raise # 3. Async generator that streams chat data and decrements the counter async def stream_chat_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) if use_openai: _async_gen = async_gen # established in handler scope above else: if opt == True: # Use the dedicated MOE helper function _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) else: _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) if stream == True: tc_acc = {} # accumulate OpenAI tool-call deltas across chunks content_parts: list[str] = [] async for chunk in _async_gen: if use_openai: _accumulate_openai_tc_delta(chunk, tc_acc) chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) # Inject fully-accumulated tool calls only into the final chunk if chunk.done and tc_acc and chunk.message: chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc) # `chunk` can be a dict or a pydantic model – dump to JSON safely prompt_tok = chunk.prompt_eval_count or 0 comp_tok = chunk.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) if hasattr(chunk, "model_dump_json"): json_line = chunk.model_dump_json() else: json_line = orjson.dumps(chunk) # Accumulate and store cache on done chunk — before yield so it always runs # Works for both Ollama-native and OpenAI-compatible backends; chunks are # already converted to Ollama format by rechunk before this point. if getattr(chunk, "done", False): # Detect context exhaustion mid-generation for small-ctx models _dr = getattr(chunk, "done_reason", None) # Only cache when no max_tokens limit was set — otherwise # finish_reason=length might just mean max_tokens was hit, # not that the context window was exhausted. _req_max_tok = ( params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") if use_openai else (options.get("num_predict") if options else None) ) if _dr == "length" and not _req_max_tok: _pt = getattr(chunk, "prompt_eval_count", 0) or 0 _ct = getattr(chunk, "eval_count", 0) or 0 _inferred_nctx = _pt + _ct if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: _endpoint_nctx[(endpoint, model)] = _inferred_nctx print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) if _cache is not None and not _is_moe and _cache_enabled: if chunk.message and getattr(chunk.message, "content", None): content_parts.append(chunk.message.content) if getattr(chunk, "done", False): assembled = orjson.dumps({ k: v for k, v in { "model": getattr(chunk, "model", model), "created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)), "message": {"role": "assistant", "content": "".join(content_parts)}, "done": True, "done_reason": getattr(chunk, "done_reason", "stop") or "stop", "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), "eval_count": getattr(chunk, "eval_count", None), "total_duration": getattr(chunk, "total_duration", None), "eval_duration": getattr(chunk, "eval_duration", None), }.items() if v is not None }) + b"\n" try: await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled) except Exception as _ce: print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}") yield json_line.encode("utf-8") + b"\n" else: if use_openai: response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) response = response.model_dump_json() else: response = _async_gen.model_dump_json() prompt_tok = _async_gen.prompt_eval_count or 0 comp_tok = _async_gen.eval_count or 0 if prompt_tok != 0 or comp_tok != 0: await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) json_line = ( response if hasattr(_async_gen, "model_dump_json") else orjson.dumps(_async_gen) ) cache_bytes = json_line.encode("utf-8") + b"\n" yield cache_bytes # Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends) if _cache is not None and not _is_moe and _cache_enabled: try: await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes) except Exception as _ce: print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}") finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator media_type = "application/x-ndjson" if stream else "application/json" return StreamingResponse( stream_chat_response(), media_type=media_type, ) @router.post("/api/embeddings") async def embedding_proxy(request: Request): """ Proxy an embedding request to Ollama and reply with embeddings. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") prompt = payload.get("prompt") options = payload.get("options") keep_alive = payload.get("keep_alive") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not prompt: raise HTTPException( status_code=400, detail="Missing required field 'prompt'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: model = model.split(":latest") model = model[0] client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams embedding data and decrements the counter async def stream_embedding_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) if use_openai: async_gen = await client.embeddings.create(input=prompt, model=model) async_gen = rechunk.openai_embeddings2ollama(async_gen) else: async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: json_line = orjson.dumps(async_gen) yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) # 5. Return a StreamingResponse backed by the generator return StreamingResponse( stream_embedding_response(), media_type="application/json", ) @router.post("/api/embed") async def embed_proxy(request: Request): """ Proxy an embed request to Ollama and reply with embeddings. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") _input = payload.get("input") truncate = payload.get("truncate") options = payload.get("options") keep_alive = payload.get("keep_alive") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not _input: raise HTTPException( status_code=400, detail="Missing required field 'input'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint, tracking_model = await choose_endpoint(model) use_openai = is_openai_compatible(endpoint) if use_openai: if ":latest" in model: model = model.split(":latest") model = model[0] client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) else: client = ollama.AsyncClient(host=endpoint) # 3. Async generator that streams embed data and decrements the counter async def stream_embedding_response(): try: # The chat method returns a generator of dicts (or GenerateResponse) if use_openai: async_gen = await client.embeddings.create(input=_input, model=model) async_gen = rechunk.openai_embed2ollama(async_gen, model) else: async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive) if hasattr(async_gen, "model_dump_json"): json_line = async_gen.model_dump_json() else: json_line = orjson.dumps(async_gen) yield json_line.encode("utf-8") + b"\n" finally: # Ensure counter is decremented even if an exception occurs await decrement_usage(endpoint, tracking_model) # 4. Return a StreamingResponse backed by the generator return StreamingResponse( stream_embedding_response(), media_type="application/json", ) @router.post("/api/create") async def create_proxy(request: Request): """ Proxy a create request to all Ollama endpoints and reply with deduplicated status. """ config = get_config() try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") quantize = payload.get("quantize") from_ = payload.get("from") files = payload.get("files") adapters = payload.get("adapters") template = payload.get("template") license = payload.get("license") system = payload.get("system") parameters = payload.get("parameters") messages = payload.get("messages") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) if not from_ and not files: raise HTTPException( status_code=400, detail="You need to provide either from_ or files parameter!" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e status_lists = [] for endpoint in config.endpoints: client = ollama.AsyncClient(host=endpoint) create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False) status_lists.append(create) combined_status = [] for status_list in status_lists: combined_status += status_list final_status = list(dict.fromkeys(combined_status)) return dict(final_status) @router.post("/api/show") async def show_proxy(request: Request, model: Optional[str] = None): """ Proxy a model show request to Ollama and reply with ShowResponse. """ try: body_bytes = await request.body() if not model: payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Endpoint logic endpoint, _ = await choose_endpoint(model, reserve=False) client = ollama.AsyncClient(host=endpoint) # 3. Proxy a simple show request show = await client.show(model=model) # 4. Return ShowResponse return show @router.post("/api/copy") async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None): """ Proxy a model copy request to each Ollama endpoint and reply with Status Code. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() if not source and not destination: payload = orjson.loads(body_bytes.decode("utf-8")) src = payload.get("source") dst = payload.get("destination") else: src = source dst = destination if not src: raise HTTPException( status_code=400, detail="Missing required field 'source'" ) if not dst: raise HTTPException( status_code=400, detail="Missing required field 'destination'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 3. Iterate over all endpoints to copy the model on each endpoint status_list = [] for endpoint in config.endpoints: if "/v1" not in endpoint: client = ollama.AsyncClient(host=endpoint) # 4. Proxy a simple copy request copy = await client.copy(source=src, destination=dst) status_list.append(copy.status) # 4. Return with 200 OK if all went well, 404 if a single endpoint failed return Response(status_code=404 if 404 in status_list else 200) @router.delete("/api/delete") async def delete_proxy(request: Request, model: Optional[str] = None): """ Proxy a model delete request to each Ollama endpoint and reply with Status Code. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() if not model: payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Iterate over all endpoints to delete the model on each endpoint status_list = [] for endpoint in config.endpoints: if "/v1" not in endpoint: client = ollama.AsyncClient(host=endpoint) # 3. Proxy a simple copy request copy = await client.delete(model=model) status_list.append(copy.status) # 4. Return 200 0K, if a single enpoint fails, respond with 404 return Response(status_code=404 if 404 in status_list else 200) @router.post("/api/pull") async def pull_proxy(request: Request, model: Optional[str] = None): """ Proxy a pull request to all Ollama endpoint and report status back. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() if not model: payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") insecure = payload.get("insecure") else: insecure = None if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Iterate over all endpoints to pull the model status_list = [] for endpoint in config.endpoints: if "/v1" not in endpoint: client = ollama.AsyncClient(host=endpoint) # 3. Proxy a simple pull request pull = await client.pull(model=model, insecure=insecure, stream=False) status_list.append(pull) combined_status = [] for status in status_list: combined_status += status # 4. Report back a deduplicated status message final_status = list(dict.fromkeys(combined_status)) return dict(final_status) @router.post("/api/push") async def push_proxy(request: Request): """ Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies. """ config = get_config() # 1. Parse and validate request try: body_bytes = await request.body() payload = orjson.loads(body_bytes.decode("utf-8")) model = payload.get("model") insecure = payload.get("insecure") if not model: raise HTTPException( status_code=400, detail="Missing required field 'model'" ) except orjson.JSONDecodeError as e: raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e # 2. Iterate over all endpoints status_list = [] for endpoint in config.endpoints: client = ollama.AsyncClient(host=endpoint) # 3. Proxy a simple push request push = await client.push(model=model, insecure=insecure, stream=False) status_list.append(push) combined_status = [] for status in status_list: combined_status += status # 4. Report a deduplicated status final_status = list(dict.fromkeys(combined_status)) return dict(final_status) @router.get("/api/version") async def version_proxy(request: Request): """ Proxy a version request to Ollama and reply lowest version of all endpoints. """ config = get_config() # 1. Query all endpoints for version tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] all_versions_raw = await asyncio.gather(*tasks) # Filter out non-string values (e.g., empty lists from failed/timeout responses) all_versions = [v for v in all_versions_raw if isinstance(v, str) and v] if not all_versions: raise HTTPException(status_code=503, detail="No valid version response from any endpoint") def version_key(v): return tuple(map(int, v.split('.'))) # 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility return JSONResponse( content={"version": str(min(all_versions, key=version_key))}, status_code=200, ) @router.get("/api/tags") async def tags_proxy(request: Request): """ Proxy a tags request to Ollama endpoints and reply with a unique list of all models. """ config = get_config() # 1. Query all endpoints for models tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep] # Also query llama-server endpoints not already covered by config.endpoints llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags] all_models = await asyncio.gather(*tasks) models = {'models': []} for modellist in all_models: for model in modellist: if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id model['model'] = model['id'] + ":latest" else: model['id'] = model['model'] if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys model['name'] = model['model'] else: model['id'] = model['model'] models['models'] += modellist # 2. Return a JSONResponse with a deduplicated list of unique models for inference return JSONResponse( content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])}, status_code=200, ) @router.get("/api/ps") async def ps_proxy(request: Request): """ Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models. For Ollama endpoints: queries /api/ps For llama-server endpoints: queries /v1/models with status.value == "loaded" """ config = get_config() # 1. Query Ollama endpoints for running models via /api/ps ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] # 2. Query llama-server endpoints for loaded models via /v1/models # Also query endpoints from llama_server_endpoints that may not be in config.endpoints all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) llama_tasks = [ fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in all_llama_endpoints ] ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else [] models = {'models': []} # Add Ollama models (if any) if ollama_loaded: for modellist in ollama_loaded: models['models'] += modellist # Add llama-server models (filter for loaded only, if any) if llama_loaded: for modellist in llama_loaded: loaded_models = [item for item in modellist if _is_llama_model_loaded(item)] # Convert llama-server format to Ollama-like format for consistency for item in loaded_models: raw_id = item.get("id", "") normalized = _normalize_llama_model_name(raw_id) quant = _extract_llama_quant(raw_id) models['models'].append({ "name": normalized, "id": normalized, "digest": "", "status": item.get("status"), "details": {"quantization_level": quant} if quant else {} }) # 3. Return a JSONResponse with deduplicated currently deployed models # Deduplicate on 'name' rather than 'digest': llama-server models always # have digest="" so deduping on digest collapses all of them to one entry. return JSONResponse( content={"models": dedupe_on_keys(models['models'], ['name'])}, status_code=200, ) @router.get("/api/ps_details") async def ps_details_proxy(request: Request): """ Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances. This keeps /api/ps backward compatible while providing richer data. For Ollama endpoints: queries /api/ps For llama-server endpoints: queries /v1/models with status info """ config = get_config() # 1. Query Ollama endpoints via /api/ps ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep] # 2. Query llama-server endpoints via /v1/models # Also query endpoints from llama_server_endpoints that may not be in config.endpoints all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) llama_tasks = [ (ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)) for ep in all_llama_endpoints ] ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else [] llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else [] models: list[dict] = [] # Add Ollama models with endpoint info (if any) if ollama_loaded: for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded): for model in modellist: if isinstance(model, dict): model_with_endpoint = dict(model) model_with_endpoint["endpoint"] = endpoint models.append(model_with_endpoint) # Add llama-server models with endpoint info and full status metadata (if any) if llama_loaded: # Collect (endpoint, raw_id) pairs to fetch /props in parallel props_requests: list[tuple[str, str]] = [] llama_models_pending: list[dict] = [] for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded): # Include sleeping models too so _fetch_llama_props can unload them loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)] for item in loaded_models: if isinstance(item, dict) and item.get("id"): raw_id = item["id"] normalized = _normalize_llama_model_name(raw_id) quant = _extract_llama_quant(raw_id) model_with_endpoint = { "name": normalized, "id": normalized, "original_name": raw_id, "digest": "", "details": {"quantization_level": quant} if quant else {}, "endpoint": endpoint, "status": item.get("status"), "created": item.get("created"), "owned_by": item.get("owned_by") } # Include full llama-server status details (args, preset) status_info = item.get("status", {}) if isinstance(status_info, dict): model_with_endpoint["llama_status_args"] = status_info.get("args") model_with_endpoint["llama_status_preset"] = status_info.get("preset") llama_models_pending.append(model_with_endpoint) props_requests.append((endpoint, raw_id)) # Fetch /props for each llama-server model to get context length (n_ctx) # and unload sleeping models automatically async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]: client: aiohttp.ClientSession = get_probe_session(endpoint) base_url = endpoint.rstrip("/").removesuffix("/v1") props_url = f"{base_url}/props?model={model_id}" headers = None api_key = config.api_keys.get(endpoint) if api_key: headers = {"Authorization": f"Bearer {api_key}"} try: async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: if resp.status == 200: data = await resp.json() dgs = data.get("default_generation_settings", {}) n_ctx = dgs.get("n_ctx") is_sleeping = data.get("is_sleeping", False) # Embedding models have no sampling params in default_generation_settings is_generation = "temperature" in dgs if is_sleeping: unload_url = f"{base_url}/models/unload" try: async with client.post( unload_url, json={"model": model_id}, headers=headers, ) as unload_resp: print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}") except Exception as ue: print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") return n_ctx, is_sleeping, is_generation except Exception as e: print(f"[ps_details] Failed to fetch props from {props_url}: {e}") return None, False, False props_results = await asyncio.gather( *[_fetch_llama_props(ep, mid) for ep, mid in props_requests] ) for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results): if n_ctx is not None: model_dict["context_length"] = n_ctx if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: normalized = _normalize_llama_model_name(raw_id) _endpoint_nctx[(ep, normalized)] = n_ctx print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True) if not is_sleeping: models.append(model_dict) return JSONResponse(content={"models": models}, status_code=200)