diff --git a/api/ollama.py b/api/ollama.py index 0fc98aa..6ae6027 100644 --- a/api/ollama.py +++ b/api/ollama.py @@ -13,6 +13,7 @@ import asyncio import re import time from typing import Optional +from urllib.parse import quote import aiohttp import ollama @@ -976,6 +977,43 @@ async def _fetch_llama_swap_running(endpoint: str) -> list[dict]: ) +# Match the context size in a llama-swap worker's `cmd` string, e.g. +# "llama-server --port 5818 -hf ... --ctx-size 131072 ...". llama.cpp accepts +# both --ctx-size and the short -c alias. +_CTX_SIZE_CMD_RE = re.compile(r"(?:--ctx-size|-c)[=\s]+(\d+)") + + +def _ctx_size_from_cmd(cmd: str) -> int | None: + """Extract n_ctx from a llama-swap worker `cmd` string, or None if absent.""" + if not cmd: + return None + m = _CTX_SIZE_CMD_RE.search(cmd) + return int(m.group(1)) if m else None + + +async def _fetch_llama_swap_nctx(endpoint: str, model_id: str) -> int | None: + """Fallback when a worker's `cmd` lacks --ctx-size: ask the underlying + llama-server via llama-swap's /upstream//props route (plain /props?model= + is not routed by llama-swap and 404s). Returns n_ctx or None on any failure. + """ + config = get_config() + base_url = endpoint.rstrip("/").removesuffix("/v1") + props_url = f"{base_url}/upstream/{quote(model_id, safe='')}/props" + headers = None + api_key = config.api_keys.get(endpoint) + if api_key: + headers = {"Authorization": f"Bearer {api_key}"} + try: + client: aiohttp.ClientSession = get_probe_session(endpoint) + async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: + if resp.status == 200: + data = await resp.json() + return data.get("default_generation_settings", {}).get("n_ctx") + except Exception as e: + print(f"[ps_details] Failed to fetch props from {props_url}: {e}") + return None + + @router.get("/api/ps") async def ps_proxy(request: Request): """ @@ -1161,6 +1199,7 @@ async def ps_details_proxy(request: Request): swap_running = await asyncio.gather( *[_fetch_llama_swap_running(ep) for ep in config.llama_swap_endpoints] ) + swap_nctx_fallbacks: list[tuple[str, str, dict]] = [] for endpoint, runlist in zip(config.llama_swap_endpoints, swap_running): for item in runlist: if not isinstance(item, dict) or item.get("state") != "ready": @@ -1170,7 +1209,7 @@ async def ps_details_proxy(request: Request): continue normalized = _normalize_llama_model_name(raw_id) quant = _extract_llama_quant(raw_id) - models.append({ + swap_model = { "name": normalized, "id": normalized, "original_name": raw_id, @@ -1180,6 +1219,29 @@ async def ps_details_proxy(request: Request): "state": item.get("state"), "ttl": item.get("ttl"), "proxy": item.get("proxy"), - }) + } + # llama-swap omits n_ctx from /running, but the worker's launch + # command carries --ctx-size, so parse it from there (no extra + # request). Workers whose cmd lacks the flag fall back to an + # /upstream//props probe below. + n_ctx = _ctx_size_from_cmd(item.get("cmd", "")) + if n_ctx is not None: + swap_model["context_length"] = n_ctx + if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, normalized)] = n_ctx + else: + swap_nctx_fallbacks.append((endpoint, raw_id, swap_model)) + models.append(swap_model) + + # Resolve ctx for workers whose cmd lacked --ctx-size via /upstream props. + if swap_nctx_fallbacks: + fallback_results = await asyncio.gather( + *[_fetch_llama_swap_nctx(ep, rid) for ep, rid, _ in swap_nctx_fallbacks] + ) + for (ep, _rid, swap_model), n_ctx in zip(swap_nctx_fallbacks, fallback_results): + if n_ctx is not None: + swap_model["context_length"] = n_ctx + if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(ep, swap_model["id"])] = n_ctx return JSONResponse(content={"models": models}, status_code=200) diff --git a/test/test_llama_swap.py b/test/test_llama_swap.py index d0427bf..74197d4 100644 --- a/test/test_llama_swap.py +++ b/test/test_llama_swap.py @@ -6,6 +6,7 @@ import pytest import router import backends.control as control import api.openai as openai_api +import api.ollama as ollama_api SWAP_EP = "http://swap:8080/v1" SERVER_EP = "http://server:8080/v1" @@ -107,3 +108,24 @@ class TestUpstreamResolution: with patch.object(openai_api, "get_config", lambda: cfg): ep = await openai_api._resolve_llama_swap_endpoint("any") assert ep is None + + +class TestCtxSizeFromCmd: + """ctx-size parsing from a /running worker's launch `cmd` string.""" + + def test_parses_long_flag(self): + cmd = ("llama-server --port 5818\n -hf unsloth/gpt-oss-20b-GGUF:F16\n" + " --ctx-size 131072\n --temp 1.0\n") + assert ollama_api._ctx_size_from_cmd(cmd) == 131072 + + def test_parses_short_flag(self): + assert ollama_api._ctx_size_from_cmd("llama-server -c 8192 --port 1") == 8192 + + def test_parses_equals_form(self): + assert ollama_api._ctx_size_from_cmd("llama-server --ctx-size=4096") == 4096 + + def test_returns_none_when_absent(self): + assert ollama_api._ctx_size_from_cmd("llama-server --port 5818") is None + + def test_returns_none_for_empty(self): + assert ollama_api._ctx_size_from_cmd("") is None