nomyo-router/api/management.py

278 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Management / observability routes.
Read-only endpoints used by the dashboard and external monitoring:
* usage counters and token-counts breakdown,
* conversation-affinity introspection,
* endpoint health summary,
* LLM-response cache stats and invalidation,
* SSE live-stream of usage updates,
* hostname and ``/health`` probe.
"""
import asyncio
import socket
import time
from typing import Optional
import orjson
from fastapi import APIRouter, HTTPException, Request
from starlette.responses import JSONResponse, StreamingResponse
from cache import get_llm_cache
from config import get_config
from db import get_db
from state import (
usage_counts,
token_usage_counts,
_affinity_map,
_affinity_lock,
)
from sse import subscribe, unsubscribe
from backends.normalize import _normalize_llama_model_name
from backends.probe import _endpoint_health
router = APIRouter()
@router.get("/api/token_counts")
async def token_counts_proxy():
breakdown = []
total = 0
async for entry in get_db().load_token_counts():
total += entry['total_tokens']
breakdown.append({
"endpoint": entry["endpoint"],
"model": entry["model"],
"input_tokens": entry["input_tokens"],
"output_tokens": entry["output_tokens"],
"total_tokens": entry["total_tokens"],
})
return {"total_tokens": total, "breakdown": breakdown}
@router.post("/api/aggregate_time_series_days")
async def aggregate_time_series_days_proxy(request: Request):
"""
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
"""
try:
body_bytes = await request.body()
if not body_bytes:
days = 30
trim_old = False
else:
payload = orjson.loads(body_bytes.decode("utf-8"))
days = int(payload.get("days", 30))
trim_old = bool(payload.get("trim_old", False))
except Exception:
days = 30
trim_old = False
aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old)
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
@router.post("/api/stats")
async def stats_proxy(request: Request, model: Optional[str] = None):
"""
Return token usage statistics for a specific model.
"""
try:
body_bytes = await request.body()
if not model:
payload = orjson.loads(body_bytes.decode("utf-8"))
model = payload.get("model")
if not model:
raise HTTPException(
status_code=400, detail="Missing required field 'model'"
)
except orjson.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
db = get_db()
token_data = await db.get_token_counts_for_model(model)
if not token_data:
raise HTTPException(
status_code=404, detail="No token data found for this model"
)
time_series = [
entry async for entry in db.get_time_series_for_model(model)
]
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
return {
'model': model,
'input_tokens': token_data['input_tokens'],
'output_tokens': token_data['output_tokens'],
'total_tokens': token_data['total_tokens'],
'time_series': time_series,
'endpoint_distribution': endpoint_distribution,
}
@router.get("/api/affinity_stats")
async def affinity_stats(request: Request):
"""
Aggregate live conversation-affinity pins, one entry per pinned conversation.
Each entry exposes only the endpoint, model, and remaining TTL in seconds —
no fingerprints or content. When conversation_affinity is disabled the
`entries` list is always empty.
"""
config = get_config()
if not config.conversation_affinity:
return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []}
now = time.monotonic()
entries: list[dict] = []
llama_eps = set(config.llama_server_endpoints)
async with _affinity_lock:
for fp, (ep, mdl, expires_at) in list(_affinity_map.items()):
remaining = expires_at - now
if remaining <= 0:
_affinity_map.pop(fp, None)
continue
# Mirror the normalisation used by /api/ps_details so the dashboard
# can join affinity entries to PS rows by (endpoint, model).
display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl
entries.append({
"endpoint": ep,
"model": display_model,
"remaining": round(remaining, 2),
})
return {
"enabled": True,
"ttl": config.conversation_affinity_ttl,
"entries": entries,
}
@router.get("/api/usage")
async def usage_proxy(request: Request):
"""
Return a snapshot of the usage counter for each endpoint.
Useful for debugging / monitoring.
"""
return {"usage_counts": usage_counts,
"token_usage_counts": token_usage_counts}
@router.get("/api/config")
async def config_proxy(request: Request):
"""
Return a simple JSON object that contains the configured
Ollama endpoints and llama_server_endpoints. The frontend uses this
to display which endpoints are being proxied and their health.
Status is "error" when either liveness (/api/version) or routing
health (/api/ps) fails — see issue #83.
"""
config = get_config()
async def check(url: str) -> dict:
return {"url": url, **(await _endpoint_health(url, timeout=5))}
ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints])
llama_results = []
if config.llama_server_endpoints:
llama_results = await asyncio.gather(
*[check(ep) for ep in config.llama_server_endpoints]
)
return {
"endpoints": ollama_results,
"llama_server_endpoints": llama_results,
"require_router_api_key": bool(config.router_api_key),
}
@router.get("/api/cache/stats")
async def cache_stats():
"""Return hit/miss counters and configuration for the LLM response cache."""
c = get_llm_cache()
if c is None:
return {"enabled": False}
return {"enabled": True, **c.stats()}
@router.post("/api/cache/invalidate")
async def cache_invalidate():
"""Clear all entries from the LLM response cache and reset counters."""
c = get_llm_cache()
if c is None:
return {"enabled": False, "cleared": False}
await c.clear()
return {"enabled": True, "cleared": True}
@router.get("/health")
async def health_proxy(request: Request):
"""
Healthcheck endpoint for monitoring the proxy.
* Queries each configured endpoint for both liveness and routing health:
Ollama endpoints are probed at `/api/version` AND `/api/ps`,
OpenAI-compatible endpoints at `/models`.
* Returns a JSON object containing:
- `status`: "ok" if every endpoint replied to every probe, otherwise "error".
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
"""
config = get_config()
# Run all health checks in parallel.
# Ollama endpoints expose /api/version (liveness) and /api/ps (routing
# health — required by `choose_endpoint`). OpenAI-compatible endpoints
# (vLLM, llama-server, external) expose /models, which serves both
# purposes. Probing /api/version alone would miss the case where the
# Ollama process is up but /api/ps is failing — see issue #83.
all_endpoints = list(config.endpoints)
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
all_endpoints += llama_eps_extra
probe_results = await asyncio.gather(
*(_endpoint_health(ep) for ep in all_endpoints),
)
health_summary = dict(zip(all_endpoints, probe_results))
overall_ok = all(entry.get("status") == "ok" for entry in probe_results)
response_payload = {
"status": "ok" if overall_ok else "error",
"endpoints": health_summary,
}
http_status = 200 if overall_ok else 503
return JSONResponse(content=response_payload, status_code=http_status)
@router.get("/api/hostname")
async def get_hostname():
"""Return the hostname of the machine running the router."""
return JSONResponse(content={"hostname": socket.gethostname()})
@router.get("/api/usage-stream")
async def usage_stream(request: Request):
"""
ServerSentEvents that emits a JSON payload every time the
global `usage_counts` dictionary changes.
"""
async def event_generator():
# The queue that receives *every* new snapshot
queue = await subscribe()
try:
while True:
# If the client disconnects, cancel the loop
if await request.is_disconnected():
break
data = await queue.get()
if data is None:
break
# Send the data as a single SSE message
yield f"data: {data}\n\n"
finally:
# Cleanup: unsubscribe from the broadcast channel
await unsubscribe(queue)
return StreamingResponse(event_generator(), media_type="text/event-stream")