refac: modularize sse, routing, db and token handling V
This commit is contained in:
parent
3a9854c5db
commit
8355bf9a1e
5 changed files with 555 additions and 441 deletions
11
db.py
11
db.py
|
|
@ -4,6 +4,17 @@ from pathlib import Path
|
|||
from datetime import datetime, timezone
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def get_db() -> "TokenDatabase":
|
||||
"""Return the live TokenDatabase instance held by router.py.
|
||||
|
||||
Resolved lazily so submodules can access it without import cycles, and
|
||||
so test patches of ``router.db`` flow through to all callers.
|
||||
"""
|
||||
import router # lazy to avoid module-load circular import
|
||||
return router.db
|
||||
|
||||
|
||||
class TokenDatabase:
|
||||
def __init__(self, db_path: str = "token_counts.db"):
|
||||
self.db_path = db_path
|
||||
|
|
|
|||
452
router.py
452
router.py
|
|
@ -233,175 +233,13 @@ from backends.normalize import (
|
|||
get_tracking_model,
|
||||
)
|
||||
|
||||
async def token_worker() -> None:
|
||||
try:
|
||||
while True:
|
||||
endpoint, model, prompt, comp = await token_queue.get()
|
||||
# Calculate timestamp once before acquiring lock
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||||
|
||||
# Accumulate counts in memory buffer (protected by lock)
|
||||
async with buffer_lock:
|
||||
token_buffer[endpoint][model] = (
|
||||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||||
)
|
||||
|
||||
# Add to time series buffer with timestamp (UTC)
|
||||
time_series_buffer.append({
|
||||
'endpoint': endpoint,
|
||||
'model': model,
|
||||
'input_tokens': prompt,
|
||||
'output_tokens': comp,
|
||||
'total_tokens': prompt + comp,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
# Update in-memory counts for immediate reporting
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
except asyncio.CancelledError:
|
||||
# Gracefully handle task cancellation during shutdown
|
||||
print("[token_worker] Task cancelled, processing remaining queue items...")
|
||||
# Process any remaining items in the queue before exiting
|
||||
while not token_queue.empty():
|
||||
try:
|
||||
endpoint, model, prompt, comp = token_queue.get_nowait()
|
||||
# Calculate timestamp once before acquiring lock
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||||
|
||||
async with buffer_lock:
|
||||
token_buffer[endpoint][model] = (
|
||||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||||
)
|
||||
time_series_buffer.append({
|
||||
'endpoint': endpoint,
|
||||
'model': model,
|
||||
'input_tokens': prompt,
|
||||
'output_tokens': comp,
|
||||
'total_tokens': prompt + comp,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
print("[token_worker] Task cancelled, remaining items processed.")
|
||||
raise
|
||||
|
||||
async def flush_buffer() -> None:
|
||||
"""Periodically flush accumulated token counts to the database."""
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(FLUSH_INTERVAL)
|
||||
|
||||
# Flush token counts and time series (protected by lock)
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
# Copy buffer before releasing lock for DB operation
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
|
||||
# Perform DB operations outside the lock to avoid blocking
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
except asyncio.CancelledError:
|
||||
# Gracefully handle task cancellation during shutdown
|
||||
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
|
||||
# Flush any remaining data before exiting
|
||||
try:
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
|
||||
except Exception as e:
|
||||
print(f"[flush_buffer] Error during shutdown flush: {e}")
|
||||
raise
|
||||
|
||||
async def flush_remaining_buffers() -> None:
|
||||
"""
|
||||
Flush any in-memory buffers to the database on shutdown.
|
||||
This is designed to be safely invoked during shutdown and should not raise.
|
||||
"""
|
||||
try:
|
||||
flushed_entries = 0
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
flushed_entries += sum(len(v) for v in token_buffer.values())
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
flushed_entries += len(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
# Perform DB operations outside the lock
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
if flushed_entries:
|
||||
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
|
||||
else:
|
||||
print("[shutdown] No in-memory entries to flush on shutdown.")
|
||||
except Exception as e:
|
||||
# Do not raise during shutdown – log and continue teardown
|
||||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
||||
from tokens import token_worker, flush_buffer, flush_remaining_buffers
|
||||
|
||||
from backends.probe import fetch
|
||||
|
||||
|
||||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
usage_counts[endpoint][model] += 1
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
from routing import increment_usage, decrement_usage
|
||||
|
||||
async def decrement_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
# Avoid negative counts
|
||||
current = usage_counts[endpoint].get(model, 0)
|
||||
if current > 0:
|
||||
usage_counts[endpoint][model] = current - 1
|
||||
# Optionally, clean up zero entries
|
||||
if usage_counts[endpoint].get(model, 0) == 0:
|
||||
usage_counts[endpoint].pop(model, None)
|
||||
#if not usage_counts[endpoint]:
|
||||
# usage_counts.pop(endpoint, None)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
|
||||
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
||||
"""
|
||||
|
|
@ -878,287 +716,19 @@ class rechunk:
|
|||
return (prompt_n + cache_n, predicted_n)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Helpser
|
||||
# ------------------------------------------------------------------
|
||||
def _capture_snapshot() -> str:
|
||||
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
|
||||
return orjson.dumps({
|
||||
"usage_counts": dict(usage_counts),
|
||||
"token_usage_counts": dict(token_usage_counts)
|
||||
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
|
||||
async def _distribute_snapshot(snapshot: str) -> None:
|
||||
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
if q.full():
|
||||
try:
|
||||
await q.get()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
await q.put(snapshot)
|
||||
|
||||
async def close_all_sse_queues():
|
||||
for q in list(_subscribers):
|
||||
# sentinel value that the generator will recognise
|
||||
await q.put(None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscriber helpers
|
||||
# ------------------------------------------------------------------
|
||||
async def subscribe() -> asyncio.Queue:
|
||||
"""
|
||||
Returns a new Queue that will receive every snapshot.
|
||||
"""
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=10)
|
||||
async with _subscribers_lock:
|
||||
_subscribers.add(q)
|
||||
return q
|
||||
|
||||
async def unsubscribe(q: asyncio.Queue):
|
||||
async with _subscribers_lock:
|
||||
_subscribers.discard(q)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Convenience wrapper – returns the current snapshot (for the proxy)
|
||||
# ------------------------------------------------------------------
|
||||
async def get_usage_counts() -> Dict:
|
||||
return dict(usage_counts) # shallow copy
|
||||
from sse import (
|
||||
_capture_snapshot,
|
||||
_distribute_snapshot,
|
||||
close_all_sse_queues,
|
||||
subscribe,
|
||||
unsubscribe,
|
||||
get_usage_counts,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||||
# -------------------------------------------------------------
|
||||
def get_max_connections(ep: str) -> int:
|
||||
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
|
||||
return config.endpoint_config.get(ep, {}).get(
|
||||
"max_concurrent_connections", config.max_concurrent_connections
|
||||
)
|
||||
|
||||
async def choose_endpoint(model: str, reserve: bool = True,
|
||||
affinity_key: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
ensuring that the chosen endpoint actually *advertises* the model.
|
||||
|
||||
The selection algorithm:
|
||||
|
||||
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
|
||||
2️⃣ Build a list of endpoints that contain the requested model.
|
||||
2️⃣.5 If conversation affinity is enabled and the caller passes
|
||||
``affinity_key``, prefer the endpoint that previously served the
|
||||
same conversation — but only when it still has the model loaded
|
||||
and a free slot. Otherwise fall through to the standard logic.
|
||||
3️⃣ For those endpoints, find those that have the model loaded
|
||||
(`/api/ps`) *and* still have a free slot.
|
||||
4️⃣ If none are both loaded and free, fall back to any endpoint
|
||||
from the filtered list that simply has a free slot and randomly
|
||||
select one.
|
||||
5️⃣ If all are saturated, pick any endpoint from the filtered list
|
||||
(the request will queue on that endpoint).
|
||||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
# Include both config.endpoints and config.llama_server_endpoints
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
all_endpoints = config.endpoints + llama_eps_extra
|
||||
|
||||
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
|
||||
# 6️⃣
|
||||
if not candidate_endpoints:
|
||||
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
|
||||
model_without_latest = model.split(":latest")[0]
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
# Only add :latest suffix if model doesn't already have a version suffix
|
||||
if ":" not in model:
|
||||
model = model + ":latest"
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
|
||||
f"advertise the model '{model}'."
|
||||
)
|
||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||
# (concurrently, but only for the filtered list)
|
||||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
# 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing
|
||||
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
|
||||
# would appear with an empty loaded set but pass through to the
|
||||
# free-slot fallback (step 4) — sending completion calls to an
|
||||
# unhealthy backend. See issue #83.
|
||||
async with _loaded_error_cache_lock:
|
||||
unhealthy = {
|
||||
ep for ep, ts in _loaded_error_cache.items()
|
||||
if _is_fresh(ts, 300)
|
||||
}
|
||||
if unhealthy:
|
||||
filtered = [
|
||||
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if ep not in unhealthy
|
||||
]
|
||||
if filtered:
|
||||
candidate_endpoints = [ep for ep, _ in filtered]
|
||||
loaded_sets = [models for _, models in filtered]
|
||||
# If *every* candidate is unhealthy we still fall through with the
|
||||
# original list — refusing to route is worse than retrying a
|
||||
# possibly-recovered backend.
|
||||
|
||||
# 3️⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
|
||||
# failed with a backend connection error (e.g. llama-server in router mode
|
||||
# whose delegated worker for *this* model died). /v1/models keeps reporting
|
||||
# OK in that case, so the probe-level filter above cannot catch it.
|
||||
async with _completion_error_cache_lock:
|
||||
completion_broken = {
|
||||
ep for (ep, m), ts in _completion_error_cache.items()
|
||||
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
|
||||
}
|
||||
if completion_broken:
|
||||
filtered = [
|
||||
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if ep not in completion_broken
|
||||
]
|
||||
if filtered:
|
||||
candidate_endpoints = [ep for ep, _ in filtered]
|
||||
loaded_sets = [models for _, models in filtered]
|
||||
# Same fallback: if every candidate is broken for this model, fall
|
||||
# through and let the upstream retry — possibly the operator restarted
|
||||
# the dead worker.
|
||||
|
||||
# Look up a possible affinity hint *before* taking usage_lock. The two
|
||||
# locks are never held together to avoid lock-ordering issues.
|
||||
affine_ep: Optional[str] = None
|
||||
if config.conversation_affinity and affinity_key:
|
||||
async with _affinity_lock:
|
||||
entry = _affinity_map.get(affinity_key)
|
||||
if entry is not None:
|
||||
ep, _stored_model, expires_at = entry
|
||||
if expires_at < time.monotonic():
|
||||
_affinity_map.pop(affinity_key, None)
|
||||
else:
|
||||
affine_ep = ep
|
||||
|
||||
# Protect all reads/writes of usage_counts with the lock so that selection
|
||||
# and reservation are atomic — concurrent callers see each other's pending load.
|
||||
async with usage_lock:
|
||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||
# that increment_usage/decrement_usage store — raw model names differ from
|
||||
# tracking names for llama-server (HF prefix / quant suffix stripped).
|
||||
def tracking_usage(ep: str) -> int:
|
||||
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
|
||||
|
||||
def utilization_ratio(ep: str) -> float:
|
||||
return tracking_usage(ep) / get_max_connections(ep)
|
||||
|
||||
# Priority map: position in all_endpoints list (lower = higher priority)
|
||||
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
|
||||
|
||||
selected: Optional[str] = None
|
||||
|
||||
# 2️⃣.5 Conversation affinity preference — only honour the hint when
|
||||
# the affine endpoint still advertises the model loaded *and* has a
|
||||
# free slot. Otherwise fall back to the standard algorithm.
|
||||
if affine_ep:
|
||||
ep_loaded = {
|
||||
ep: set(models)
|
||||
for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
}
|
||||
if (affine_ep in candidate_endpoints
|
||||
and model in ep_loaded.get(affine_ep, set())
|
||||
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
|
||||
selected = affine_ep
|
||||
|
||||
if selected is None:
|
||||
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if model in models and tracking_usage(ep) < get_max_connections(ep)
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
if config.priority_routing:
|
||||
# WRR: sort by config order first (stable), then by utilization ratio.
|
||||
# Stable sort preserves priority for equal-ratio endpoints.
|
||||
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
|
||||
loaded_and_free.sort(key=utilization_ratio)
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
selected = random.choice(loaded_and_free)
|
||||
else:
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < get_max_connections(ep)
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
if config.priority_routing:
|
||||
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
|
||||
endpoints_with_free_slot.sort(key=utilization_ratio)
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
selected = random.choice(endpoints_with_free_slot)
|
||||
else:
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
if config.priority_routing:
|
||||
selected = min(
|
||||
candidate_endpoints,
|
||||
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
|
||||
)
|
||||
else:
|
||||
selected = min(candidate_endpoints, key=tracking_usage)
|
||||
|
||||
tracking_model = get_tracking_model(selected, model)
|
||||
snapshot = None
|
||||
if reserve:
|
||||
usage_counts[selected][tracking_model] += 1
|
||||
snapshot = _capture_snapshot()
|
||||
if snapshot is not None:
|
||||
await _distribute_snapshot(snapshot)
|
||||
# Record / refresh affinity *after* releasing usage_lock.
|
||||
if reserve and config.conversation_affinity and affinity_key:
|
||||
expires_at = time.monotonic() + config.conversation_affinity_ttl
|
||||
async with _affinity_lock:
|
||||
_affinity_map[affinity_key] = (selected, model, expires_at)
|
||||
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
|
||||
now = time.monotonic()
|
||||
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
|
||||
_affinity_map.pop(k, None)
|
||||
return selected, tracking_model
|
||||
from routing import get_max_connections, choose_endpoint
|
||||
|
||||
# -------------------------------------------------------------
|
||||
# 6. API route – Generate
|
||||
|
|
|
|||
294
routing.py
Normal file
294
routing.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""Endpoint selection (load-balancing + conversation affinity).
|
||||
|
||||
``choose_endpoint`` is the heart of routing — it picks an endpoint that
|
||||
advertises the model, prefers ones with the model already loaded and a free
|
||||
slot, applies the conversation-affinity hint when available, and honors
|
||||
config-order priority routing when ``priority_routing`` is set.
|
||||
|
||||
``increment_usage`` / ``decrement_usage`` keep the per-(endpoint, model)
|
||||
counter that drives utilization-based selection; they fan out an SSE
|
||||
snapshot on every change.
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from config import get_config
|
||||
from state import (
|
||||
usage_counts,
|
||||
usage_lock,
|
||||
_loaded_error_cache,
|
||||
_loaded_error_cache_lock,
|
||||
_completion_error_cache,
|
||||
_completion_error_cache_lock,
|
||||
_COMPLETION_ERROR_TTL,
|
||||
_affinity_map,
|
||||
_affinity_lock,
|
||||
_AFFINITY_MAX_ENTRIES,
|
||||
)
|
||||
from sse import _capture_snapshot, _distribute_snapshot
|
||||
from backends.health import _is_fresh
|
||||
from backends.normalize import (
|
||||
is_ext_openai_endpoint,
|
||||
is_openai_compatible,
|
||||
get_tracking_model,
|
||||
)
|
||||
from backends.probe import fetch
|
||||
|
||||
|
||||
async def increment_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
usage_counts[endpoint][model] += 1
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
|
||||
|
||||
async def decrement_usage(endpoint: str, model: str) -> None:
|
||||
async with usage_lock:
|
||||
# Avoid negative counts
|
||||
current = usage_counts[endpoint].get(model, 0)
|
||||
if current > 0:
|
||||
usage_counts[endpoint][model] = current - 1
|
||||
# Optionally, clean up zero entries
|
||||
if usage_counts[endpoint].get(model, 0) == 0:
|
||||
usage_counts[endpoint].pop(model, None)
|
||||
#if not usage_counts[endpoint]:
|
||||
# usage_counts.pop(endpoint, None)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
|
||||
|
||||
def get_max_connections(ep: str) -> int:
|
||||
"""Per-endpoint max_concurrent_connections, falling back to the global value."""
|
||||
cfg = get_config()
|
||||
return cfg.endpoint_config.get(ep, {}).get(
|
||||
"max_concurrent_connections", cfg.max_concurrent_connections
|
||||
)
|
||||
|
||||
|
||||
async def choose_endpoint(model: str, reserve: bool = True,
|
||||
affinity_key: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
Determine which endpoint to use for the given model while respecting
|
||||
the `max_concurrent_connections` per endpoint‑model pair **and**
|
||||
ensuring that the chosen endpoint actually *advertises* the model.
|
||||
|
||||
The selection algorithm:
|
||||
|
||||
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
|
||||
2️⃣ Build a list of endpoints that contain the requested model.
|
||||
2️⃣.5 If conversation affinity is enabled and the caller passes
|
||||
``affinity_key``, prefer the endpoint that previously served the
|
||||
same conversation — but only when it still has the model loaded
|
||||
and a free slot. Otherwise fall through to the standard logic.
|
||||
3️⃣ For those endpoints, find those that have the model loaded
|
||||
(`/api/ps`) *and* still have a free slot.
|
||||
4️⃣ If none are both loaded and free, fall back to any endpoint
|
||||
from the filtered list that simply has a free slot and randomly
|
||||
select one.
|
||||
5️⃣ If all are saturated, pick any endpoint from the filtered list
|
||||
(the request will queue on that endpoint).
|
||||
6️⃣ If no endpoint advertises the model at all, raise an error.
|
||||
"""
|
||||
config = get_config()
|
||||
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
||||
# Include both config.endpoints and config.llama_server_endpoints
|
||||
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
||||
all_endpoints = config.endpoints + llama_eps_extra
|
||||
|
||||
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
|
||||
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
|
||||
advertised_sets = await asyncio.gather(*tag_tasks)
|
||||
|
||||
# 2️⃣ Filter endpoints that advertise the requested model
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
|
||||
# 6️⃣
|
||||
if not candidate_endpoints:
|
||||
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
|
||||
model_without_latest = model.split(":latest")[0]
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
# Only add :latest suffix if model doesn't already have a version suffix
|
||||
if ":" not in model:
|
||||
model = model + ":latest"
|
||||
candidate_endpoints = [
|
||||
ep for ep, models in zip(all_endpoints, advertised_sets)
|
||||
if model in models
|
||||
]
|
||||
if not candidate_endpoints:
|
||||
raise RuntimeError(
|
||||
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
|
||||
f"advertise the model '{model}'."
|
||||
)
|
||||
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
||||
# (concurrently, but only for the filtered list)
|
||||
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
||||
loaded_sets = await asyncio.gather(*load_tasks)
|
||||
|
||||
# 3️⃣.5 Exclude endpoints whose loaded-model probe has been failing
|
||||
# recently. Without this filter, an endpoint where `/api/ps` returns 5xx
|
||||
# would appear with an empty loaded set but pass through to the
|
||||
# free-slot fallback (step 4) — sending completion calls to an
|
||||
# unhealthy backend. See issue #83.
|
||||
async with _loaded_error_cache_lock:
|
||||
unhealthy = {
|
||||
ep for ep, ts in _loaded_error_cache.items()
|
||||
if _is_fresh(ts, 300)
|
||||
}
|
||||
if unhealthy:
|
||||
filtered = [
|
||||
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if ep not in unhealthy
|
||||
]
|
||||
if filtered:
|
||||
candidate_endpoints = [ep for ep, _ in filtered]
|
||||
loaded_sets = [models for _, models in filtered]
|
||||
# If *every* candidate is unhealthy we still fall through with the
|
||||
# original list — refusing to route is worse than retrying a
|
||||
# possibly-recovered backend.
|
||||
|
||||
# 3️⃣.6 Exclude (endpoint, model) pairs whose completion path has recently
|
||||
# failed with a backend connection error (e.g. llama-server in router mode
|
||||
# whose delegated worker for *this* model died). /v1/models keeps reporting
|
||||
# OK in that case, so the probe-level filter above cannot catch it.
|
||||
async with _completion_error_cache_lock:
|
||||
completion_broken = {
|
||||
ep for (ep, m), ts in _completion_error_cache.items()
|
||||
if m == model and _is_fresh(ts, _COMPLETION_ERROR_TTL)
|
||||
}
|
||||
if completion_broken:
|
||||
filtered = [
|
||||
(ep, models) for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if ep not in completion_broken
|
||||
]
|
||||
if filtered:
|
||||
candidate_endpoints = [ep for ep, _ in filtered]
|
||||
loaded_sets = [models for _, models in filtered]
|
||||
# Same fallback: if every candidate is broken for this model, fall
|
||||
# through and let the upstream retry — possibly the operator restarted
|
||||
# the dead worker.
|
||||
|
||||
# Look up a possible affinity hint *before* taking usage_lock. The two
|
||||
# locks are never held together to avoid lock-ordering issues.
|
||||
affine_ep: Optional[str] = None
|
||||
if config.conversation_affinity and affinity_key:
|
||||
async with _affinity_lock:
|
||||
entry = _affinity_map.get(affinity_key)
|
||||
if entry is not None:
|
||||
ep, _stored_model, expires_at = entry
|
||||
if expires_at < time.monotonic():
|
||||
_affinity_map.pop(affinity_key, None)
|
||||
else:
|
||||
affine_ep = ep
|
||||
|
||||
# Protect all reads/writes of usage_counts with the lock so that selection
|
||||
# and reservation are atomic — concurrent callers see each other's pending load.
|
||||
async with usage_lock:
|
||||
# Helper: current usage for (endpoint, model) using the same normalized key
|
||||
# that increment_usage/decrement_usage store — raw model names differ from
|
||||
# tracking names for llama-server (HF prefix / quant suffix stripped).
|
||||
def tracking_usage(ep: str) -> int:
|
||||
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
|
||||
|
||||
def utilization_ratio(ep: str) -> float:
|
||||
return tracking_usage(ep) / get_max_connections(ep)
|
||||
|
||||
# Priority map: position in all_endpoints list (lower = higher priority)
|
||||
ep_priority = {ep: i for i, ep in enumerate(all_endpoints)}
|
||||
|
||||
selected: Optional[str] = None
|
||||
|
||||
# 2️⃣.5 Conversation affinity preference — only honour the hint when
|
||||
# the affine endpoint still advertises the model loaded *and* has a
|
||||
# free slot. Otherwise fall back to the standard algorithm.
|
||||
if affine_ep:
|
||||
ep_loaded = {
|
||||
ep: set(models)
|
||||
for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
}
|
||||
if (affine_ep in candidate_endpoints
|
||||
and model in ep_loaded.get(affine_ep, set())
|
||||
and tracking_usage(affine_ep) < get_max_connections(affine_ep)):
|
||||
selected = affine_ep
|
||||
|
||||
if selected is None:
|
||||
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
||||
loaded_and_free = [
|
||||
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
||||
if model in models and tracking_usage(ep) < get_max_connections(ep)
|
||||
]
|
||||
|
||||
if loaded_and_free:
|
||||
if config.priority_routing:
|
||||
# WRR: sort by config order first (stable), then by utilization ratio.
|
||||
# Stable sort preserves priority for equal-ratio endpoints.
|
||||
loaded_and_free.sort(key=lambda ep: ep_priority.get(ep, 999))
|
||||
loaded_and_free.sort(key=utilization_ratio)
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# Sort ascending for load balancing — all endpoints here already have the
|
||||
# model loaded, so there is no model-switching cost to optimise for.
|
||||
loaded_and_free.sort(key=tracking_usage)
|
||||
# When all candidates are equally idle, randomise to avoid always picking
|
||||
# the first entry in a stable sort.
|
||||
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
||||
selected = random.choice(loaded_and_free)
|
||||
else:
|
||||
selected = loaded_and_free[0]
|
||||
else:
|
||||
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
||||
endpoints_with_free_slot = [
|
||||
ep for ep in candidate_endpoints
|
||||
if tracking_usage(ep) < get_max_connections(ep)
|
||||
]
|
||||
|
||||
if endpoints_with_free_slot:
|
||||
if config.priority_routing:
|
||||
endpoints_with_free_slot.sort(key=lambda ep: ep_priority.get(ep, 999))
|
||||
endpoints_with_free_slot.sort(key=utilization_ratio)
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
||||
endpoints_with_free_slot.sort(
|
||||
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
||||
)
|
||||
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
||||
selected = random.choice(endpoints_with_free_slot)
|
||||
else:
|
||||
selected = endpoints_with_free_slot[0]
|
||||
else:
|
||||
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
||||
if config.priority_routing:
|
||||
selected = min(
|
||||
candidate_endpoints,
|
||||
key=lambda ep: (utilization_ratio(ep), ep_priority.get(ep, 999)),
|
||||
)
|
||||
else:
|
||||
selected = min(candidate_endpoints, key=tracking_usage)
|
||||
|
||||
tracking_model = get_tracking_model(selected, model)
|
||||
snapshot = None
|
||||
if reserve:
|
||||
usage_counts[selected][tracking_model] += 1
|
||||
snapshot = _capture_snapshot()
|
||||
if snapshot is not None:
|
||||
await _distribute_snapshot(snapshot)
|
||||
# Record / refresh affinity *after* releasing usage_lock.
|
||||
if reserve and config.conversation_affinity and affinity_key:
|
||||
expires_at = time.monotonic() + config.conversation_affinity_ttl
|
||||
async with _affinity_lock:
|
||||
_affinity_map[affinity_key] = (selected, model, expires_at)
|
||||
if len(_affinity_map) > _AFFINITY_MAX_ENTRIES:
|
||||
now = time.monotonic()
|
||||
for k in [k for k, v in _affinity_map.items() if v[2] < now]:
|
||||
_affinity_map.pop(k, None)
|
||||
return selected, tracking_model
|
||||
62
sse.py
Normal file
62
sse.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
"""Server-sent-events plumbing.
|
||||
|
||||
Captures the current ``usage_counts`` / ``token_usage_counts`` snapshot and
|
||||
fan-outs it to every subscribed asyncio.Queue. Routes that need a live
|
||||
dashboard feed call ``subscribe`` / ``unsubscribe`` to obtain a queue.
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
import orjson
|
||||
|
||||
from state import (
|
||||
usage_counts,
|
||||
token_usage_counts,
|
||||
_subscribers,
|
||||
_subscribers_lock,
|
||||
)
|
||||
|
||||
|
||||
def _capture_snapshot() -> str:
|
||||
"""Capture current usage counts as a JSON string. Caller must hold at least one of usage_lock/token_usage_lock."""
|
||||
return orjson.dumps({
|
||||
"usage_counts": dict(usage_counts),
|
||||
"token_usage_counts": dict(token_usage_counts)
|
||||
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
||||
|
||||
|
||||
async def _distribute_snapshot(snapshot: str) -> None:
|
||||
"""Push a pre-captured snapshot to all SSE subscribers. Must be called outside any usage lock."""
|
||||
async with _subscribers_lock:
|
||||
for q in _subscribers:
|
||||
if q.full():
|
||||
try:
|
||||
await q.get()
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
await q.put(snapshot)
|
||||
|
||||
|
||||
async def close_all_sse_queues():
|
||||
for q in list(_subscribers):
|
||||
# sentinel value that the generator will recognise
|
||||
await q.put(None)
|
||||
|
||||
|
||||
async def subscribe() -> asyncio.Queue:
|
||||
"""
|
||||
Returns a new Queue that will receive every snapshot.
|
||||
"""
|
||||
q: asyncio.Queue = asyncio.Queue(maxsize=10)
|
||||
async with _subscribers_lock:
|
||||
_subscribers.add(q)
|
||||
return q
|
||||
|
||||
|
||||
async def unsubscribe(q: asyncio.Queue):
|
||||
async with _subscribers_lock:
|
||||
_subscribers.discard(q)
|
||||
|
||||
|
||||
async def get_usage_counts() -> Dict:
|
||||
return dict(usage_counts) # shallow copy
|
||||
177
tokens.py
Normal file
177
tokens.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""Token-count write-behind pipeline.
|
||||
|
||||
``token_worker`` drains ``token_queue`` into the in-memory buffer (and into
|
||||
``token_usage_counts`` for immediate SSE reporting). ``flush_buffer``
|
||||
periodically persists the buffer to SQLite via ``TokenDatabase``.
|
||||
``flush_remaining_buffers`` is invoked on shutdown to drain whatever is left.
|
||||
|
||||
The lock order is ``buffer_lock`` then ``token_usage_lock`` — see
|
||||
choose_endpoint for why we never combine them with usage_lock.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from state import (
|
||||
token_queue,
|
||||
token_buffer,
|
||||
time_series_buffer,
|
||||
buffer_lock,
|
||||
token_usage_counts,
|
||||
token_usage_lock,
|
||||
FLUSH_INTERVAL,
|
||||
)
|
||||
from sse import _capture_snapshot, _distribute_snapshot
|
||||
from db import get_db
|
||||
|
||||
|
||||
async def token_worker() -> None:
|
||||
try:
|
||||
while True:
|
||||
endpoint, model, prompt, comp = await token_queue.get()
|
||||
# Calculate timestamp once before acquiring lock
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||||
|
||||
# Accumulate counts in memory buffer (protected by lock)
|
||||
async with buffer_lock:
|
||||
token_buffer[endpoint][model] = (
|
||||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||||
)
|
||||
|
||||
# Add to time series buffer with timestamp (UTC)
|
||||
time_series_buffer.append({
|
||||
'endpoint': endpoint,
|
||||
'model': model,
|
||||
'input_tokens': prompt,
|
||||
'output_tokens': comp,
|
||||
'total_tokens': prompt + comp,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
|
||||
# Update in-memory counts for immediate reporting
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
except asyncio.CancelledError:
|
||||
# Gracefully handle task cancellation during shutdown
|
||||
print("[token_worker] Task cancelled, processing remaining queue items...")
|
||||
# Process any remaining items in the queue before exiting
|
||||
while not token_queue.empty():
|
||||
try:
|
||||
endpoint, model, prompt, comp = token_queue.get_nowait()
|
||||
# Calculate timestamp once before acquiring lock
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
||||
|
||||
async with buffer_lock:
|
||||
token_buffer[endpoint][model] = (
|
||||
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
||||
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
||||
)
|
||||
time_series_buffer.append({
|
||||
'endpoint': endpoint,
|
||||
'model': model,
|
||||
'input_tokens': prompt,
|
||||
'output_tokens': comp,
|
||||
'total_tokens': prompt + comp,
|
||||
'timestamp': timestamp
|
||||
})
|
||||
async with token_usage_lock:
|
||||
token_usage_counts[endpoint][model] += (prompt + comp)
|
||||
snapshot = _capture_snapshot()
|
||||
await _distribute_snapshot(snapshot)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
print("[token_worker] Task cancelled, remaining items processed.")
|
||||
raise
|
||||
|
||||
|
||||
async def flush_buffer() -> None:
|
||||
"""Periodically flush accumulated token counts to the database."""
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(FLUSH_INTERVAL)
|
||||
|
||||
# Flush token counts and time series (protected by lock)
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
# Copy buffer before releasing lock for DB operation
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
|
||||
# Perform DB operations outside the lock to avoid blocking
|
||||
db = get_db()
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
except asyncio.CancelledError:
|
||||
# Gracefully handle task cancellation during shutdown
|
||||
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
|
||||
# Flush any remaining data before exiting
|
||||
try:
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
db = get_db()
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
|
||||
except Exception as e:
|
||||
print(f"[flush_buffer] Error during shutdown flush: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def flush_remaining_buffers() -> None:
|
||||
"""
|
||||
Flush any in-memory buffers to the database on shutdown.
|
||||
This is designed to be safely invoked during shutdown and should not raise.
|
||||
"""
|
||||
try:
|
||||
flushed_entries = 0
|
||||
async with buffer_lock:
|
||||
if token_buffer:
|
||||
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
||||
flushed_entries += sum(len(v) for v in token_buffer.values())
|
||||
token_buffer.clear()
|
||||
else:
|
||||
buffer_copy = None
|
||||
if time_series_buffer:
|
||||
ts_copy = list(time_series_buffer)
|
||||
flushed_entries += len(time_series_buffer)
|
||||
time_series_buffer.clear()
|
||||
else:
|
||||
ts_copy = None
|
||||
# Perform DB operations outside the lock
|
||||
db = get_db()
|
||||
if buffer_copy:
|
||||
await db.update_batched_counts(buffer_copy)
|
||||
if ts_copy:
|
||||
await db.add_batched_time_series(ts_copy)
|
||||
if flushed_entries:
|
||||
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
|
||||
else:
|
||||
print("[shutdown] No in-memory entries to flush on shutdown.")
|
||||
except Exception as e:
|
||||
# Do not raise during shutdown – log and continue teardown
|
||||
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
||||
Loading…
Add table
Add a link
Reference in a new issue