fix: exclude embedding models from preemptive context shift caches
This commit is contained in:
parent
21d6835253
commit
1e9996c393
1 changed files with 7 additions and 5 deletions
12
router.py
12
router.py
|
|
@ -2796,7 +2796,7 @@ async def ps_details_proxy(request: Request):
|
|||
|
||||
# Fetch /props for each llama-server model to get context length (n_ctx)
|
||||
# and unload sleeping models automatically
|
||||
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool]:
|
||||
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]:
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
props_url = f"{base_url}/props?model={model_id}"
|
||||
|
|
@ -2811,6 +2811,8 @@ async def ps_details_proxy(request: Request):
|
|||
dgs = data.get("default_generation_settings", {})
|
||||
n_ctx = dgs.get("n_ctx")
|
||||
is_sleeping = data.get("is_sleeping", False)
|
||||
# Embedding models have no sampling params in default_generation_settings
|
||||
is_generation = "temperature" in dgs
|
||||
|
||||
if is_sleeping:
|
||||
unload_url = f"{base_url}/models/unload"
|
||||
|
|
@ -2824,19 +2826,19 @@ async def ps_details_proxy(request: Request):
|
|||
except Exception as ue:
|
||||
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
|
||||
|
||||
return n_ctx, is_sleeping
|
||||
return n_ctx, is_sleeping, is_generation
|
||||
except Exception as e:
|
||||
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
|
||||
return None, False
|
||||
return None, False, False
|
||||
|
||||
props_results = await asyncio.gather(
|
||||
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
||||
)
|
||||
|
||||
for (ep, raw_id), model_dict, (n_ctx, is_sleeping) in zip(props_requests, llama_models_pending, props_results):
|
||||
for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results):
|
||||
if n_ctx is not None:
|
||||
model_dict["context_length"] = n_ctx
|
||||
if 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT:
|
||||
normalized = _normalize_llama_model_name(raw_id)
|
||||
_endpoint_nctx[(ep, normalized)] = n_ctx
|
||||
print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue