279 lines
9.3 KiB
Python
279 lines
9.3 KiB
Python
|
|
"""Management / observability routes.
|
|||
|
|
|
|||
|
|
Read-only endpoints used by the dashboard and external monitoring:
|
|||
|
|
* usage counters and token-counts breakdown,
|
|||
|
|
* conversation-affinity introspection,
|
|||
|
|
* endpoint health summary,
|
|||
|
|
* LLM-response cache stats and invalidation,
|
|||
|
|
* SSE live-stream of usage updates,
|
|||
|
|
* hostname and ``/health`` probe.
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
import socket
|
|||
|
|
import time
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
import orjson
|
|||
|
|
from fastapi import APIRouter, HTTPException, Request
|
|||
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|||
|
|
|
|||
|
|
from cache import get_llm_cache
|
|||
|
|
from config import get_config
|
|||
|
|
from db import get_db
|
|||
|
|
from state import (
|
|||
|
|
usage_counts,
|
|||
|
|
token_usage_counts,
|
|||
|
|
_affinity_map,
|
|||
|
|
_affinity_lock,
|
|||
|
|
)
|
|||
|
|
from sse import subscribe, unsubscribe
|
|||
|
|
from backends.normalize import _normalize_llama_model_name
|
|||
|
|
from backends.probe import _endpoint_health
|
|||
|
|
|
|||
|
|
|
|||
|
|
router = APIRouter()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/token_counts")
|
|||
|
|
async def token_counts_proxy():
|
|||
|
|
breakdown = []
|
|||
|
|
total = 0
|
|||
|
|
async for entry in get_db().load_token_counts():
|
|||
|
|
total += entry['total_tokens']
|
|||
|
|
breakdown.append({
|
|||
|
|
"endpoint": entry["endpoint"],
|
|||
|
|
"model": entry["model"],
|
|||
|
|
"input_tokens": entry["input_tokens"],
|
|||
|
|
"output_tokens": entry["output_tokens"],
|
|||
|
|
"total_tokens": entry["total_tokens"],
|
|||
|
|
})
|
|||
|
|
return {"total_tokens": total, "breakdown": breakdown}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/api/aggregate_time_series_days")
|
|||
|
|
async def aggregate_time_series_days_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
body_bytes = await request.body()
|
|||
|
|
if not body_bytes:
|
|||
|
|
days = 30
|
|||
|
|
trim_old = False
|
|||
|
|
else:
|
|||
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|||
|
|
days = int(payload.get("days", 30))
|
|||
|
|
trim_old = bool(payload.get("trim_old", False))
|
|||
|
|
except Exception:
|
|||
|
|
days = 30
|
|||
|
|
trim_old = False
|
|||
|
|
aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old)
|
|||
|
|
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/api/stats")
|
|||
|
|
async def stats_proxy(request: Request, model: Optional[str] = None):
|
|||
|
|
"""
|
|||
|
|
Return token usage statistics for a specific model.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
body_bytes = await request.body()
|
|||
|
|
|
|||
|
|
if not model:
|
|||
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|||
|
|
model = payload.get("model")
|
|||
|
|
|
|||
|
|
if not model:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400, detail="Missing required field 'model'"
|
|||
|
|
)
|
|||
|
|
except orjson.JSONDecodeError as e:
|
|||
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|||
|
|
|
|||
|
|
db = get_db()
|
|||
|
|
token_data = await db.get_token_counts_for_model(model)
|
|||
|
|
|
|||
|
|
if not token_data:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=404, detail="No token data found for this model"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
time_series = [
|
|||
|
|
entry async for entry in db.get_time_series_for_model(model)
|
|||
|
|
]
|
|||
|
|
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'model': model,
|
|||
|
|
'input_tokens': token_data['input_tokens'],
|
|||
|
|
'output_tokens': token_data['output_tokens'],
|
|||
|
|
'total_tokens': token_data['total_tokens'],
|
|||
|
|
'time_series': time_series,
|
|||
|
|
'endpoint_distribution': endpoint_distribution,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/affinity_stats")
|
|||
|
|
async def affinity_stats(request: Request):
|
|||
|
|
"""
|
|||
|
|
Aggregate live conversation-affinity pins, one entry per pinned conversation.
|
|||
|
|
Each entry exposes only the endpoint, model, and remaining TTL in seconds —
|
|||
|
|
no fingerprints or content. When conversation_affinity is disabled the
|
|||
|
|
`entries` list is always empty.
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
if not config.conversation_affinity:
|
|||
|
|
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
|
|||
|
|
|
|||
|
|
now = time.monotonic()
|
|||
|
|
entries: list[dict] = []
|
|||
|
|
llama_eps = set(config.llama_server_endpoints)
|
|||
|
|
async with _affinity_lock:
|
|||
|
|
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
|
|||
|
|
remaining = expires_at - now
|
|||
|
|
if remaining <= 0:
|
|||
|
|
_affinity_map.pop(fp, None)
|
|||
|
|
continue
|
|||
|
|
# Mirror the normalisation used by /api/ps_details so the dashboard
|
|||
|
|
# can join affinity entries to PS rows by (endpoint, model).
|
|||
|
|
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
|
|||
|
|
entries.append({
|
|||
|
|
"endpoint": ep,
|
|||
|
|
"model": display_model,
|
|||
|
|
"remaining": round(remaining, 2),
|
|||
|
|
})
|
|||
|
|
return {
|
|||
|
|
"enabled": True,
|
|||
|
|
"ttl": config.conversation_affinity_ttl,
|
|||
|
|
"entries": entries,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/usage")
|
|||
|
|
async def usage_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Return a snapshot of the usage counter for each endpoint.
|
|||
|
|
Useful for debugging / monitoring.
|
|||
|
|
"""
|
|||
|
|
return {"usage_counts": usage_counts,
|
|||
|
|
"token_usage_counts": token_usage_counts}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/config")
|
|||
|
|
async def config_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Return a simple JSON object that contains the configured
|
|||
|
|
Ollama endpoints and llama_server_endpoints. The front‑end uses this
|
|||
|
|
to display which endpoints are being proxied and their health.
|
|||
|
|
Status is "error" when either liveness (/api/version) or routing
|
|||
|
|
health (/api/ps) fails — see issue #83.
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
|
|||
|
|
async def check(url: str) -> dict:
|
|||
|
|
return {"url": url, **(await _endpoint_health(url, timeout=5))}
|
|||
|
|
|
|||
|
|
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
|
|||
|
|
llama_results = []
|
|||
|
|
if config.llama_server_endpoints:
|
|||
|
|
llama_results = await asyncio.gather(
|
|||
|
|
*[check(ep) for ep in config.llama_server_endpoints]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"endpoints": ollama_results,
|
|||
|
|
"llama_server_endpoints": llama_results,
|
|||
|
|
"require_router_api_key": bool(config.router_api_key),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/cache/stats")
|
|||
|
|
async def cache_stats():
|
|||
|
|
"""Return hit/miss counters and configuration for the LLM response cache."""
|
|||
|
|
c = get_llm_cache()
|
|||
|
|
if c is None:
|
|||
|
|
return {"enabled": False}
|
|||
|
|
return {"enabled": True, **c.stats()}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/api/cache/invalidate")
|
|||
|
|
async def cache_invalidate():
|
|||
|
|
"""Clear all entries from the LLM response cache and reset counters."""
|
|||
|
|
c = get_llm_cache()
|
|||
|
|
if c is None:
|
|||
|
|
return {"enabled": False, "cleared": False}
|
|||
|
|
await c.clear()
|
|||
|
|
return {"enabled": True, "cleared": True}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/health")
|
|||
|
|
async def health_proxy(request: Request):
|
|||
|
|
"""
|
|||
|
|
Health‑check endpoint for monitoring the proxy.
|
|||
|
|
|
|||
|
|
* Queries each configured endpoint for both liveness and routing health:
|
|||
|
|
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
|
|||
|
|
OpenAI-compatible endpoints at `/models`.
|
|||
|
|
* Returns a JSON object containing:
|
|||
|
|
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
|
|||
|
|
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
|
|||
|
|
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
|
|||
|
|
"""
|
|||
|
|
config = get_config()
|
|||
|
|
# Run all health checks in parallel.
|
|||
|
|
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
|
|||
|
|
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
|
|||
|
|
# (vLLM, llama-server, external) expose /models, which serves both
|
|||
|
|
# purposes. Probing /api/version alone would miss the case where the
|
|||
|
|
# Ollama process is up but /api/ps is failing — see issue #83.
|
|||
|
|
all_endpoints = list(config.endpoints)
|
|||
|
|
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
|||
|
|
all_endpoints += llama_eps_extra
|
|||
|
|
|
|||
|
|
probe_results = await asyncio.gather(
|
|||
|
|
*(_endpoint_health(ep) for ep in all_endpoints),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
health_summary = dict(zip(all_endpoints, probe_results))
|
|||
|
|
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
|
|||
|
|
|
|||
|
|
response_payload = {
|
|||
|
|
"status": "ok" if overall_ok else "error",
|
|||
|
|
"endpoints": health_summary,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
http_status = 200 if overall_ok else 503
|
|||
|
|
return JSONResponse(content=response_payload, status_code=http_status)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/hostname")
|
|||
|
|
async def get_hostname():
|
|||
|
|
"""Return the hostname of the machine running the router."""
|
|||
|
|
return JSONResponse(content={"hostname": socket.gethostname()})
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/api/usage-stream")
|
|||
|
|
async def usage_stream(request: Request):
|
|||
|
|
"""
|
|||
|
|
Server‑Sent‑Events that emits a JSON payload every time the
|
|||
|
|
global `usage_counts` dictionary changes.
|
|||
|
|
"""
|
|||
|
|
async def event_generator():
|
|||
|
|
# The queue that receives *every* new snapshot
|
|||
|
|
queue = await subscribe()
|
|||
|
|
try:
|
|||
|
|
while True:
|
|||
|
|
# If the client disconnects, cancel the loop
|
|||
|
|
if await request.is_disconnected():
|
|||
|
|
break
|
|||
|
|
data = await queue.get()
|
|||
|
|
if data is None:
|
|||
|
|
break
|
|||
|
|
# Send the data as a single SSE message
|
|||
|
|
yield f"data: {data}\n\n"
|
|||
|
|
finally:
|
|||
|
|
# Clean‑up: unsubscribe from the broadcast channel
|
|||
|
|
await unsubscribe(queue)
|
|||
|
|
|
|||
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|