407 lines
14 KiB
Python
407 lines
14 KiB
Python
"""
|
||
title: NOMYO Router - an (O)llama and OpenAI API v1 Proxy with Endpoint:Model aware routing
|
||
author: alpha-nerd-nomyo
|
||
author_url: https://github.com/nomyo-ai
|
||
version: 0.9
|
||
license: AGPL
|
||
"""
|
||
# -------------------------------------------------------------
|
||
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math, socket, httpx, hashlib
|
||
try:
|
||
import truststore; truststore.inject_into_ssl()
|
||
except ImportError:
|
||
pass
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
# Directory containing static files (relative to this script)
|
||
STATIC_DIR = Path(__file__).parent / "static"
|
||
from typing import Dict, Set, List, Optional
|
||
from urllib.parse import urlparse, parse_qsl, urlencode
|
||
from fastapi import FastAPI, Request, HTTPException
|
||
from fastapi_sse import sse_handler
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
|
||
from pydantic import Field
|
||
from pydantic_settings import BaseSettings
|
||
from collections import defaultdict
|
||
from PIL import Image
|
||
|
||
from security import _mask_secrets
|
||
from context_window import (
|
||
_count_message_tokens,
|
||
_trim_messages_for_context,
|
||
_calibrated_trim_target,
|
||
_endpoint_nctx,
|
||
_CTX_TRIM_SMALL_LIMIT,
|
||
)
|
||
from state import (
|
||
_models_cache,
|
||
_loaded_models_cache,
|
||
_available_error_cache,
|
||
_loaded_error_cache,
|
||
_completion_error_cache,
|
||
_COMPLETION_ERROR_TTL,
|
||
_models_cache_lock,
|
||
_loaded_models_cache_lock,
|
||
_available_error_cache_lock,
|
||
_loaded_error_cache_lock,
|
||
_completion_error_cache_lock,
|
||
_inflight_available_models,
|
||
_inflight_loaded_models,
|
||
_inflight_lock,
|
||
_bg_refresh_available,
|
||
_bg_refresh_loaded,
|
||
_bg_refresh_lock,
|
||
_subscribers,
|
||
_subscribers_lock,
|
||
token_queue,
|
||
app_state,
|
||
token_buffer,
|
||
time_series_buffer,
|
||
buffer_lock,
|
||
FLUSH_INTERVAL,
|
||
)
|
||
|
||
# Rebound on startup — must stay in router.py module namespace.
|
||
token_worker_task: asyncio.Task | None = None
|
||
flush_task: asyncio.Task | None = None
|
||
|
||
from config import Config, _config_path_from_env
|
||
|
||
from ollama._types import TokenLogprob, Logprob
|
||
from db import TokenDatabase
|
||
from cache import init_llm_cache, get_llm_cache, openai_nonstream_to_sse
|
||
|
||
|
||
# Create the global config object – it will be overwritten on startup.
|
||
# Submodules read it lazily via config.get_config().
|
||
config = Config.from_yaml(_config_path_from_env())
|
||
|
||
# -------------------------------------------------------------
|
||
# 2. FastAPI application
|
||
# -------------------------------------------------------------
|
||
app = FastAPI()
|
||
sse_handler.app = app
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["GET", "POST", "DELETE"],
|
||
allow_headers=["Authorization", "Content-Type"],
|
||
)
|
||
from state import default_headers
|
||
|
||
# -------------------------------------------------------------
|
||
# Router-level authentication (optional)
|
||
# -------------------------------------------------------------
|
||
def _extract_router_api_key(request: Request) -> Optional[str]:
|
||
"""
|
||
Extract the provided router API key from the Authorization header or `api_key`
|
||
query parameter. The middleware uses this to gate access to API routes when
|
||
a router_api_key is configured.
|
||
"""
|
||
auth_header = request.headers.get("Authorization")
|
||
if auth_header and auth_header.lower().startswith("bearer "):
|
||
key = auth_header.split(" ", 1)[1].strip()
|
||
if key: # Ensure key is not empty
|
||
return key
|
||
query_key = request.query_params.get("api_key")
|
||
if query_key:
|
||
return query_key
|
||
return None
|
||
|
||
|
||
def _strip_api_key_from_scope(request: Request) -> None:
|
||
"""
|
||
Remove api_key from the ASGI scope query string to avoid leaking it in logs.
|
||
"""
|
||
scope = request.scope
|
||
raw_qs = scope.get("query_string", b"")
|
||
if not raw_qs:
|
||
return
|
||
params = parse_qsl(raw_qs.decode("utf-8"), keep_blank_values=True)
|
||
filtered = [(k, v) for (k, v) in params if k != "api_key"]
|
||
scope["query_string"] = urlencode(filtered).encode("utf-8")
|
||
|
||
|
||
@app.middleware("http")
|
||
async def enforce_router_api_key(request: Request, call_next):
|
||
"""
|
||
Enforce the optional NOMYO Router API key for all non-static requests.
|
||
When `config.router_api_key` is set, clients must supply the key either in
|
||
the Authorization header (`Bearer <key>`) or as `api_key` query parameter.
|
||
"""
|
||
expected_key = config.router_api_key
|
||
if not expected_key or request.method == "OPTIONS":
|
||
return await call_next(request)
|
||
|
||
path = request.url.path
|
||
# Allow static assets (CSS, JS, images, fonts) but NOT HTML pages,
|
||
# which would bypass auth by accessing /static/index.html directly.
|
||
_STATIC_ASSET_EXTS = {".css", ".js", ".ico", ".png", ".jpg", ".jpeg", ".svg", ".woff", ".woff2", ".ttf", ".map"}
|
||
is_static_asset = path.startswith("/static") and Path(path).suffix.lower() in _STATIC_ASSET_EXTS
|
||
if is_static_asset or path in {"/", "/favicon.ico"}:
|
||
return await call_next(request)
|
||
|
||
provided_key = _extract_router_api_key(request)
|
||
# Strip the api_key query param from scope so access logs do not leak it
|
||
_strip_api_key_from_scope(request)
|
||
if provided_key is None:
|
||
# No key provided but authentication is required - return 401
|
||
headers = {}
|
||
if "/api/" in path and path != "/api/usage-stream":
|
||
headers = {
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||
}
|
||
return JSONResponse(
|
||
content={"detail": "Missing NOMYO Router API key"},
|
||
status_code=401,
|
||
headers=headers,
|
||
)
|
||
|
||
if not secrets.compare_digest(str(provided_key), str(expected_key)):
|
||
return JSONResponse(
|
||
content={"detail": "Invalid NOMYO Router API key"},
|
||
status_code=403,
|
||
)
|
||
|
||
response = await call_next(request)
|
||
# Add CORS headers for authenticated API requests
|
||
if "/api/" in path and path != "/api/usage-stream":
|
||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||
response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type"
|
||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||
return response
|
||
|
||
|
||
@app.exception_handler(openai.APIStatusError)
|
||
async def _openai_api_status_error_handler(request: Request, exc: openai.APIStatusError):
|
||
"""Forward upstream OpenAI-SDK status errors with their original status code and body
|
||
instead of letting them bubble up as 500s."""
|
||
body = exc.body if exc.body is not None else {"error": {"message": str(exc), "code": exc.status_code}}
|
||
return JSONResponse(status_code=exc.status_code, content=body)
|
||
|
||
|
||
from state import (
|
||
usage_counts,
|
||
token_usage_counts,
|
||
usage_lock,
|
||
token_usage_lock,
|
||
_affinity_map,
|
||
_affinity_lock,
|
||
_AFFINITY_MAX_ENTRIES,
|
||
)
|
||
|
||
from fingerprint import _conversation_fingerprint
|
||
|
||
# Database instance
|
||
db: "TokenDatabase" = None
|
||
|
||
# -------------------------------------------------------------
|
||
# 4. Helperfunctions
|
||
# -------------------------------------------------------------
|
||
from backends.normalize import (
|
||
_normalize_llama_model_name,
|
||
_extract_llama_quant,
|
||
ep2base,
|
||
dedupe_on_keys,
|
||
)
|
||
from backends.sessions import (
|
||
_is_unix_socket_endpoint,
|
||
_get_socket_path,
|
||
get_session,
|
||
_make_openai_client,
|
||
)
|
||
from backends.health import (
|
||
_is_fresh,
|
||
_ensure_success,
|
||
_format_connection_issue,
|
||
_is_backend_connection_error,
|
||
_mark_backend_unhealthy,
|
||
_is_llama_model_loaded,
|
||
_is_llama_model_loaded_or_sleeping,
|
||
)
|
||
|
||
|
||
from backends.normalize import (
|
||
is_ext_openai_endpoint,
|
||
is_openai_compatible,
|
||
get_tracking_model,
|
||
)
|
||
|
||
from tokens import token_worker, flush_buffer, flush_remaining_buffers
|
||
|
||
from backends.probe import fetch
|
||
|
||
|
||
from routing import increment_usage, decrement_usage
|
||
|
||
|
||
from requests.chat import _make_chat_request, _make_moe_requests
|
||
|
||
from images import iso8601_ns, is_base64, resize_image_if_needed
|
||
|
||
from requests.messages import (
|
||
_strip_assistant_prefill,
|
||
transform_tool_calls_to_openai,
|
||
transform_images_to_data_urls,
|
||
_strip_images_from_messages,
|
||
_accumulate_openai_tc_delta,
|
||
_build_ollama_tool_calls,
|
||
_convert_openai_logprobs,
|
||
get_last_user_content,
|
||
)
|
||
from requests.rechunk import rechunk
|
||
|
||
from sse import (
|
||
_capture_snapshot,
|
||
_distribute_snapshot,
|
||
close_all_sse_queues,
|
||
subscribe,
|
||
unsubscribe,
|
||
get_usage_counts,
|
||
)
|
||
|
||
# -------------------------------------------------------------
|
||
# 5. Endpoint selection logic (respecting the configurable limit)
|
||
# -------------------------------------------------------------
|
||
from routing import get_max_connections, choose_endpoint
|
||
|
||
# (Ollama /api/* routes — moved to api/ollama.py)
|
||
# -------------------------------------------------------------
|
||
# 18b. Conversation-affinity stats – feeds the PS-table dot matrix
|
||
# -------------------------------------------------------------
|
||
# (affinity_stats, usage, config — moved to api/management.py)
|
||
# (v1/* routes — moved to api/openai.py)
|
||
# (cache routes — moved to api/management.py)
|
||
# -------------------------------------------------------------
|
||
# 26. Serve the static front‑end
|
||
# -------------------------------------------------------------
|
||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||
|
||
from api.static import router as static_router
|
||
app.include_router(static_router)
|
||
from api.management import router as management_router
|
||
app.include_router(management_router)
|
||
from api.openai import router as openai_router
|
||
app.include_router(openai_router)
|
||
from api.ollama import router as ollama_router
|
||
app.include_router(ollama_router)
|
||
|
||
# (health, hostname, usage-stream — moved to api/management.py)
|
||
# -------------------------------------------------------------
|
||
# 28. FastAPI startup/shutdown events
|
||
# -------------------------------------------------------------
|
||
@app.on_event("startup")
|
||
async def startup_event() -> None:
|
||
global config, db, token_worker_task, flush_task
|
||
# Load YAML config (or use defaults if not present)
|
||
config_path = _config_path_from_env()
|
||
config = Config.from_yaml(config_path)
|
||
if config_path.exists():
|
||
print(
|
||
f"Loaded configuration from {config_path}:\n"
|
||
f" endpoints={config.endpoints},\n"
|
||
f" llama_server_endpoints={config.llama_server_endpoints},\n"
|
||
f" max_concurrent_connections={config.max_concurrent_connections},\n"
|
||
f" endpoint_config={config.endpoint_config},\n"
|
||
f" priority_routing={config.priority_routing}"
|
||
)
|
||
else:
|
||
print(
|
||
f"No configuration file found at {config_path}. "
|
||
"Falling back to default settings."
|
||
)
|
||
|
||
# Initialize database
|
||
db = TokenDatabase(config.db_path)
|
||
await db.init_db()
|
||
|
||
# Load existing token counts from database
|
||
async for count_entry in db.load_token_counts():
|
||
endpoint = count_entry['endpoint']
|
||
model = count_entry['model']
|
||
input_tokens = count_entry['input_tokens']
|
||
output_tokens = count_entry['output_tokens']
|
||
total_tokens = count_entry['total_tokens']
|
||
|
||
token_usage_counts[endpoint][model] = total_tokens
|
||
|
||
ssl_context = ssl.create_default_context()
|
||
connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context)
|
||
timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15)
|
||
session = aiohttp.ClientSession(
|
||
connector=connector,
|
||
timeout=timeout,
|
||
headers={"Referer": default_headers.get("HTTP-Referer", "https://nomyo.ai")},
|
||
)
|
||
|
||
app_state["connector"] = connector
|
||
app_state["session"] = session
|
||
|
||
# Create httpx clients for external OpenAI endpoints (Google, etc.)
|
||
# aiohttp strips Referer headers for cross-origin requests, so we use httpx
|
||
for ep in config.endpoints:
|
||
if is_ext_openai_endpoint(ep):
|
||
app_state["httpx_clients"][ep] = httpx.AsyncClient(timeout=30.0)
|
||
|
||
# Create per-endpoint Unix socket sessions for .sock endpoints
|
||
for ep in config.llama_server_endpoints:
|
||
if _is_unix_socket_endpoint(ep):
|
||
sock_path = _get_socket_path(ep)
|
||
sock_connector = aiohttp.UnixConnector(path=sock_path)
|
||
sock_timeout = aiohttp.ClientTimeout(total=300, connect=5, sock_read=300)
|
||
sock_session = aiohttp.ClientSession(connector=sock_connector, timeout=sock_timeout)
|
||
app_state["socket_sessions"][ep] = sock_session
|
||
transport = httpx.AsyncHTTPTransport(uds=sock_path)
|
||
app_state["httpx_clients"][ep] = httpx.AsyncClient(transport=transport, timeout=300.0)
|
||
print(f"[startup] Unix socket session: {ep} -> {sock_path}")
|
||
|
||
token_worker_task = asyncio.create_task(token_worker())
|
||
flush_task = asyncio.create_task(flush_buffer())
|
||
await init_llm_cache(config)
|
||
|
||
@app.on_event("shutdown")
|
||
async def shutdown_event() -> None:
|
||
await close_all_sse_queues()
|
||
|
||
# Stop background tasks first so they stop touching the DB before we close it.
|
||
for t in (token_worker_task, flush_task):
|
||
if t is not None:
|
||
t.cancel()
|
||
try:
|
||
await t
|
||
except (asyncio.CancelledError, Exception):
|
||
pass
|
||
|
||
await flush_remaining_buffers()
|
||
await app_state["session"].close()
|
||
|
||
# Close Unix socket sessions
|
||
for ep, sess in list(app_state.get("socket_sessions", {}).items()):
|
||
try:
|
||
await sess.close()
|
||
print(f"[shutdown] Closed Unix socket session: {ep}")
|
||
except Exception as e:
|
||
print(f"[shutdown] Error closing Unix socket session {ep}: {e}")
|
||
|
||
# Close httpx Unix socket clients
|
||
for ep, client in list(app_state.get("httpx_clients", {}).items()):
|
||
try:
|
||
await client.aclose()
|
||
print(f"[shutdown] Closed httpx client: {ep}")
|
||
except Exception as e:
|
||
print(f"[shutdown] Error closing httpx client {ep}: {e}")
|
||
|
||
# Close the aiosqlite connection last — its worker thread is non-daemon
|
||
# and would otherwise keep the interpreter alive after lifespan completes.
|
||
if db is not None:
|
||
try:
|
||
await db.close()
|
||
print("[shutdown] Closed token DB connection.")
|
||
except Exception as e:
|
||
print(f"[shutdown] Error closing DB: {e}")
|