commit
a5a0bd51c0
2 changed files with 279 additions and 89 deletions
366
router.py
366
router.py
|
|
@ -53,6 +53,9 @@ _loaded_error_cache_lock = asyncio.Lock()
|
|||
_inflight_available_models: dict[str, asyncio.Task] = {}
|
||||
_inflight_loaded_models: dict[str, asyncio.Task] = {}
|
||||
_inflight_lock = asyncio.Lock()
|
||||
_bg_refresh_available: dict[str, asyncio.Task] = {}
|
||||
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
|
||||
_bg_refresh_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Queues
|
||||
|
|
@ -414,6 +417,30 @@ def is_openai_compatible(endpoint: str) -> bool:
|
|||
"""
|
||||
return "/v1" in endpoint or endpoint in config.llama_server_endpoints
|
||||
|
||||
def get_tracking_model(endpoint: str, model: str) -> str:
|
||||
"""
|
||||
Normalize model name for tracking purposes so it matches the PS table key.
|
||||
|
||||
- For llama-server endpoints: strips HF prefix and quantization suffix
|
||||
- For Ollama endpoints: appends ":latest" if no version suffix is present
|
||||
- For external OpenAI endpoints: returns as-is (not shown in PS)
|
||||
|
||||
This ensures consistent model naming across all routes for usage tracking.
|
||||
"""
|
||||
# External OpenAI endpoints are not shown in PS, keep as-is
|
||||
if is_ext_openai_endpoint(endpoint):
|
||||
return model
|
||||
|
||||
# llama-server endpoints use normalized names in PS
|
||||
if endpoint in config.llama_server_endpoints:
|
||||
return _normalize_llama_model_name(model)
|
||||
|
||||
# Ollama endpoints: append ":latest" if no version suffix
|
||||
if ":" not in model:
|
||||
return model + ":latest"
|
||||
|
||||
return model
|
||||
|
||||
async def token_worker() -> None:
|
||||
try:
|
||||
while True:
|
||||
|
|
@ -605,12 +632,23 @@ class fetch:
|
|||
"""
|
||||
Background task to refresh available models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
||||
_bg_refresh_available[endpoint] = task
|
||||
|
||||
try:
|
||||
await fetch._fetch_available_models_internal(endpoint, api_key)
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_available.get(endpoint) is task:
|
||||
_bg_refresh_available.pop(endpoint, None)
|
||||
|
||||
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -634,7 +672,7 @@ class fetch:
|
|||
if endpoint in _models_cache:
|
||||
models, cached_at = _models_cache[endpoint]
|
||||
|
||||
# FRESH: < 300s old - return immediately
|
||||
# FRESH: <= 300s old - return immediately
|
||||
if _is_fresh(cached_at, 300):
|
||||
return models
|
||||
|
||||
|
|
@ -649,7 +687,7 @@ class fetch:
|
|||
# Check error cache with lock protection
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 30):
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
# Still within the short error TTL – pretend nothing is available
|
||||
return set()
|
||||
# Error expired – remove it
|
||||
|
|
@ -735,12 +773,23 @@ class fetch:
|
|||
"""
|
||||
Background task to refresh loaded models cache without blocking the caller.
|
||||
Used for stale-while-revalidate pattern.
|
||||
Deduplicates: only one background refresh runs per endpoint at a time.
|
||||
"""
|
||||
async with _bg_refresh_lock:
|
||||
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
||||
return # A refresh is already running for this endpoint
|
||||
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
||||
_bg_refresh_loaded[endpoint] = task
|
||||
|
||||
try:
|
||||
await fetch._fetch_loaded_models_internal(endpoint)
|
||||
await task
|
||||
except Exception as e:
|
||||
# Silently fail - cache will remain stale but functional
|
||||
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
||||
finally:
|
||||
async with _bg_refresh_lock:
|
||||
if _bg_refresh_loaded.get(endpoint) is task:
|
||||
_bg_refresh_loaded.pop(endpoint, None)
|
||||
|
||||
async def loaded_models(endpoint: str) -> Set[str]:
|
||||
"""
|
||||
|
|
@ -760,7 +809,7 @@ class fetch:
|
|||
models, cached_at = _loaded_models_cache[endpoint]
|
||||
|
||||
# FRESH: < 10s old - return immediately
|
||||
if _is_fresh(cached_at, 30):
|
||||
if _is_fresh(cached_at, 10):
|
||||
return models
|
||||
|
||||
# STALE: 10-60s old - return stale data and refresh in background
|
||||
|
|
@ -775,7 +824,7 @@ class fetch:
|
|||
# Check error cache with lock protection
|
||||
async with _loaded_error_cache_lock:
|
||||
if endpoint in _loaded_error_cache:
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 30):
|
||||
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
||||
return set()
|
||||
# Error expired - remove it
|
||||
del _loaded_error_cache[endpoint]
|
||||
|
|
@ -813,7 +862,7 @@ class fetch:
|
|||
if not skip_error_cache:
|
||||
async with _available_error_cache_lock:
|
||||
if endpoint in _available_error_cache:
|
||||
if _is_fresh(_available_error_cache[endpoint], 30):
|
||||
if _is_fresh(_available_error_cache[endpoint], 300):
|
||||
return []
|
||||
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
|
|
@ -910,7 +959,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
|
||||
await increment_usage(endpoint, model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
try:
|
||||
if use_openai:
|
||||
|
|
@ -923,11 +974,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
async for chunk in response:
|
||||
chunks.append(chunk)
|
||||
_accumulate_openai_tc_delta(chunk, tc_acc)
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Convert to Ollama format
|
||||
if chunks:
|
||||
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
|
||||
|
|
@ -935,10 +992,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
if tc_acc and response.message:
|
||||
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
||||
else:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if response.usage is not None:
|
||||
prompt_tok = response.usage.prompt_tokens or 0
|
||||
comp_tok = response.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(response)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
|
||||
else:
|
||||
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
||||
|
|
@ -950,18 +1014,18 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
|
|||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if chunks:
|
||||
response = chunks[-1]
|
||||
else:
|
||||
prompt_tok = response.prompt_eval_count or 0
|
||||
comp_tok = response.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
|
||||
return response
|
||||
finally:
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
def get_last_user_content(messages):
|
||||
"""
|
||||
|
|
@ -1317,6 +1381,34 @@ class rechunk:
|
|||
eval_duration=None,
|
||||
embeddings=[chunk.data[0].embedding])
|
||||
return rechunk
|
||||
|
||||
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
|
||||
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
|
||||
|
||||
llama-server returns a ``timings`` dict instead of the standard OpenAI
|
||||
``usage`` field::
|
||||
|
||||
"timings": {
|
||||
"cache_n": 236, // prompt tokens reused from cache
|
||||
"prompt_n": 1, // prompt tokens processed
|
||||
"predicted_n": 35 // predicted (completion) tokens
|
||||
}
|
||||
|
||||
prompt_tokens = prompt_n + cache_n
|
||||
completion_tokens = predicted_n
|
||||
|
||||
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
|
||||
timings are found.
|
||||
"""
|
||||
timings = getattr(obj, "timings", None)
|
||||
if timings is None:
|
||||
return None
|
||||
if isinstance(timings, dict):
|
||||
prompt_n = timings.get("prompt_n", 0) or 0
|
||||
cache_n = timings.get("cache_n", 0) or 0
|
||||
predicted_n = timings.get("predicted_n", 0) or 0
|
||||
return (prompt_n + cache_n, predicted_n)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
|
|
@ -1433,27 +1525,26 @@ async def choose_endpoint(model: str) -> str:
|
|||
|
||||
# Protect all reads of usage_counts with the lock
|
||||
async with usage_lock:
|
||||
# Helper: get current usage count for (endpoint, model)
|
||||
def current_usage(ep: str) -> int:
|
||||
return usage_counts.get(ep, {}).get(model, 0)
|
||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||
# that increment_usage/decrement_usage store — raw model names differ from
|
||||
# tracking names for llama-server (HF prefix / quant suffix stripped).
|
||||
def tracking_usage(ep: str) -> int:
|
||||
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
|
||||
|
||||
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
|
||||
if model in models and tracking_usage(ep) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
# Sort by per-model usage in DESCENDING order to ensure model affinity
|
||||
# Endpoints with higher usage (already handling this model) should be preferred
|
||||
# until they reach max_concurrent_connections
|
||||
loaded_and_free.sort(
|
||||
key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order
|
||||
)
|
||||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
|
||||
# If all endpoints have zero usage for this model, randomize to distribute
|
||||
# different models across different endpoints for better resource utilization
|
||||
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free):
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
return random.choice(loaded_and_free)
|
||||
|
||||
return loaded_and_free[0]
|
||||
|
|
@ -1461,30 +1552,22 @@ async def choose_endpoint(model: str) -> str:
|
|||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
|
||||
if tracking_usage(ep) < config.max_concurrent_connections
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
# Sort by per-model usage (descending) first to ensure model affinity
|
||||
# Even if the model isn't showing as "loaded" in /api/ps yet (e.g., during initial loading),
|
||||
# we want to send subsequent requests to the endpoint that already has connections for this model
|
||||
# Then by total endpoint usage (ascending) to balance idle endpoints
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: (
|
||||
#-usage_counts.get(ep, {}).get(model, 0), # Primary: per-model usage (descending - prefer endpoints with connections)
|
||||
sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints)
|
||||
)
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
|
||||
# If all endpoints have zero usage for this specific model, randomize to distribute
|
||||
# different models across different endpoints for better resource utilization
|
||||
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot):
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
return random.choice(endpoints_with_free_slot)
|
||||
|
||||
return endpoints_with_free_slot[0]
|
||||
|
||||
# 5️⃣ All candidate endpoints are saturated – pick one with lowest usages count (will queue)
|
||||
ep = min(candidate_endpoints, key=current_usage)
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
ep = min(candidate_endpoints, key=tracking_usage)
|
||||
return ep
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -1528,6 +1611,8 @@ async def proxy(request: Request):
|
|||
|
||||
endpoint = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1552,7 +1637,7 @@ async def proxy(request: Request):
|
|||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Async generator that streams data and decrements the counter
|
||||
async def stream_generate_response():
|
||||
|
|
@ -1569,7 +1654,7 @@ async def proxy(request: Request):
|
|||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
|
|
@ -1584,7 +1669,7 @@ async def proxy(request: Request):
|
|||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -1594,7 +1679,7 @@ async def proxy(request: Request):
|
|||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
|
|
@ -1649,6 +1734,8 @@ async def chat_proxy(request: Request):
|
|||
opt = False
|
||||
endpoint = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1679,7 +1766,7 @@ async def chat_proxy(request: Request):
|
|||
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams chat data and decrements the counter
|
||||
async def stream_chat_response():
|
||||
try:
|
||||
|
|
@ -1706,7 +1793,7 @@ async def chat_proxy(request: Request):
|
|||
prompt_tok = chunk.prompt_eval_count or 0
|
||||
comp_tok = chunk.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
if hasattr(chunk, "model_dump_json"):
|
||||
json_line = chunk.model_dump_json()
|
||||
else:
|
||||
|
|
@ -1721,7 +1808,7 @@ async def chat_proxy(request: Request):
|
|||
prompt_tok = async_gen.prompt_eval_count or 0
|
||||
comp_tok = async_gen.eval_count or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
response
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -1731,7 +1818,7 @@ async def chat_proxy(request: Request):
|
|||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
media_type = "application/x-ndjson" if stream else "application/json"
|
||||
|
|
@ -1773,6 +1860,8 @@ async def embedding_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1780,7 +1869,7 @@ async def embedding_proxy(request: Request):
|
|||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
|
|
@ -1797,7 +1886,7 @@ async def embedding_proxy(request: Request):
|
|||
yield json_line.encode("utf-8") + b"\n"
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
|
|
@ -1839,6 +1928,8 @@ async def embed_proxy(request: Request):
|
|||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
use_openai = is_openai_compatible(endpoint)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
if use_openai:
|
||||
if ":latest" in model:
|
||||
model = model.split(":latest")
|
||||
|
|
@ -1846,7 +1937,7 @@ async def embed_proxy(request: Request):
|
|||
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
else:
|
||||
client = ollama.AsyncClient(host=endpoint)
|
||||
await increment_usage(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
# 3. Async generator that streams embed data and decrements the counter
|
||||
async def stream_embedding_response():
|
||||
try:
|
||||
|
|
@ -1863,7 +1954,7 @@ async def embed_proxy(request: Request):
|
|||
yield json_line.encode("utf-8") + b"\n"
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
|
|
@ -2226,7 +2317,13 @@ async def version_proxy(request: Request):
|
|||
"""
|
||||
# 1. Query all endpoints for version
|
||||
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
||||
all_versions = await asyncio.gather(*tasks)
|
||||
all_versions_raw = await asyncio.gather(*tasks)
|
||||
|
||||
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
|
||||
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
|
||||
|
||||
if not all_versions:
|
||||
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
|
||||
|
||||
def version_key(v):
|
||||
return tuple(map(int, v.split('.')))
|
||||
|
|
@ -2364,6 +2461,10 @@ async def ps_details_proxy(request: Request):
|
|||
|
||||
# Add llama-server models with endpoint info and full status metadata (if any)
|
||||
if llama_loaded:
|
||||
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
|
||||
props_requests: list[tuple[str, str]] = []
|
||||
llama_models_pending: list[dict] = []
|
||||
|
||||
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
|
||||
# Filter for loaded models only
|
||||
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
|
||||
|
|
@ -2388,7 +2489,53 @@ async def ps_details_proxy(request: Request):
|
|||
if isinstance(status_info, dict):
|
||||
model_with_endpoint["llama_status_args"] = status_info.get("args")
|
||||
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
|
||||
models.append(model_with_endpoint)
|
||||
llama_models_pending.append(model_with_endpoint)
|
||||
props_requests.append((endpoint, raw_id))
|
||||
|
||||
# Fetch /props for each llama-server model to get context length (n_ctx)
|
||||
# and unload sleeping models automatically
|
||||
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool]:
|
||||
client: aiohttp.ClientSession = app_state["session"]
|
||||
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
||||
props_url = f"{base_url}/props?model={model_id}"
|
||||
headers = None
|
||||
api_key = config.api_keys.get(endpoint)
|
||||
if api_key:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
try:
|
||||
async with client.get(props_url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
dgs = data.get("default_generation_settings", {})
|
||||
n_ctx = dgs.get("n_ctx")
|
||||
is_sleeping = data.get("is_sleeping", False)
|
||||
|
||||
if is_sleeping:
|
||||
unload_url = f"{base_url}/models/unload"
|
||||
try:
|
||||
async with client.post(
|
||||
unload_url,
|
||||
json={"model": model_id},
|
||||
headers=headers,
|
||||
) as unload_resp:
|
||||
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
|
||||
except Exception as ue:
|
||||
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
|
||||
|
||||
return n_ctx, is_sleeping
|
||||
except Exception as e:
|
||||
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
|
||||
return None, False
|
||||
|
||||
props_results = await asyncio.gather(
|
||||
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
||||
)
|
||||
|
||||
for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results):
|
||||
if n_ctx is not None:
|
||||
model_dict["context_length"] = n_ctx
|
||||
if not is_sleeping:
|
||||
models.append(model_dict)
|
||||
|
||||
return JSONResponse(content={"models": models}, status_code=200)
|
||||
|
||||
|
|
@ -2479,7 +2626,9 @@ async def openai_embedding_proxy(request: Request):
|
|||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
if is_openai_compatible(endpoint):
|
||||
api_key = config.api_keys.get(endpoint, "no-key")
|
||||
else:
|
||||
|
|
@ -2490,8 +2639,8 @@ async def openai_embedding_proxy(request: Request):
|
|||
|
||||
# 3. Async generator that streams embedding data and decrements the counter
|
||||
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
||||
|
||||
await decrement_usage(endpoint, model)
|
||||
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 5. Return a StreamingResponse backed by the generator
|
||||
return async_gen
|
||||
|
|
@ -2568,7 +2717,9 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
# 3. Async generator that streams completions data and decrements the counter
|
||||
|
|
@ -2593,23 +2744,42 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
delta = chunk.choices[0].delta
|
||||
has_content = delta.content is not None
|
||||
has_reasoning = (
|
||||
getattr(delta, "reasoning_content", None) is not None
|
||||
or getattr(delta, "reasoning", None) is not None
|
||||
)
|
||||
has_tool_calls = getattr(delta, "tool_calls", None) is not None
|
||||
if has_content or has_reasoning or has_tool_calls:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
local_model = model
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -2619,12 +2789,12 @@ async def openai_chat_completions_proxy(request: Request):
|
|||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ochat_response(),
|
||||
media_type="application/json",
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
@ -2693,7 +2863,9 @@ async def openai_completions_proxy(request: Request):
|
|||
|
||||
# 2. Endpoint logic
|
||||
endpoint = await choose_endpoint(model)
|
||||
await increment_usage(endpoint, model)
|
||||
# Normalize model name for tracking so it matches the PS table key
|
||||
tracking_model = get_tracking_model(endpoint, model)
|
||||
await increment_usage(endpoint, tracking_model)
|
||||
base_url = ep2base(endpoint)
|
||||
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
||||
|
||||
|
|
@ -2710,24 +2882,42 @@ async def openai_completions_proxy(request: Request):
|
|||
else orjson.dumps(chunk)
|
||||
)
|
||||
if chunk.choices:
|
||||
if chunk.choices[0].finish_reason == None:
|
||||
choice = chunk.choices[0]
|
||||
has_text = getattr(choice, "text", None) is not None
|
||||
has_reasoning = (
|
||||
getattr(choice, "reasoning_content", None) is not None
|
||||
or getattr(choice, "reasoning", None) is not None
|
||||
)
|
||||
if has_text or has_reasoning or choice.finish_reason is not None:
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
elif chunk.usage is not None:
|
||||
# Forward the usage-only final chunk (e.g. from llama-server)
|
||||
yield f"data: {data}\n\n".encode("utf-8")
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if chunk.usage is not None:
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
local_model = model
|
||||
if not is_ext_openai_endpoint(endpoint):
|
||||
if not ":" in model:
|
||||
local_model = model if ":" in model else model + ":latest"
|
||||
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
|
||||
prompt_tok = chunk.usage.prompt_tokens or 0
|
||||
comp_tok = chunk.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
# Final DONE event
|
||||
yield b"data: [DONE]\n\n"
|
||||
else:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
prompt_tok = 0
|
||||
comp_tok = 0
|
||||
if async_gen.usage is not None:
|
||||
prompt_tok = async_gen.usage.prompt_tokens or 0
|
||||
comp_tok = async_gen.usage.completion_tokens or 0
|
||||
else:
|
||||
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
||||
if llama_usage:
|
||||
prompt_tok, comp_tok = llama_usage
|
||||
if prompt_tok != 0 or comp_tok != 0:
|
||||
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
|
||||
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
||||
json_line = (
|
||||
async_gen.model_dump_json()
|
||||
if hasattr(async_gen, "model_dump_json")
|
||||
|
|
@ -2737,12 +2927,12 @@ async def openai_completions_proxy(request: Request):
|
|||
|
||||
finally:
|
||||
# Ensure counter is decremented even if an exception occurs
|
||||
await decrement_usage(endpoint, model)
|
||||
await decrement_usage(endpoint, tracking_model)
|
||||
|
||||
# 4. Return a StreamingResponse backed by the generator
|
||||
return StreamingResponse(
|
||||
stream_ocompletions_response(),
|
||||
media_type="application/json",
|
||||
media_type="text/event-stream" if stream else "application/json",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -928,7 +928,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
|
|||
const uniqueEndpoints = Array.from(new Set(endpoints));
|
||||
const endpointsData = encodeURIComponent(JSON.stringify(uniqueEndpoints));
|
||||
return `<tr data-model="${modelName}" data-endpoints="${endpointsData}">
|
||||
<td class="model">${modelName} <a href="#" class="stats-link" data-model="${originalName}">stats</a></td>
|
||||
<td class="model">${modelName} <a href="#" class="stats-link" data-model="${modelName}">stats</a></td>
|
||||
<td>${renderInstanceList(endpoints)}</td>
|
||||
<td>${params}</td>
|
||||
<td>${quant}</td>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue