Merge pull request #27 from nomyo-ai/dev-v0.6.X

Dev v0.6.x to prod
This commit is contained in:
Alpha Nerd 2026-02-25 13:08:15 +01:00 committed by GitHub
commit a5a0bd51c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 279 additions and 89 deletions

366
router.py
View file

@ -53,6 +53,9 @@ _loaded_error_cache_lock = asyncio.Lock()
_inflight_available_models: dict[str, asyncio.Task] = {}
_inflight_loaded_models: dict[str, asyncio.Task] = {}
_inflight_lock = asyncio.Lock()
_bg_refresh_available: dict[str, asyncio.Task] = {}
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
_bg_refresh_lock = asyncio.Lock()
# ------------------------------------------------------------------
# Queues
@ -414,6 +417,30 @@ def is_openai_compatible(endpoint: str) -> bool:
"""
return "/v1" in endpoint or endpoint in config.llama_server_endpoints
def get_tracking_model(endpoint: str, model: str) -> str:
"""
Normalize model name for tracking purposes so it matches the PS table key.
- For llama-server endpoints: strips HF prefix and quantization suffix
- For Ollama endpoints: appends ":latest" if no version suffix is present
- For external OpenAI endpoints: returns as-is (not shown in PS)
This ensures consistent model naming across all routes for usage tracking.
"""
# External OpenAI endpoints are not shown in PS, keep as-is
if is_ext_openai_endpoint(endpoint):
return model
# llama-server endpoints use normalized names in PS
if endpoint in config.llama_server_endpoints:
return _normalize_llama_model_name(model)
# Ollama endpoints: append ":latest" if no version suffix
if ":" not in model:
return model + ":latest"
return model
async def token_worker() -> None:
try:
while True:
@ -605,12 +632,23 @@ class fetch:
"""
Background task to refresh available models cache without blocking the caller.
Used for stale-while-revalidate pattern.
Deduplicates: only one background refresh runs per endpoint at a time.
"""
async with _bg_refresh_lock:
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
return # A refresh is already running for this endpoint
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
_bg_refresh_available[endpoint] = task
try:
await fetch._fetch_available_models_internal(endpoint, api_key)
await task
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
finally:
async with _bg_refresh_lock:
if _bg_refresh_available.get(endpoint) is task:
_bg_refresh_available.pop(endpoint, None)
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
"""
@ -634,7 +672,7 @@ class fetch:
if endpoint in _models_cache:
models, cached_at = _models_cache[endpoint]
# FRESH: < 300s old - return immediately
# FRESH: <= 300s old - return immediately
if _is_fresh(cached_at, 300):
return models
@ -649,7 +687,7 @@ class fetch:
# Check error cache with lock protection
async with _available_error_cache_lock:
if endpoint in _available_error_cache:
if _is_fresh(_available_error_cache[endpoint], 30):
if _is_fresh(_available_error_cache[endpoint], 300):
# Still within the short error TTL pretend nothing is available
return set()
# Error expired remove it
@ -735,12 +773,23 @@ class fetch:
"""
Background task to refresh loaded models cache without blocking the caller.
Used for stale-while-revalidate pattern.
Deduplicates: only one background refresh runs per endpoint at a time.
"""
async with _bg_refresh_lock:
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
return # A refresh is already running for this endpoint
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
_bg_refresh_loaded[endpoint] = task
try:
await fetch._fetch_loaded_models_internal(endpoint)
await task
except Exception as e:
# Silently fail - cache will remain stale but functional
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
finally:
async with _bg_refresh_lock:
if _bg_refresh_loaded.get(endpoint) is task:
_bg_refresh_loaded.pop(endpoint, None)
async def loaded_models(endpoint: str) -> Set[str]:
"""
@ -760,7 +809,7 @@ class fetch:
models, cached_at = _loaded_models_cache[endpoint]
# FRESH: < 10s old - return immediately
if _is_fresh(cached_at, 30):
if _is_fresh(cached_at, 10):
return models
# STALE: 10-60s old - return stale data and refresh in background
@ -775,7 +824,7 @@ class fetch:
# Check error cache with lock protection
async with _loaded_error_cache_lock:
if endpoint in _loaded_error_cache:
if _is_fresh(_loaded_error_cache[endpoint], 30):
if _is_fresh(_loaded_error_cache[endpoint], 300):
return set()
# Error expired - remove it
del _loaded_error_cache[endpoint]
@ -813,7 +862,7 @@ class fetch:
if not skip_error_cache:
async with _available_error_cache_lock:
if endpoint in _available_error_cache:
if _is_fresh(_available_error_cache[endpoint], 30):
if _is_fresh(_available_error_cache[endpoint], 300):
return []
client: aiohttp.ClientSession = app_state["session"]
@ -910,7 +959,9 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
try:
if use_openai:
@ -923,11 +974,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
async for chunk in response:
chunks.append(chunk)
_accumulate_openai_tc_delta(chunk, tc_acc)
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Convert to Ollama format
if chunks:
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
@ -935,10 +992,17 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
if tc_acc and response.message:
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
else:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
prompt_tok = 0
comp_tok = 0
if response.usage is not None:
prompt_tok = response.usage.prompt_tokens or 0
comp_tok = response.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(response)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
else:
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
@ -950,18 +1014,18 @@ async def _make_chat_request(endpoint: str, model: str, messages: list, tools=No
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if chunks:
response = chunks[-1]
else:
prompt_tok = response.prompt_eval_count or 0
comp_tok = response.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
return response
finally:
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
def get_last_user_content(messages):
"""
@ -1317,6 +1381,34 @@ class rechunk:
eval_duration=None,
embeddings=[chunk.data[0].embedding])
return rechunk
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
llama-server returns a ``timings`` dict instead of the standard OpenAI
``usage`` field::
"timings": {
"cache_n": 236, // prompt tokens reused from cache
"prompt_n": 1, // prompt tokens processed
"predicted_n": 35 // predicted (completion) tokens
}
prompt_tokens = prompt_n + cache_n
completion_tokens = predicted_n
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
timings are found.
"""
timings = getattr(obj, "timings", None)
if timings is None:
return None
if isinstance(timings, dict):
prompt_n = timings.get("prompt_n", 0) or 0
cache_n = timings.get("cache_n", 0) or 0
predicted_n = timings.get("predicted_n", 0) or 0
return (prompt_n + cache_n, predicted_n)
return None
# ------------------------------------------------------------------
# SSE Helpser
@ -1433,27 +1525,26 @@ async def choose_endpoint(model: str) -> str:
# Protect all reads of usage_counts with the lock
async with usage_lock:
# Helper: get current usage count for (endpoint, model)
def current_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(model, 0)
# Helper: current usage for (endpoint, model) using the same normalized key
# that increment_usage/decrement_usage store — raw model names differ from
# tracking names for llama-server (HF prefix / quant suffix stripped).
def tracking_usage(ep: str) -> int:
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
# 3⃣ Endpoints that have the model loaded *and* a free slot
loaded_and_free = [
ep for ep, models in zip(candidate_endpoints, loaded_sets)
if model in models and usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
if model in models and tracking_usage(ep) < config.max_concurrent_connections
]
if loaded_and_free:
# Sort by per-model usage in DESCENDING order to ensure model affinity
# Endpoints with higher usage (already handling this model) should be preferred
# until they reach max_concurrent_connections
loaded_and_free.sort(
key=lambda ep: -usage_counts.get(ep, {}).get(model, 0) # Negative for descending order
)
# Sort ascending for load balancing — all endpoints here already have the
# model loaded, so there is no model-switching cost to optimise for.
loaded_and_free.sort(key=tracking_usage)
# If all endpoints have zero usage for this model, randomize to distribute
# different models across different endpoints for better resource utilization
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in loaded_and_free):
# When all candidates are equally idle, randomise to avoid always picking
# the first entry in a stable sort.
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
return random.choice(loaded_and_free)
return loaded_and_free[0]
@ -1461,30 +1552,22 @@ async def choose_endpoint(model: str) -> str:
# 4⃣ Endpoints among the candidates that simply have a free slot
endpoints_with_free_slot = [
ep for ep in candidate_endpoints
if usage_counts.get(ep, {}).get(model, 0) < config.max_concurrent_connections
if tracking_usage(ep) < config.max_concurrent_connections
]
if endpoints_with_free_slot:
# Sort by per-model usage (descending) first to ensure model affinity
# Even if the model isn't showing as "loaded" in /api/ps yet (e.g., during initial loading),
# we want to send subsequent requests to the endpoint that already has connections for this model
# Then by total endpoint usage (ascending) to balance idle endpoints
# Sort by total endpoint load (ascending) to prefer idle endpoints.
endpoints_with_free_slot.sort(
key=lambda ep: (
#-usage_counts.get(ep, {}).get(model, 0), # Primary: per-model usage (descending - prefer endpoints with connections)
sum(usage_counts.get(ep, {}).values()) # Secondary: total endpoint usage (ascending - prefer idle endpoints)
)
key=lambda ep: sum(usage_counts.get(ep, {}).values())
)
# If all endpoints have zero usage for this specific model, randomize to distribute
# different models across different endpoints for better resource utilization
if all(usage_counts.get(ep, {}).get(model, 0) == 0 for ep in endpoints_with_free_slot):
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
return random.choice(endpoints_with_free_slot)
return endpoints_with_free_slot[0]
# 5⃣ All candidate endpoints are saturated pick one with lowest usages count (will queue)
ep = min(candidate_endpoints, key=current_usage)
# 5⃣ All candidate endpoints are saturated pick the least-busy one (will queue)
ep = min(candidate_endpoints, key=tracking_usage)
return ep
# -------------------------------------------------------------
@ -1528,6 +1611,8 @@ async def proxy(request: Request):
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1552,7 +1637,7 @@ async def proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 4. Async generator that streams data and decrements the counter
async def stream_generate_response():
@ -1569,7 +1654,7 @@ async def proxy(request: Request):
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
@ -1584,7 +1669,7 @@ async def proxy(request: Request):
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
@ -1594,7 +1679,7 @@ async def proxy(request: Request):
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return StreamingResponse(
@ -1649,6 +1734,8 @@ async def chat_proxy(request: Request):
opt = False
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1679,7 +1766,7 @@ async def chat_proxy(request: Request):
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams chat data and decrements the counter
async def stream_chat_response():
try:
@ -1706,7 +1793,7 @@ async def chat_proxy(request: Request):
prompt_tok = chunk.prompt_eval_count or 0
comp_tok = chunk.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
if hasattr(chunk, "model_dump_json"):
json_line = chunk.model_dump_json()
else:
@ -1721,7 +1808,7 @@ async def chat_proxy(request: Request):
prompt_tok = async_gen.prompt_eval_count or 0
comp_tok = async_gen.eval_count or 0
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
response
if hasattr(async_gen, "model_dump_json")
@ -1731,7 +1818,7 @@ async def chat_proxy(request: Request):
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
media_type = "application/x-ndjson" if stream else "application/json"
@ -1773,6 +1860,8 @@ async def embedding_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1780,7 +1869,7 @@ async def embedding_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embedding data and decrements the counter
async def stream_embedding_response():
try:
@ -1797,7 +1886,7 @@ async def embedding_proxy(request: Request):
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return StreamingResponse(
@ -1839,6 +1928,8 @@ async def embed_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
use_openai = is_openai_compatible(endpoint)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
if use_openai:
if ":latest" in model:
model = model.split(":latest")
@ -1846,7 +1937,7 @@ async def embed_proxy(request: Request):
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
else:
client = ollama.AsyncClient(host=endpoint)
await increment_usage(endpoint, model)
await increment_usage(endpoint, tracking_model)
# 3. Async generator that streams embed data and decrements the counter
async def stream_embedding_response():
try:
@ -1863,7 +1954,7 @@ async def embed_proxy(request: Request):
yield json_line.encode("utf-8") + b"\n"
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
@ -2226,7 +2317,13 @@ async def version_proxy(request: Request):
"""
# 1. Query all endpoints for version
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
all_versions = await asyncio.gather(*tasks)
all_versions_raw = await asyncio.gather(*tasks)
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
if not all_versions:
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
def version_key(v):
return tuple(map(int, v.split('.')))
@ -2364,6 +2461,10 @@ async def ps_details_proxy(request: Request):
# Add llama-server models with endpoint info and full status metadata (if any)
if llama_loaded:
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
props_requests: list[tuple[str, str]] = []
llama_models_pending: list[dict] = []
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
# Filter for loaded models only
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
@ -2388,7 +2489,53 @@ async def ps_details_proxy(request: Request):
if isinstance(status_info, dict):
model_with_endpoint["llama_status_args"] = status_info.get("args")
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
models.append(model_with_endpoint)
llama_models_pending.append(model_with_endpoint)
props_requests.append((endpoint, raw_id))
# Fetch /props for each llama-server model to get context length (n_ctx)
# and unload sleeping models automatically
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool]:
client: aiohttp.ClientSession = app_state["session"]
base_url = endpoint.rstrip("/").removesuffix("/v1")
props_url = f"{base_url}/props?model={model_id}"
headers = None
api_key = config.api_keys.get(endpoint)
if api_key:
headers = {"Authorization": f"Bearer {api_key}"}
try:
async with client.get(props_url, headers=headers) as resp:
if resp.status == 200:
data = await resp.json()
dgs = data.get("default_generation_settings", {})
n_ctx = dgs.get("n_ctx")
is_sleeping = data.get("is_sleeping", False)
if is_sleeping:
unload_url = f"{base_url}/models/unload"
try:
async with client.post(
unload_url,
json={"model": model_id},
headers=headers,
) as unload_resp:
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
except Exception as ue:
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
return n_ctx, is_sleeping
except Exception as e:
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
return None, False
props_results = await asyncio.gather(
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
)
for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results):
if n_ctx is not None:
model_dict["context_length"] = n_ctx
if not is_sleeping:
models.append(model_dict)
return JSONResponse(content={"models": models}, status_code=200)
@ -2479,7 +2626,9 @@ async def openai_embedding_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
if is_openai_compatible(endpoint):
api_key = config.api_keys.get(endpoint, "no-key")
else:
@ -2490,8 +2639,8 @@ async def openai_embedding_proxy(request: Request):
# 3. Async generator that streams embedding data and decrements the counter
async_gen = await oclient.embeddings.create(input=doc, model=model)
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 5. Return a StreamingResponse backed by the generator
return async_gen
@ -2568,7 +2717,9 @@ async def openai_chat_completions_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
# 3. Async generator that streams completions data and decrements the counter
@ -2593,23 +2744,42 @@ async def openai_chat_completions_proxy(request: Request):
else orjson.dumps(chunk)
)
if chunk.choices:
if chunk.choices[0].delta.content is not None:
delta = chunk.choices[0].delta
has_content = delta.content is not None
has_reasoning = (
getattr(delta, "reasoning_content", None) is not None
or getattr(delta, "reasoning", None) is not None
)
has_tool_calls = getattr(delta, "tool_calls", None) is not None
if has_content or has_reasoning or has_tool_calls:
yield f"data: {data}\n\n".encode("utf-8")
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
local_model = model
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model if ":" in model else model + ":latest"
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
yield b"data: [DONE]\n\n"
else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
@ -2619,12 +2789,12 @@ async def openai_chat_completions_proxy(request: Request):
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ochat_response(),
media_type="application/json",
media_type="text/event-stream" if stream else "application/json",
)
# -------------------------------------------------------------
@ -2693,7 +2863,9 @@ async def openai_completions_proxy(request: Request):
# 2. Endpoint logic
endpoint = await choose_endpoint(model)
await increment_usage(endpoint, model)
# Normalize model name for tracking so it matches the PS table key
tracking_model = get_tracking_model(endpoint, model)
await increment_usage(endpoint, tracking_model)
base_url = ep2base(endpoint)
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
@ -2710,24 +2882,42 @@ async def openai_completions_proxy(request: Request):
else orjson.dumps(chunk)
)
if chunk.choices:
if chunk.choices[0].finish_reason == None:
choice = chunk.choices[0]
has_text = getattr(choice, "text", None) is not None
has_reasoning = (
getattr(choice, "reasoning_content", None) is not None
or getattr(choice, "reasoning", None) is not None
)
if has_text or has_reasoning or choice.finish_reason is not None:
yield f"data: {data}\n\n".encode("utf-8")
elif chunk.usage is not None:
# Forward the usage-only final chunk (e.g. from llama-server)
yield f"data: {data}\n\n".encode("utf-8")
prompt_tok = 0
comp_tok = 0
if chunk.usage is not None:
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
if prompt_tok != 0 or comp_tok != 0:
local_model = model
if not is_ext_openai_endpoint(endpoint):
if not ":" in model:
local_model = model if ":" in model else model + ":latest"
await token_queue.put((endpoint, local_model, prompt_tok, comp_tok))
prompt_tok = chunk.usage.prompt_tokens or 0
comp_tok = chunk.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
# Final DONE event
yield b"data: [DONE]\n\n"
else:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
prompt_tok = 0
comp_tok = 0
if async_gen.usage is not None:
prompt_tok = async_gen.usage.prompt_tokens or 0
comp_tok = async_gen.usage.completion_tokens or 0
else:
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
if llama_usage:
prompt_tok, comp_tok = llama_usage
if prompt_tok != 0 or comp_tok != 0:
await token_queue.put((endpoint, model, prompt_tok, comp_tok))
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
json_line = (
async_gen.model_dump_json()
if hasattr(async_gen, "model_dump_json")
@ -2737,12 +2927,12 @@ async def openai_completions_proxy(request: Request):
finally:
# Ensure counter is decremented even if an exception occurs
await decrement_usage(endpoint, model)
await decrement_usage(endpoint, tracking_model)
# 4. Return a StreamingResponse backed by the generator
return StreamingResponse(
stream_ocompletions_response(),
media_type="application/json",
media_type="text/event-stream" if stream else "application/json",
)
# -------------------------------------------------------------

View file

@ -928,7 +928,7 @@ function renderTimeSeriesChart(timeSeriesData, chart, minutes) {
const uniqueEndpoints = Array.from(new Set(endpoints));
const endpointsData = encodeURIComponent(JSON.stringify(uniqueEndpoints));
return `<tr data-model="${modelName}" data-endpoints="${endpointsData}">
<td class="model">${modelName} <a href="#" class="stats-link" data-model="${originalName}">stats</a></td>
<td class="model">${modelName} <a href="#" class="stats-link" data-model="${modelName}">stats</a></td>
<td>${renderInstanceList(endpoints)}</td>
<td>${params}</td>
<td>${quant}</td>