2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
2026-03-05 11:09:20 +01:00
|
|
|
|
title: NOMYO Router - an (O)llama and OpenAI API v1 Proxy with Endpoint:Model aware routing
|
2025-08-26 18:19:43 +02:00
|
|
|
|
author: alpha-nerd-nomyo
|
|
|
|
|
|
author_url: https://github.com/nomyo-ai
|
2026-02-27 16:39:27 +01:00
|
|
|
|
version: 0.7
|
2025-08-26 18:19:43 +02:00
|
|
|
|
license: AGPL
|
|
|
|
|
|
"""
|
|
|
|
|
|
# -------------------------------------------------------------
|
2026-02-27 16:39:27 +01:00
|
|
|
|
import orjson, time, asyncio, yaml, ollama, openai, os, re, aiohttp, ssl, random, base64, io, enhance, secrets, math
|
2026-02-12 16:15:39 +01:00
|
|
|
|
try:
|
|
|
|
|
|
import truststore; truststore.inject_into_ssl()
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
pass
|
2025-11-18 11:16:21 +01:00
|
|
|
|
from datetime import datetime, timezone
|
2025-08-26 18:19:43 +02:00
|
|
|
|
from pathlib import Path
|
2026-01-05 17:16:31 +01:00
|
|
|
|
|
|
|
|
|
|
# Directory containing static files (relative to this script)
|
|
|
|
|
|
STATIC_DIR = Path(__file__).parent / "static"
|
2025-09-05 12:11:31 +02:00
|
|
|
|
from typing import Dict, Set, List, Optional
|
2026-01-14 09:28:02 +01:00
|
|
|
|
from urllib.parse import urlparse, parse_qsl, urlencode
|
2025-08-26 18:19:43 +02:00
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
2025-09-05 12:11:31 +02:00
|
|
|
|
from fastapi_sse import sse_handler
|
2025-08-30 00:12:56 +02:00
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
2025-09-11 09:46:19 +02:00
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-08-30 12:43:35 +02:00
|
|
|
|
from starlette.responses import StreamingResponse, JSONResponse, Response, HTMLResponse, RedirectResponse
|
2025-08-26 18:19:43 +02:00
|
|
|
|
from pydantic import Field
|
|
|
|
|
|
from pydantic_settings import BaseSettings
|
|
|
|
|
|
from collections import defaultdict
|
2025-09-24 11:46:38 +02:00
|
|
|
|
from PIL import Image
|
2025-09-01 13:38:49 +02:00
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# In‑memory caches
|
|
|
|
|
|
# ------------------------------------------------------------------
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# Successful results are cached for 300s
|
2025-09-01 13:38:49 +02:00
|
|
|
|
_models_cache: dict[str, tuple[Set[str], float]] = {}
|
2025-11-17 14:40:24 +01:00
|
|
|
|
_loaded_models_cache: dict[str, tuple[Set[str], float]] = {}
|
2026-02-08 16:46:40 +01:00
|
|
|
|
# Transient errors are cached separately per concern so that a failure
|
|
|
|
|
|
# in one path does not poison the other.
|
|
|
|
|
|
_available_error_cache: dict[str, float] = {}
|
|
|
|
|
|
_loaded_error_cache: dict[str, float] = {}
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2026-01-16 16:47:24 +01:00
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Cache locks
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
_models_cache_lock = asyncio.Lock()
|
|
|
|
|
|
_loaded_models_cache_lock = asyncio.Lock()
|
2026-02-08 16:46:40 +01:00
|
|
|
|
_available_error_cache_lock = asyncio.Lock()
|
|
|
|
|
|
_loaded_error_cache_lock = asyncio.Lock()
|
2026-01-16 16:47:24 +01:00
|
|
|
|
|
2026-01-29 18:00:33 +01:00
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# In-flight request tracking (prevents cache stampede)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
_inflight_available_models: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
_inflight_loaded_models: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
_inflight_lock = asyncio.Lock()
|
2026-02-14 14:51:44 +01:00
|
|
|
|
_bg_refresh_available: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
_bg_refresh_loaded: dict[str, asyncio.Task] = {}
|
|
|
|
|
|
_bg_refresh_lock = asyncio.Lock()
|
2026-01-29 18:00:33 +01:00
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# ------------------------------------------------------------------
|
2025-11-10 15:37:46 +01:00
|
|
|
|
# Queues
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
_subscribers: Set[asyncio.Queue] = set()
|
|
|
|
|
|
_subscribers_lock = asyncio.Lock()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
token_queue: asyncio.Queue[tuple[str, str, int, int]] = asyncio.Queue()
|
2025-09-05 12:11:31 +02:00
|
|
|
|
|
2026-01-14 09:28:02 +01:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# Secret handling
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
def _mask_secrets(text: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Mask common API key patterns to avoid leaking secrets in logs or error payloads.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not text:
|
|
|
|
|
|
return text
|
|
|
|
|
|
# OpenAI-style keys (sk-...) and generic "api key" mentions
|
|
|
|
|
|
text = re.sub(r"sk-[A-Za-z0-9]{4}[A-Za-z0-9_-]*", "sk-***redacted***", text)
|
2026-03-03 16:34:16 +01:00
|
|
|
|
text = re.sub(r"(?i)(api[-_ ]key\s*[:=]\s*)([^\s]+)", r"\1***redacted***", text)
|
2026-01-14 09:28:02 +01:00
|
|
|
|
return text
|
|
|
|
|
|
|
2025-09-10 10:21:49 +02:00
|
|
|
|
# ------------------------------------------------------------------
|
2025-11-13 10:13:10 +01:00
|
|
|
|
# Globals
|
2025-09-10 10:21:49 +02:00
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
app_state = {
|
|
|
|
|
|
"session": None,
|
|
|
|
|
|
"connector": None,
|
|
|
|
|
|
}
|
2025-11-13 10:13:10 +01:00
|
|
|
|
token_worker_task: asyncio.Task | None = None
|
2025-11-18 11:16:21 +01:00
|
|
|
|
flush_task: asyncio.Task | None = None
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Token Count Buffer (for write-behind pattern)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Structure: {endpoint: {model: (input_tokens, output_tokens)}}
|
2025-12-02 12:18:23 +01:00
|
|
|
|
token_buffer: dict[str, dict[str, tuple[int, int]]] = defaultdict(lambda: defaultdict(lambda: (0, 0)))
|
2025-11-18 11:16:21 +01:00
|
|
|
|
# Time series buffer with timestamp
|
|
|
|
|
|
time_series_buffer: list[dict[str, int | str]] = []
|
2026-01-05 17:16:31 +01:00
|
|
|
|
# Lock to protect buffer access from race conditions
|
|
|
|
|
|
buffer_lock = asyncio.Lock()
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
|
|
|
|
|
# Configuration for periodic flushing
|
|
|
|
|
|
FLUSH_INTERVAL = 10 # seconds
|
2025-09-10 10:21:49 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 1. Configuration loader
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
class Config(BaseSettings):
|
|
|
|
|
|
# List of Ollama endpoints
|
|
|
|
|
|
endpoints: list[str] = Field(
|
|
|
|
|
|
default_factory=lambda: [
|
|
|
|
|
|
"http://localhost:11434",
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# List of llama-server endpoints (OpenAI-compatible with /v1/models status info)
|
|
|
|
|
|
llama_server_endpoints: List[str] = Field(default_factory=list)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# Max concurrent connections per endpoint‑model pair, see OLLAMA_NUM_PARALLEL
|
|
|
|
|
|
max_concurrent_connections: int = 1
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
api_keys: Dict[str, str] = Field(default_factory=dict)
|
2026-01-14 09:28:02 +01:00
|
|
|
|
# Optional router-level API key used to gate access to this service and dashboard
|
|
|
|
|
|
router_api_key: Optional[str] = Field(default=None, env="NOMYO_ROUTER_API_KEY")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
|
2025-11-18 11:16:21 +01:00
|
|
|
|
# Database configuration
|
|
|
|
|
|
db_path: str = Field(default=os.getenv("NOMYO_ROUTER_DB_PATH", "token_counts.db"))
|
|
|
|
|
|
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Semantic LLM Cache configuration
|
|
|
|
|
|
cache_enabled: bool = Field(default=False)
|
|
|
|
|
|
# Backend: "memory" (default, in-process), "sqlite" (persistent), "redis" (distributed)
|
|
|
|
|
|
cache_backend: str = Field(default="memory")
|
|
|
|
|
|
# Cosine similarity threshold: 1.0 = exact match only, <1.0 = semantic (requires :semantic image)
|
|
|
|
|
|
cache_similarity: float = Field(default=1.0)
|
|
|
|
|
|
# TTL in seconds; None = cache forever
|
|
|
|
|
|
cache_ttl: Optional[int] = Field(default=3600)
|
|
|
|
|
|
# SQLite backend: path to cache database file
|
|
|
|
|
|
cache_db_path: str = Field(default="llm_cache.db")
|
|
|
|
|
|
# Redis backend: connection URL
|
|
|
|
|
|
cache_redis_url: str = Field(default="redis://localhost:6379/0")
|
|
|
|
|
|
# Weight of BM25-weighted chat-history embedding vs last-user-message embedding
|
|
|
|
|
|
# 0.3 = 30% history context signal, 70% question signal
|
|
|
|
|
|
cache_history_weight: float = Field(default=0.3)
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
class Config:
|
|
|
|
|
|
# Load from `config.yaml` first, then from env variables
|
2025-09-05 12:11:31 +02:00
|
|
|
|
env_prefix = "NOMYO_ROUTER_"
|
2025-08-26 18:19:43 +02:00
|
|
|
|
yaml_file = Path("config.yaml") # relative to cwd
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
@classmethod
|
|
|
|
|
|
def _expand_env_refs(cls, obj):
|
|
|
|
|
|
"""Recursively replace `${VAR}` with os.getenv('VAR')."""
|
|
|
|
|
|
if isinstance(obj, dict):
|
|
|
|
|
|
return {k: cls._expand_env_refs(v) for k, v in obj.items()}
|
|
|
|
|
|
if isinstance(obj, list):
|
|
|
|
|
|
return [cls._expand_env_refs(v) for v in obj]
|
|
|
|
|
|
if isinstance(obj, str):
|
|
|
|
|
|
# Only expand if it is exactly ${VAR}
|
|
|
|
|
|
m = re.fullmatch(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", obj)
|
|
|
|
|
|
if m:
|
|
|
|
|
|
return os.getenv(m.group(1), "")
|
|
|
|
|
|
return obj
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
@classmethod
|
|
|
|
|
|
def from_yaml(cls, path: Path) -> "Config":
|
|
|
|
|
|
"""Load the YAML file and create the Config instance."""
|
|
|
|
|
|
if path.exists():
|
|
|
|
|
|
with path.open("r", encoding="utf-8") as fp:
|
|
|
|
|
|
data = yaml.safe_load(fp) or {}
|
2025-09-05 12:11:31 +02:00
|
|
|
|
cleaned = cls._expand_env_refs(data)
|
2026-01-14 09:28:02 +01:00
|
|
|
|
if isinstance(cleaned, dict):
|
|
|
|
|
|
# Accept hyphenated config key and map it to the field name
|
|
|
|
|
|
key_aliases = [
|
|
|
|
|
|
# canonical field name
|
|
|
|
|
|
"router_api_key",
|
|
|
|
|
|
# lowercase, hyphen/underscore variants
|
|
|
|
|
|
"nomyo-router-api-key",
|
|
|
|
|
|
"nomyo_router_api_key",
|
|
|
|
|
|
"nomyo-router_api_key",
|
|
|
|
|
|
"nomyo_router-api_key",
|
|
|
|
|
|
# uppercase env-style variants
|
|
|
|
|
|
"NOMYO-ROUTER_API_KEY",
|
|
|
|
|
|
"NOMYO_ROUTER_API_KEY",
|
|
|
|
|
|
]
|
|
|
|
|
|
for alias in key_aliases:
|
|
|
|
|
|
if alias in cleaned:
|
|
|
|
|
|
cleaned["router_api_key"] = cleaned.get("router_api_key", cleaned.pop(alias))
|
|
|
|
|
|
break
|
|
|
|
|
|
# If not present in YAML (or empty), fall back to env var explicitly
|
|
|
|
|
|
if not cleaned.get("router_api_key"):
|
|
|
|
|
|
env_key = os.getenv("NOMYO_ROUTER_API_KEY")
|
|
|
|
|
|
if env_key:
|
|
|
|
|
|
cleaned["router_api_key"] = env_key
|
2025-09-05 12:11:31 +02:00
|
|
|
|
return cls(**cleaned)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
return cls()
|
|
|
|
|
|
|
2025-11-07 13:59:16 +01:00
|
|
|
|
def _config_path_from_env() -> Path:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Resolve the configuration file path. Defaults to `config.yaml`
|
|
|
|
|
|
in the current working directory unless NOMYO_ROUTER_CONFIG_PATH
|
|
|
|
|
|
is set.
|
|
|
|
|
|
"""
|
|
|
|
|
|
candidate = os.getenv("NOMYO_ROUTER_CONFIG_PATH")
|
|
|
|
|
|
if candidate:
|
|
|
|
|
|
return Path(candidate).expanduser()
|
|
|
|
|
|
return Path("config.yaml")
|
|
|
|
|
|
|
2026-02-13 13:29:45 +01:00
|
|
|
|
from ollama._types import TokenLogprob, Logprob
|
2025-11-18 11:16:21 +01:00
|
|
|
|
from db import TokenDatabase
|
2026-03-08 09:12:09 +01:00
|
|
|
|
from cache import init_llm_cache, get_llm_cache, openai_nonstream_to_sse
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
2025-11-07 13:59:16 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# Create the global config object – it will be overwritten on startup
|
2025-11-07 13:59:16 +01:00
|
|
|
|
config = Config.from_yaml(_config_path_from_env())
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 2. FastAPI application
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
app = FastAPI()
|
2025-09-05 12:11:31 +02:00
|
|
|
|
sse_handler.app = app
|
2025-09-11 09:46:19 +02:00
|
|
|
|
app.add_middleware(
|
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
|
allow_origins=["*"],
|
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
|
allow_methods=["GET", "POST", "DELETE"],
|
|
|
|
|
|
allow_headers=["Authorization", "Content-Type"],
|
|
|
|
|
|
)
|
2025-09-21 16:20:36 +02:00
|
|
|
|
default_headers={
|
2025-10-28 11:08:52 +01:00
|
|
|
|
"HTTP-Referer": "https://nomyo.ai",
|
|
|
|
|
|
"X-Title": "NOMYO Router",
|
|
|
|
|
|
}
|
2025-09-21 16:20:36 +02:00
|
|
|
|
|
2026-01-14 09:28:02 +01:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# Router-level authentication (optional)
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
def _extract_router_api_key(request: Request) -> Optional[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Extract the provided router API key from the Authorization header or `api_key`
|
|
|
|
|
|
query parameter. The middleware uses this to gate access to API routes when
|
|
|
|
|
|
a router_api_key is configured.
|
|
|
|
|
|
"""
|
|
|
|
|
|
auth_header = request.headers.get("Authorization")
|
|
|
|
|
|
if auth_header and auth_header.lower().startswith("bearer "):
|
2026-01-26 18:11:28 +01:00
|
|
|
|
key = auth_header.split(" ", 1)[1].strip()
|
|
|
|
|
|
if key: # Ensure key is not empty
|
|
|
|
|
|
return key
|
2026-01-14 09:28:02 +01:00
|
|
|
|
query_key = request.query_params.get("api_key")
|
|
|
|
|
|
if query_key:
|
|
|
|
|
|
return query_key
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _strip_api_key_from_scope(request: Request) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Remove api_key from the ASGI scope query string to avoid leaking it in logs.
|
|
|
|
|
|
"""
|
|
|
|
|
|
scope = request.scope
|
|
|
|
|
|
raw_qs = scope.get("query_string", b"")
|
|
|
|
|
|
if not raw_qs:
|
|
|
|
|
|
return
|
|
|
|
|
|
params = parse_qsl(raw_qs.decode("utf-8"), keep_blank_values=True)
|
|
|
|
|
|
filtered = [(k, v) for (k, v) in params if k != "api_key"]
|
|
|
|
|
|
scope["query_string"] = urlencode(filtered).encode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.middleware("http")
|
|
|
|
|
|
async def enforce_router_api_key(request: Request, call_next):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Enforce the optional NOMYO Router API key for all non-static requests.
|
|
|
|
|
|
When `config.router_api_key` is set, clients must supply the key either in
|
|
|
|
|
|
the Authorization header (`Bearer <key>`) or as `api_key` query parameter.
|
|
|
|
|
|
"""
|
|
|
|
|
|
expected_key = config.router_api_key
|
|
|
|
|
|
if not expected_key or request.method == "OPTIONS":
|
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
|
|
|
|
path = request.url.path
|
|
|
|
|
|
if path.startswith("/static") or path in {"/", "/favicon.ico"}:
|
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
|
|
|
|
|
|
provided_key = _extract_router_api_key(request)
|
|
|
|
|
|
# Strip the api_key query param from scope so access logs do not leak it
|
|
|
|
|
|
_strip_api_key_from_scope(request)
|
|
|
|
|
|
if provided_key is None:
|
2026-02-01 10:05:46 +01:00
|
|
|
|
# No key provided but authentication is required - return 401
|
|
|
|
|
|
headers = {}
|
2026-01-26 18:11:28 +01:00
|
|
|
|
if "/api/" in path and path != "/api/usage-stream":
|
2026-02-01 10:05:46 +01:00
|
|
|
|
headers = {
|
|
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
|
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
|
|
|
|
|
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
|
|
|
|
|
}
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"detail": "Missing NOMYO Router API key"},
|
|
|
|
|
|
status_code=401,
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
)
|
2026-01-14 09:28:02 +01:00
|
|
|
|
|
|
|
|
|
|
if not secrets.compare_digest(str(provided_key), str(expected_key)):
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"detail": "Invalid NOMYO Router API key"},
|
|
|
|
|
|
status_code=403,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-26 18:11:28 +01:00
|
|
|
|
response = await call_next(request)
|
|
|
|
|
|
# Add CORS headers for authenticated API requests
|
|
|
|
|
|
if "/api/" in path and path != "/api/usage-stream":
|
|
|
|
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
|
|
|
|
response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type"
|
|
|
|
|
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
|
|
|
|
|
return response
|
2026-01-14 09:28:02 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 3. Global state: per‑endpoint per‑model active connection counters
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
2025-11-04 17:55:19 +01:00
|
|
|
|
token_usage_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
usage_lock = asyncio.Lock() # protects access to usage_counts
|
2025-11-10 15:37:46 +01:00
|
|
|
|
token_usage_lock = asyncio.Lock()
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2025-11-18 11:16:21 +01:00
|
|
|
|
# Database instance
|
|
|
|
|
|
db: "TokenDatabase" = None
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 4. Helperfunctions
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-09-01 13:38:49 +02:00
|
|
|
|
def _is_fresh(cached_at: float, ttl: int) -> bool:
|
|
|
|
|
|
return (time.time() - cached_at) < ttl
|
|
|
|
|
|
|
2025-09-09 17:08:00 +02:00
|
|
|
|
async def _ensure_success(resp: aiohttp.ClientResponse) -> None:
|
|
|
|
|
|
if resp.status >= 400:
|
|
|
|
|
|
text = await resp.text()
|
2026-01-14 09:28:02 +01:00
|
|
|
|
raise HTTPException(status_code=resp.status, detail=_mask_secrets(text))
|
2025-10-30 09:06:21 +01:00
|
|
|
|
|
2025-11-07 13:59:16 +01:00
|
|
|
|
def _format_connection_issue(url: str, error: Exception) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Provide a human-friendly error string for connection failures so operators
|
|
|
|
|
|
know which endpoint and address failed from inside the container.
|
|
|
|
|
|
"""
|
|
|
|
|
|
parsed = urlparse(url)
|
|
|
|
|
|
host_hint = parsed.hostname or ""
|
|
|
|
|
|
port_hint = parsed.port or ""
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(error, aiohttp.ClientConnectorError):
|
|
|
|
|
|
resolved_host = getattr(error, "host", host_hint) or host_hint or "?"
|
|
|
|
|
|
resolved_port = getattr(error, "port", port_hint) or port_hint or "?"
|
|
|
|
|
|
parts = [
|
|
|
|
|
|
f"Failed to connect to {url} (resolved: {resolved_host}:{resolved_port}).",
|
|
|
|
|
|
"Ensure the endpoint address is reachable from within the container.",
|
|
|
|
|
|
]
|
|
|
|
|
|
if resolved_host in {"localhost", "127.0.0.1"}:
|
|
|
|
|
|
parts.append(
|
|
|
|
|
|
"Inside Docker, 'localhost' refers to the container itself; use "
|
|
|
|
|
|
"'host.docker.internal' or a Docker network alias if the service "
|
|
|
|
|
|
"runs on the host machine."
|
|
|
|
|
|
)
|
|
|
|
|
|
os_error = getattr(error, "os_error", None)
|
|
|
|
|
|
if isinstance(os_error, OSError):
|
|
|
|
|
|
errno = getattr(os_error, "errno", None)
|
|
|
|
|
|
strerror = os_error.strerror or str(os_error)
|
|
|
|
|
|
if errno is not None or strerror:
|
|
|
|
|
|
parts.append(f"OS error [{errno}]: {strerror}.")
|
|
|
|
|
|
elif os_error:
|
|
|
|
|
|
parts.append(f"OS error: {os_error}.")
|
|
|
|
|
|
parts.append(f"Original error: {error}.")
|
|
|
|
|
|
return " ".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(error, asyncio.TimeoutError):
|
|
|
|
|
|
return (
|
|
|
|
|
|
f"Timed out waiting for {url}. "
|
|
|
|
|
|
"The remote endpoint may be offline or slow to respond."
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return f"Error while contacting {url}: {error}"
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
def _normalize_llama_model_name(name: str) -> str:
|
|
|
|
|
|
"""Extract the model name from a huggingface-style identifier.
|
|
|
|
|
|
e.g. 'unsloth/gpt-oss-20b-GGUF:F16' -> 'gpt-oss-20b-GGUF'
|
|
|
|
|
|
"""
|
|
|
|
|
|
if "/" in name:
|
|
|
|
|
|
name = name.rsplit("/", 1)[1]
|
|
|
|
|
|
if ":" in name:
|
|
|
|
|
|
name = name.split(":")[0]
|
|
|
|
|
|
return name
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_llama_quant(name: str) -> str:
|
|
|
|
|
|
"""Extract the quantization level from a huggingface-style identifier.
|
|
|
|
|
|
e.g. 'unsloth/gpt-oss-20b-GGUF:Q8_0' -> 'Q8_0'
|
|
|
|
|
|
Returns empty string if no quant suffix is present.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if ":" in name:
|
|
|
|
|
|
return name.rsplit(":", 1)[1]
|
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
|
|
def _is_llama_model_loaded(item: dict) -> bool:
|
|
|
|
|
|
"""Return True if a llama-server /v1/models item has status 'loaded'.
|
2026-03-02 08:54:46 +01:00
|
|
|
|
Handles both dict format ({"value": "loaded"}) and plain string ("loaded").
|
|
|
|
|
|
If no status field is present, the model is always-loaded (not dynamically managed)."""
|
2026-02-10 16:46:51 +01:00
|
|
|
|
status = item.get("status")
|
2026-03-02 08:54:46 +01:00
|
|
|
|
if status is None:
|
|
|
|
|
|
return True # No status field: model is always loaded (e.g. single-model servers)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if isinstance(status, dict):
|
|
|
|
|
|
return status.get("value") == "loaded"
|
|
|
|
|
|
if isinstance(status, str):
|
|
|
|
|
|
return status == "loaded"
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
2025-10-30 09:06:21 +01:00
|
|
|
|
def is_ext_openai_endpoint(endpoint: str) -> bool:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
"""
|
|
|
|
|
|
Determine if an endpoint is an external OpenAI-compatible endpoint (not Ollama or llama-server).
|
|
|
|
|
|
|
|
|
|
|
|
Returns True for:
|
|
|
|
|
|
- External services like OpenAI.com, Groq, etc.
|
|
|
|
|
|
|
|
|
|
|
|
Returns False for:
|
|
|
|
|
|
- Ollama endpoints (without /v1, or with /v1 but default port 11434)
|
|
|
|
|
|
- llama-server endpoints (explicitly configured in llama_server_endpoints)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Check if it's a llama-server endpoint (has /v1 and is in the configured list)
|
|
|
|
|
|
if endpoint in config.llama_server_endpoints:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
2025-10-30 09:06:21 +01:00
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
base_endpoint = endpoint.replace('/v1', '')
|
|
|
|
|
|
if base_endpoint in config.endpoints:
|
|
|
|
|
|
return False # It's Ollama's /v1
|
|
|
|
|
|
|
|
|
|
|
|
# Check for default Ollama port
|
|
|
|
|
|
if ':11434' in endpoint:
|
|
|
|
|
|
return False # It's Ollama
|
|
|
|
|
|
|
|
|
|
|
|
return True # It's an external OpenAI endpoint
|
2025-09-01 11:07:07 +02:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
def is_openai_compatible(endpoint: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Return True if the endpoint speaks the OpenAI API (not native Ollama).
|
|
|
|
|
|
This includes external OpenAI endpoints AND llama-server endpoints.
|
|
|
|
|
|
"""
|
|
|
|
|
|
return "/v1" in endpoint or endpoint in config.llama_server_endpoints
|
|
|
|
|
|
|
2026-02-18 11:45:37 +01:00
|
|
|
|
def get_tracking_model(endpoint: str, model: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Normalize model name for tracking purposes so it matches the PS table key.
|
|
|
|
|
|
|
|
|
|
|
|
- For llama-server endpoints: strips HF prefix and quantization suffix
|
|
|
|
|
|
- For Ollama endpoints: appends ":latest" if no version suffix is present
|
|
|
|
|
|
- For external OpenAI endpoints: returns as-is (not shown in PS)
|
|
|
|
|
|
|
|
|
|
|
|
This ensures consistent model naming across all routes for usage tracking.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# External OpenAI endpoints are not shown in PS, keep as-is
|
|
|
|
|
|
if is_ext_openai_endpoint(endpoint):
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
# llama-server endpoints use normalized names in PS
|
|
|
|
|
|
if endpoint in config.llama_server_endpoints:
|
|
|
|
|
|
return _normalize_llama_model_name(model)
|
|
|
|
|
|
|
|
|
|
|
|
# Ollama endpoints: append ":latest" if no version suffix
|
|
|
|
|
|
if ":" not in model:
|
|
|
|
|
|
return model + ":latest"
|
|
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
2025-11-10 15:37:46 +01:00
|
|
|
|
async def token_worker() -> None:
|
2026-01-05 17:16:31 +01:00
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
endpoint, model, prompt, comp = await token_queue.get()
|
2026-01-16 16:47:24 +01:00
|
|
|
|
# Calculate timestamp once before acquiring lock
|
|
|
|
|
|
now = datetime.now(tz=timezone.utc)
|
|
|
|
|
|
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
|
|
|
|
|
|
2026-01-05 17:16:31 +01:00
|
|
|
|
# Accumulate counts in memory buffer (protected by lock)
|
|
|
|
|
|
async with buffer_lock:
|
|
|
|
|
|
token_buffer[endpoint][model] = (
|
|
|
|
|
|
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
|
|
|
|
|
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
|
|
|
|
|
)
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
2026-01-05 17:16:31 +01:00
|
|
|
|
# Add to time series buffer with timestamp (UTC)
|
|
|
|
|
|
time_series_buffer.append({
|
|
|
|
|
|
'endpoint': endpoint,
|
|
|
|
|
|
'model': model,
|
|
|
|
|
|
'input_tokens': prompt,
|
|
|
|
|
|
'output_tokens': comp,
|
|
|
|
|
|
'total_tokens': prompt + comp,
|
|
|
|
|
|
'timestamp': timestamp
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# Update in-memory counts for immediate reporting
|
|
|
|
|
|
async with token_usage_lock:
|
|
|
|
|
|
token_usage_counts[endpoint][model] += (prompt + comp)
|
|
|
|
|
|
await publish_snapshot()
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
# Gracefully handle task cancellation during shutdown
|
|
|
|
|
|
print("[token_worker] Task cancelled, processing remaining queue items...")
|
|
|
|
|
|
# Process any remaining items in the queue before exiting
|
|
|
|
|
|
while not token_queue.empty():
|
|
|
|
|
|
try:
|
|
|
|
|
|
endpoint, model, prompt, comp = token_queue.get_nowait()
|
2026-01-16 16:47:24 +01:00
|
|
|
|
# Calculate timestamp once before acquiring lock
|
|
|
|
|
|
now = datetime.now(tz=timezone.utc)
|
|
|
|
|
|
timestamp = int(datetime(now.year, now.month, now.day, now.hour, now.minute, tzinfo=timezone.utc).timestamp())
|
|
|
|
|
|
|
2026-01-05 17:16:31 +01:00
|
|
|
|
async with buffer_lock:
|
|
|
|
|
|
token_buffer[endpoint][model] = (
|
|
|
|
|
|
token_buffer[endpoint].get(model, (0, 0))[0] + prompt,
|
|
|
|
|
|
token_buffer[endpoint].get(model, (0, 0))[1] + comp
|
|
|
|
|
|
)
|
|
|
|
|
|
time_series_buffer.append({
|
|
|
|
|
|
'endpoint': endpoint,
|
|
|
|
|
|
'model': model,
|
|
|
|
|
|
'input_tokens': prompt,
|
|
|
|
|
|
'output_tokens': comp,
|
|
|
|
|
|
'total_tokens': prompt + comp,
|
|
|
|
|
|
'timestamp': timestamp
|
|
|
|
|
|
})
|
|
|
|
|
|
async with token_usage_lock:
|
|
|
|
|
|
token_usage_counts[endpoint][model] += (prompt + comp)
|
|
|
|
|
|
await publish_snapshot()
|
|
|
|
|
|
except asyncio.QueueEmpty:
|
|
|
|
|
|
break
|
|
|
|
|
|
print("[token_worker] Task cancelled, remaining items processed.")
|
|
|
|
|
|
raise
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
|
|
|
|
|
async def flush_buffer() -> None:
|
|
|
|
|
|
"""Periodically flush accumulated token counts to the database."""
|
2026-01-05 17:16:31 +01:00
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
await asyncio.sleep(FLUSH_INTERVAL)
|
|
|
|
|
|
|
|
|
|
|
|
# Flush token counts and time series (protected by lock)
|
|
|
|
|
|
async with buffer_lock:
|
|
|
|
|
|
if token_buffer:
|
|
|
|
|
|
# Copy buffer before releasing lock for DB operation
|
|
|
|
|
|
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
|
|
|
|
|
token_buffer.clear()
|
|
|
|
|
|
else:
|
|
|
|
|
|
buffer_copy = None
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
2026-01-05 17:16:31 +01:00
|
|
|
|
if time_series_buffer:
|
|
|
|
|
|
ts_copy = list(time_series_buffer)
|
|
|
|
|
|
time_series_buffer.clear()
|
|
|
|
|
|
else:
|
|
|
|
|
|
ts_copy = None
|
|
|
|
|
|
|
|
|
|
|
|
# Perform DB operations outside the lock to avoid blocking
|
|
|
|
|
|
if buffer_copy:
|
|
|
|
|
|
await db.update_batched_counts(buffer_copy)
|
|
|
|
|
|
if ts_copy:
|
|
|
|
|
|
await db.add_batched_time_series(ts_copy)
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
# Gracefully handle task cancellation during shutdown
|
|
|
|
|
|
print("[flush_buffer] Task cancelled, flushing remaining buffers...")
|
|
|
|
|
|
# Flush any remaining data before exiting
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with buffer_lock:
|
|
|
|
|
|
if token_buffer:
|
|
|
|
|
|
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
|
|
|
|
|
token_buffer.clear()
|
|
|
|
|
|
else:
|
|
|
|
|
|
buffer_copy = None
|
|
|
|
|
|
if time_series_buffer:
|
|
|
|
|
|
ts_copy = list(time_series_buffer)
|
|
|
|
|
|
time_series_buffer.clear()
|
|
|
|
|
|
else:
|
|
|
|
|
|
ts_copy = None
|
|
|
|
|
|
if buffer_copy:
|
|
|
|
|
|
await db.update_batched_counts(buffer_copy)
|
|
|
|
|
|
if ts_copy:
|
|
|
|
|
|
await db.add_batched_time_series(ts_copy)
|
|
|
|
|
|
print("[flush_buffer] Task cancelled, remaining buffers flushed.")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[flush_buffer] Error during shutdown flush: {e}")
|
|
|
|
|
|
raise
|
2025-11-04 17:55:19 +01:00
|
|
|
|
|
2025-12-02 12:18:23 +01:00
|
|
|
|
async def flush_remaining_buffers() -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Flush any in-memory buffers to the database on shutdown.
|
|
|
|
|
|
This is designed to be safely invoked during shutdown and should not raise.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
flushed_entries = 0
|
2026-01-05 17:16:31 +01:00
|
|
|
|
async with buffer_lock:
|
|
|
|
|
|
if token_buffer:
|
|
|
|
|
|
buffer_copy = {ep: dict(models) for ep, models in token_buffer.items()}
|
|
|
|
|
|
flushed_entries += sum(len(v) for v in token_buffer.values())
|
|
|
|
|
|
token_buffer.clear()
|
|
|
|
|
|
else:
|
|
|
|
|
|
buffer_copy = None
|
|
|
|
|
|
if time_series_buffer:
|
|
|
|
|
|
ts_copy = list(time_series_buffer)
|
|
|
|
|
|
flushed_entries += len(time_series_buffer)
|
|
|
|
|
|
time_series_buffer.clear()
|
|
|
|
|
|
else:
|
|
|
|
|
|
ts_copy = None
|
|
|
|
|
|
# Perform DB operations outside the lock
|
|
|
|
|
|
if buffer_copy:
|
|
|
|
|
|
await db.update_batched_counts(buffer_copy)
|
|
|
|
|
|
if ts_copy:
|
|
|
|
|
|
await db.add_batched_time_series(ts_copy)
|
2025-12-02 12:18:23 +01:00
|
|
|
|
if flushed_entries:
|
|
|
|
|
|
print(f"[shutdown] Flushed {flushed_entries} in-memory entries to DB on shutdown.")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print("[shutdown] No in-memory entries to flush on shutdown.")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Do not raise during shutdown – log and continue teardown
|
|
|
|
|
|
print(f"[shutdown] Error flushing remaining buffers: {e}")
|
|
|
|
|
|
|
2025-09-13 16:57:09 +02:00
|
|
|
|
class fetch:
|
2026-01-29 18:00:33 +01:00
|
|
|
|
async def _fetch_available_models_internal(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
"""
|
2026-01-29 18:00:33 +01:00
|
|
|
|
Internal function that performs the actual HTTP request to fetch available models.
|
|
|
|
|
|
This is called by available_models() after checking caches and in-flight requests.
|
2025-09-13 16:57:09 +02:00
|
|
|
|
"""
|
|
|
|
|
|
headers = None
|
|
|
|
|
|
if api_key is not None:
|
|
|
|
|
|
headers = {"Authorization": "Bearer " + api_key}
|
|
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if endpoint in config.llama_server_endpoints and "/v1" not in endpoint:
|
|
|
|
|
|
endpoint_url = f"{endpoint}/v1/models"
|
|
|
|
|
|
key = "data"
|
|
|
|
|
|
elif "/v1" in endpoint or endpoint in config.llama_server_endpoints:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
endpoint_url = f"{endpoint}/models"
|
|
|
|
|
|
key = "data"
|
2025-09-01 13:38:49 +02:00
|
|
|
|
else:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
endpoint_url = f"{endpoint}/api/tags"
|
|
|
|
|
|
key = "models"
|
2026-01-29 18:00:33 +01:00
|
|
|
|
|
2025-09-13 16:57:09 +02:00
|
|
|
|
client: aiohttp.ClientSession = app_state["session"]
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.get(endpoint_url, headers=headers) as resp:
|
|
|
|
|
|
await _ensure_success(resp)
|
|
|
|
|
|
data = await resp.json()
|
|
|
|
|
|
|
|
|
|
|
|
items = data.get(key, [])
|
|
|
|
|
|
models = {item.get("id") or item.get("name") for item in items if item.get("id") or item.get("name")}
|
2026-01-16 16:47:24 +01:00
|
|
|
|
|
|
|
|
|
|
# Update cache with lock protection
|
|
|
|
|
|
async with _models_cache_lock:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
_models_cache[endpoint] = (models, time.time())
|
2026-01-16 16:47:24 +01:00
|
|
|
|
return models
|
2025-09-13 16:57:09 +02:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Treat any error as if the endpoint offers no models
|
2025-11-07 13:59:16 +01:00
|
|
|
|
message = _format_connection_issue(endpoint_url, e)
|
|
|
|
|
|
print(f"[fetch.available_models] {message}")
|
2026-01-16 16:47:24 +01:00
|
|
|
|
# Update error cache with lock protection
|
2026-02-08 16:46:40 +01:00
|
|
|
|
async with _available_error_cache_lock:
|
|
|
|
|
|
_available_error_cache[endpoint] = time.time()
|
2025-09-01 13:38:49 +02:00
|
|
|
|
return set()
|
|
|
|
|
|
|
2026-02-08 16:46:40 +01:00
|
|
|
|
async def _refresh_available_models(endpoint: str, api_key: Optional[str] = None) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Background task to refresh available models cache without blocking the caller.
|
|
|
|
|
|
Used for stale-while-revalidate pattern.
|
2026-02-14 14:51:44 +01:00
|
|
|
|
Deduplicates: only one background refresh runs per endpoint at a time.
|
2026-02-08 16:46:40 +01:00
|
|
|
|
"""
|
2026-02-14 14:51:44 +01:00
|
|
|
|
async with _bg_refresh_lock:
|
|
|
|
|
|
if endpoint in _bg_refresh_available and not _bg_refresh_available[endpoint].done():
|
|
|
|
|
|
return # A refresh is already running for this endpoint
|
|
|
|
|
|
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
|
|
|
|
|
_bg_refresh_available[endpoint] = task
|
|
|
|
|
|
|
2026-02-08 16:46:40 +01:00
|
|
|
|
try:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
await task
|
2026-02-08 16:46:40 +01:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Silently fail - cache will remain stale but functional
|
|
|
|
|
|
print(f"[fetch._refresh_available_models] Background refresh failed for {endpoint}: {e}")
|
2026-02-14 14:51:44 +01:00
|
|
|
|
finally:
|
|
|
|
|
|
async with _bg_refresh_lock:
|
|
|
|
|
|
if _bg_refresh_available.get(endpoint) is task:
|
|
|
|
|
|
_bg_refresh_available.pop(endpoint, None)
|
2026-02-08 16:46:40 +01:00
|
|
|
|
|
2026-01-29 18:00:33 +01:00
|
|
|
|
async def available_models(endpoint: str, api_key: Optional[str] = None) -> Set[str]:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
"""
|
2026-01-29 18:00:33 +01:00
|
|
|
|
Query <endpoint>/api/tags and return a set of all model names that the
|
|
|
|
|
|
endpoint *advertises* (i.e. is capable of serving). This endpoint lists
|
|
|
|
|
|
every model that is installed on the Ollama instance, regardless of
|
|
|
|
|
|
whether the model is currently loaded into memory.
|
|
|
|
|
|
|
|
|
|
|
|
Uses request coalescing to prevent cache stampede: if multiple requests
|
|
|
|
|
|
arrive when cache is expired, only one actual HTTP request is made.
|
|
|
|
|
|
|
2026-02-08 16:46:40 +01:00
|
|
|
|
Uses stale-while-revalidate: when the cache is between 300-600s old,
|
|
|
|
|
|
the stale data is returned immediately while a background refresh runs.
|
|
|
|
|
|
This prevents model blackouts caused by transient timeouts.
|
|
|
|
|
|
|
2026-01-29 18:00:33 +01:00
|
|
|
|
If the request fails (e.g. timeout, 5xx, or malformed response), an empty
|
2025-09-13 16:57:09 +02:00
|
|
|
|
set is returned.
|
|
|
|
|
|
"""
|
2026-01-29 18:00:33 +01:00
|
|
|
|
# Check models cache with lock protection
|
|
|
|
|
|
async with _models_cache_lock:
|
|
|
|
|
|
if endpoint in _models_cache:
|
|
|
|
|
|
models, cached_at = _models_cache[endpoint]
|
2026-02-08 16:46:40 +01:00
|
|
|
|
|
2026-02-14 14:51:44 +01:00
|
|
|
|
# FRESH: <= 300s old - return immediately
|
2026-01-29 18:00:33 +01:00
|
|
|
|
if _is_fresh(cached_at, 300):
|
2026-01-16 16:47:24 +01:00
|
|
|
|
return models
|
2026-02-08 16:46:40 +01:00
|
|
|
|
|
|
|
|
|
|
# STALE: 300-600s old - return stale data and refresh in background
|
|
|
|
|
|
if _is_fresh(cached_at, 600):
|
|
|
|
|
|
asyncio.create_task(fetch._refresh_available_models(endpoint, api_key))
|
|
|
|
|
|
return models # Return stale data immediately
|
|
|
|
|
|
|
|
|
|
|
|
# EXPIRED: > 600s old - too stale, must refresh synchronously
|
2026-01-29 18:00:33 +01:00
|
|
|
|
del _models_cache[endpoint]
|
2025-11-17 14:40:24 +01:00
|
|
|
|
|
2026-01-16 16:47:24 +01:00
|
|
|
|
# Check error cache with lock protection
|
2026-02-08 16:46:40 +01:00
|
|
|
|
async with _available_error_cache_lock:
|
|
|
|
|
|
if endpoint in _available_error_cache:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
if _is_fresh(_available_error_cache[endpoint], 300):
|
2026-01-29 18:00:33 +01:00
|
|
|
|
# Still within the short error TTL – pretend nothing is available
|
2026-01-16 16:47:24 +01:00
|
|
|
|
return set()
|
2026-01-29 18:00:33 +01:00
|
|
|
|
# Error expired – remove it
|
2026-02-08 16:46:40 +01:00
|
|
|
|
del _available_error_cache[endpoint]
|
2026-01-16 16:47:24 +01:00
|
|
|
|
|
2026-01-29 18:00:33 +01:00
|
|
|
|
# Request coalescing: check if another request is already fetching this endpoint
|
|
|
|
|
|
async with _inflight_lock:
|
|
|
|
|
|
if endpoint in _inflight_available_models:
|
|
|
|
|
|
# Another request is already fetching - wait for it
|
|
|
|
|
|
task = _inflight_available_models[endpoint]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Create new fetch task
|
|
|
|
|
|
task = asyncio.create_task(fetch._fetch_available_models_internal(endpoint, api_key))
|
|
|
|
|
|
_inflight_available_models[endpoint] = task
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# Wait for the fetch to complete (either ours or another request's)
|
|
|
|
|
|
result = await task
|
|
|
|
|
|
return result
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Clean up in-flight tracking (only if we created it)
|
|
|
|
|
|
async with _inflight_lock:
|
|
|
|
|
|
if _inflight_available_models.get(endpoint) == task:
|
|
|
|
|
|
_inflight_available_models.pop(endpoint, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _fetch_loaded_models_internal(endpoint: str) -> Set[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Internal function that performs the actual HTTP request to fetch loaded models.
|
|
|
|
|
|
This is called by loaded_models() after checking caches and in-flight requests.
|
2026-02-10 16:46:51 +01:00
|
|
|
|
|
|
|
|
|
|
For Ollama endpoints: queries /api/ps and returns model names
|
|
|
|
|
|
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
2026-01-29 18:00:33 +01:00
|
|
|
|
"""
|
2025-11-17 14:40:24 +01:00
|
|
|
|
client: aiohttp.ClientSession = app_state["session"]
|
2026-02-10 16:46:51 +01:00
|
|
|
|
|
|
|
|
|
|
# Check if this is a llama-server endpoint
|
|
|
|
|
|
if endpoint in config.llama_server_endpoints:
|
|
|
|
|
|
# Query /v1/models for llama-server
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.get(f"{endpoint}/models") as resp:
|
|
|
|
|
|
await _ensure_success(resp)
|
|
|
|
|
|
data = await resp.json()
|
|
|
|
|
|
|
|
|
|
|
|
# Filter for loaded models only
|
|
|
|
|
|
items = data.get("data", [])
|
|
|
|
|
|
models = {
|
|
|
|
|
|
item.get("id")
|
|
|
|
|
|
for item in items
|
|
|
|
|
|
if item.get("id") and _is_llama_model_loaded(item)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# Update cache with lock protection
|
|
|
|
|
|
async with _loaded_models_cache_lock:
|
|
|
|
|
|
_loaded_models_cache[endpoint] = (models, time.time())
|
|
|
|
|
|
return models
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# If anything goes wrong we simply assume the endpoint has no models
|
|
|
|
|
|
message = _format_connection_issue(f"{endpoint}/models", e)
|
|
|
|
|
|
print(f"[fetch.loaded_models] {message}")
|
|
|
|
|
|
return set()
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Original Ollama /api/ps logic
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.get(f"{endpoint}/api/ps") as resp:
|
|
|
|
|
|
await _ensure_success(resp)
|
|
|
|
|
|
data = await resp.json()
|
|
|
|
|
|
# The response format is:
|
|
|
|
|
|
# {"models": [{"name": "model1"}, {"name": "model2"}]}
|
|
|
|
|
|
models = {m.get("name") for m in data.get("models", []) if m.get("name")}
|
|
|
|
|
|
|
|
|
|
|
|
# Update cache with lock protection
|
|
|
|
|
|
async with _loaded_models_cache_lock:
|
|
|
|
|
|
_loaded_models_cache[endpoint] = (models, time.time())
|
|
|
|
|
|
return models
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# If anything goes wrong we simply assume the endpoint has no models
|
|
|
|
|
|
message = _format_connection_issue(f"{endpoint}/api/ps", e)
|
|
|
|
|
|
print(f"[fetch.loaded_models] {message}")
|
|
|
|
|
|
return set()
|
2025-08-29 13:13:25 +02:00
|
|
|
|
|
2026-01-29 18:00:33 +01:00
|
|
|
|
async def _refresh_loaded_models(endpoint: str) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Background task to refresh loaded models cache without blocking the caller.
|
|
|
|
|
|
Used for stale-while-revalidate pattern.
|
2026-02-14 14:51:44 +01:00
|
|
|
|
Deduplicates: only one background refresh runs per endpoint at a time.
|
2026-01-29 18:00:33 +01:00
|
|
|
|
"""
|
2026-02-14 14:51:44 +01:00
|
|
|
|
async with _bg_refresh_lock:
|
|
|
|
|
|
if endpoint in _bg_refresh_loaded and not _bg_refresh_loaded[endpoint].done():
|
|
|
|
|
|
return # A refresh is already running for this endpoint
|
|
|
|
|
|
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
|
|
|
|
|
_bg_refresh_loaded[endpoint] = task
|
|
|
|
|
|
|
2026-01-29 18:00:33 +01:00
|
|
|
|
try:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
await task
|
2026-01-29 18:00:33 +01:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# Silently fail - cache will remain stale but functional
|
|
|
|
|
|
print(f"[fetch._refresh_loaded_models] Background refresh failed for {endpoint}: {e}")
|
2026-02-14 14:51:44 +01:00
|
|
|
|
finally:
|
|
|
|
|
|
async with _bg_refresh_lock:
|
|
|
|
|
|
if _bg_refresh_loaded.get(endpoint) is task:
|
|
|
|
|
|
_bg_refresh_loaded.pop(endpoint, None)
|
2026-01-29 18:00:33 +01:00
|
|
|
|
|
|
|
|
|
|
async def loaded_models(endpoint: str) -> Set[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Query <endpoint>/api/ps and return a set of model names that are currently
|
|
|
|
|
|
loaded on that endpoint. If the request fails (e.g. timeout, 5xx), an empty
|
|
|
|
|
|
set is returned.
|
|
|
|
|
|
|
|
|
|
|
|
Uses request coalescing to prevent cache stampede and stale-while-revalidate
|
|
|
|
|
|
to serve requests immediately even when cache is stale (refreshing in background).
|
|
|
|
|
|
"""
|
|
|
|
|
|
if is_ext_openai_endpoint(endpoint):
|
|
|
|
|
|
return set()
|
|
|
|
|
|
|
|
|
|
|
|
# Check loaded models cache with lock protection
|
|
|
|
|
|
async with _loaded_models_cache_lock:
|
|
|
|
|
|
if endpoint in _loaded_models_cache:
|
|
|
|
|
|
models, cached_at = _loaded_models_cache[endpoint]
|
|
|
|
|
|
|
|
|
|
|
|
# FRESH: < 10s old - return immediately
|
2026-02-15 12:15:36 +01:00
|
|
|
|
if _is_fresh(cached_at, 10):
|
2026-01-29 18:00:33 +01:00
|
|
|
|
return models
|
|
|
|
|
|
|
|
|
|
|
|
# STALE: 10-60s old - return stale data and refresh in background
|
|
|
|
|
|
if _is_fresh(cached_at, 60):
|
|
|
|
|
|
# Kick off background refresh (fire-and-forget)
|
|
|
|
|
|
asyncio.create_task(fetch._refresh_loaded_models(endpoint))
|
|
|
|
|
|
return models # Return stale data immediately
|
|
|
|
|
|
|
|
|
|
|
|
# EXPIRED: > 60s old - too stale, must refresh synchronously
|
|
|
|
|
|
del _loaded_models_cache[endpoint]
|
|
|
|
|
|
|
|
|
|
|
|
# Check error cache with lock protection
|
2026-02-08 16:46:40 +01:00
|
|
|
|
async with _loaded_error_cache_lock:
|
|
|
|
|
|
if endpoint in _loaded_error_cache:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
if _is_fresh(_loaded_error_cache[endpoint], 300):
|
2026-01-29 18:00:33 +01:00
|
|
|
|
return set()
|
|
|
|
|
|
# Error expired - remove it
|
2026-02-08 16:46:40 +01:00
|
|
|
|
del _loaded_error_cache[endpoint]
|
2026-01-29 18:00:33 +01:00
|
|
|
|
|
|
|
|
|
|
# Request coalescing: check if another request is already fetching this endpoint
|
|
|
|
|
|
async with _inflight_lock:
|
|
|
|
|
|
if endpoint in _inflight_loaded_models:
|
|
|
|
|
|
# Another request is already fetching - wait for it
|
|
|
|
|
|
task = _inflight_loaded_models[endpoint]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Create new fetch task
|
|
|
|
|
|
task = asyncio.create_task(fetch._fetch_loaded_models_internal(endpoint))
|
|
|
|
|
|
_inflight_loaded_models[endpoint] = task
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# Wait for the fetch to complete (either ours or another request's)
|
|
|
|
|
|
result = await task
|
|
|
|
|
|
return result
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Clean up in-flight tracking (only if we created it)
|
|
|
|
|
|
async with _inflight_lock:
|
|
|
|
|
|
if _inflight_loaded_models.get(endpoint) == task:
|
|
|
|
|
|
_inflight_loaded_models.pop(endpoint, None)
|
|
|
|
|
|
|
2026-02-13 10:11:41 +01:00
|
|
|
|
async def endpoint_details(endpoint: str, route: str, detail: str, api_key: Optional[str] = None, skip_error_cache: bool = False) -> List[dict]:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Query <endpoint>/<route> to fetch <detail> and return a List of dicts with details
|
|
|
|
|
|
for the corresponding Ollama endpoint. If the request fails we respond with "N/A" for detail.
|
2026-02-13 10:11:41 +01:00
|
|
|
|
|
|
|
|
|
|
When ``skip_error_cache`` is False (the default), the call is short-circuited
|
|
|
|
|
|
if the endpoint recently failed (recorded in ``_available_error_cache``).
|
|
|
|
|
|
Pass ``skip_error_cache=True`` from health-check routes that must always probe.
|
2025-09-13 16:57:09 +02:00
|
|
|
|
"""
|
2026-02-13 10:11:41 +01:00
|
|
|
|
# Fast-fail if the endpoint is known to be down (unless caller opts out)
|
|
|
|
|
|
if not skip_error_cache:
|
|
|
|
|
|
async with _available_error_cache_lock:
|
|
|
|
|
|
if endpoint in _available_error_cache:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
if _is_fresh(_available_error_cache[endpoint], 300):
|
2026-02-13 10:11:41 +01:00
|
|
|
|
return []
|
|
|
|
|
|
|
2025-09-13 16:57:09 +02:00
|
|
|
|
client: aiohttp.ClientSession = app_state["session"]
|
|
|
|
|
|
headers = None
|
|
|
|
|
|
if api_key is not None:
|
|
|
|
|
|
headers = {"Authorization": "Bearer " + api_key}
|
2026-02-13 10:11:41 +01:00
|
|
|
|
|
2025-11-07 13:59:16 +01:00
|
|
|
|
request_url = f"{endpoint}{route}"
|
2025-09-13 16:57:09 +02:00
|
|
|
|
try:
|
2025-11-07 13:59:16 +01:00
|
|
|
|
async with client.get(request_url, headers=headers) as resp:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
await _ensure_success(resp)
|
|
|
|
|
|
data = await resp.json()
|
|
|
|
|
|
detail = data.get(detail, [])
|
|
|
|
|
|
return detail
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# If anything goes wrong we cannot reply details
|
2025-11-07 13:59:16 +01:00
|
|
|
|
message = _format_connection_issue(request_url, e)
|
|
|
|
|
|
print(f"[fetch.endpoint_details] {message}")
|
2026-02-13 10:11:41 +01:00
|
|
|
|
# Record failure so subsequent calls skip this endpoint briefly
|
|
|
|
|
|
async with _available_error_cache_lock:
|
|
|
|
|
|
_available_error_cache[endpoint] = time.time()
|
2025-09-13 16:57:09 +02:00
|
|
|
|
return []
|
2025-09-01 11:07:07 +02:00
|
|
|
|
|
2025-09-13 18:11:05 +02:00
|
|
|
|
def ep2base(ep):
|
|
|
|
|
|
if "/v1" in ep:
|
|
|
|
|
|
base_url = ep
|
|
|
|
|
|
else:
|
|
|
|
|
|
base_url = ep+"/v1"
|
|
|
|
|
|
return base_url
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
def dedupe_on_keys(dicts, key_fields):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Helper function to deduplicate endpoint details based on given dict keys.
|
|
|
|
|
|
"""
|
|
|
|
|
|
seen = set()
|
|
|
|
|
|
out = []
|
|
|
|
|
|
for d in dicts:
|
|
|
|
|
|
# Build a tuple of the values for the chosen keys
|
|
|
|
|
|
key = tuple(d.get(k) for k in key_fields)
|
|
|
|
|
|
if key not in seen:
|
|
|
|
|
|
seen.add(key)
|
|
|
|
|
|
out.append(d)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
2025-08-29 13:13:25 +02:00
|
|
|
|
async def increment_usage(endpoint: str, model: str) -> None:
|
|
|
|
|
|
async with usage_lock:
|
|
|
|
|
|
usage_counts[endpoint][model] += 1
|
2026-01-29 10:32:59 +01:00
|
|
|
|
await publish_snapshot()
|
2025-08-29 13:13:25 +02:00
|
|
|
|
|
|
|
|
|
|
async def decrement_usage(endpoint: str, model: str) -> None:
|
|
|
|
|
|
async with usage_lock:
|
|
|
|
|
|
# Avoid negative counts
|
|
|
|
|
|
current = usage_counts[endpoint].get(model, 0)
|
|
|
|
|
|
if current > 0:
|
|
|
|
|
|
usage_counts[endpoint][model] = current - 1
|
|
|
|
|
|
# Optionally, clean up zero entries
|
|
|
|
|
|
if usage_counts[endpoint].get(model, 0) == 0:
|
|
|
|
|
|
usage_counts[endpoint].pop(model, None)
|
2025-09-05 12:11:31 +02:00
|
|
|
|
#if not usage_counts[endpoint]:
|
|
|
|
|
|
# usage_counts.pop(endpoint, None)
|
2026-01-29 10:32:59 +01:00
|
|
|
|
await publish_snapshot()
|
2025-09-05 12:11:31 +02:00
|
|
|
|
|
2026-03-03 14:57:37 +01:00
|
|
|
|
async def _make_chat_request(model: str, messages: list, tools=None, stream: bool = False, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
2025-12-15 10:35:56 +01:00
|
|
|
|
"""
|
|
|
|
|
|
Helper function to make a chat request to a specific endpoint.
|
|
|
|
|
|
Handles endpoint selection, client creation, usage tracking, and request execution.
|
|
|
|
|
|
"""
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model) # selects and atomically reserves
|
2026-02-10 16:46:51 +01:00
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
2025-12-15 10:35:56 +01:00
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")[0]
|
|
|
|
|
|
if messages:
|
|
|
|
|
|
messages = transform_images_to_data_urls(messages)
|
2026-02-10 20:21:46 +01:00
|
|
|
|
messages = transform_tool_calls_to_openai(messages)
|
2025-12-15 10:35:56 +01:00
|
|
|
|
params = {
|
|
|
|
|
|
"messages": messages,
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
}
|
|
|
|
|
|
optional_params = {
|
|
|
|
|
|
"tools": tools,
|
|
|
|
|
|
"stream": stream,
|
|
|
|
|
|
"stream_options": {"include_usage": True} if stream else None,
|
|
|
|
|
|
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
|
|
|
|
|
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
|
|
|
|
|
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
|
|
|
|
|
"seed": options.get("seed") if options and "seed" in options else None,
|
|
|
|
|
|
"stop": options.get("stop") if options and "stop" in options else None,
|
|
|
|
|
|
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
|
|
|
|
|
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
|
|
|
|
|
"response_format": {"type": "json_schema", "json_schema": format} if format is not None else None
|
|
|
|
|
|
}
|
|
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
2026-02-10 16:46:51 +01:00
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-12-15 10:35:56 +01:00
|
|
|
|
start_ts = time.perf_counter()
|
|
|
|
|
|
response = await oclient.chat.completions.create(**params)
|
|
|
|
|
|
if stream:
|
|
|
|
|
|
# For streaming, we need to collect all chunks
|
|
|
|
|
|
chunks = []
|
2026-02-10 20:21:46 +01:00
|
|
|
|
tc_acc = {} # accumulate tool-call deltas
|
2025-12-15 10:35:56 +01:00
|
|
|
|
async for chunk in response:
|
|
|
|
|
|
chunks.append(chunk)
|
2026-02-10 20:21:46 +01:00
|
|
|
|
_accumulate_openai_tc_delta(chunk, tc_acc)
|
2026-02-14 14:51:44 +01:00
|
|
|
|
prompt_tok = 0
|
|
|
|
|
|
comp_tok = 0
|
2025-12-15 10:35:56 +01:00
|
|
|
|
if chunk.usage is not None:
|
|
|
|
|
|
prompt_tok = chunk.usage.prompt_tokens or 0
|
|
|
|
|
|
comp_tok = chunk.usage.completion_tokens or 0
|
2026-02-14 14:51:44 +01:00
|
|
|
|
else:
|
|
|
|
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
|
|
|
|
|
if llama_usage:
|
|
|
|
|
|
prompt_tok, comp_tok = llama_usage
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
# Convert to Ollama format
|
|
|
|
|
|
if chunks:
|
|
|
|
|
|
response = rechunk.openai_chat_completion2ollama(chunks[-1], stream, start_ts)
|
2026-02-10 20:21:46 +01:00
|
|
|
|
# Inject fully-accumulated tool calls into the final response
|
|
|
|
|
|
if tc_acc and response.message:
|
|
|
|
|
|
response.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
2025-12-15 10:35:56 +01:00
|
|
|
|
else:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
prompt_tok = 0
|
|
|
|
|
|
comp_tok = 0
|
|
|
|
|
|
if response.usage is not None:
|
|
|
|
|
|
prompt_tok = response.usage.prompt_tokens or 0
|
|
|
|
|
|
comp_tok = response.usage.completion_tokens or 0
|
|
|
|
|
|
else:
|
|
|
|
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(response)
|
|
|
|
|
|
if llama_usage:
|
|
|
|
|
|
prompt_tok, comp_tok = llama_usage
|
2025-12-15 10:35:56 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
response = rechunk.openai_chat_completion2ollama(response, stream, start_ts)
|
|
|
|
|
|
else:
|
|
|
|
|
|
response = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=format, options=options, keep_alive=keep_alive)
|
|
|
|
|
|
if stream:
|
|
|
|
|
|
# For streaming, collect all chunks
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
async for chunk in response:
|
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
|
prompt_tok = chunk.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = chunk.eval_count or 0
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
if chunks:
|
|
|
|
|
|
response = chunks[-1]
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompt_tok = response.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = response.eval_count or 0
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
finally:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-12-15 10:35:56 +01:00
|
|
|
|
|
|
|
|
|
|
def get_last_user_content(messages):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Given a list of dicts (e.g., messages from an API),
|
|
|
|
|
|
return the 'content' of the last dict whose 'role' is 'user'.
|
|
|
|
|
|
If no such dict exists, return None.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Reverse iterate so we stop at the first match
|
|
|
|
|
|
for msg in reversed(messages):
|
|
|
|
|
|
if msg.get("role") == "user":
|
|
|
|
|
|
return msg.get("content")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
async def _make_moe_requests(model: str, messages: list, tools=None, think: bool = False, format=None, options=None, keep_alive: str = None) -> ollama.ChatResponse:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Helper function to make MOE (Multiple Opinions Ensemble) requests.
|
|
|
|
|
|
Generates 3 responses, 3 critiques, and returns the final selected response.
|
|
|
|
|
|
"""
|
|
|
|
|
|
query = get_last_user_content(messages)
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
raise ValueError("No user query found in messages")
|
|
|
|
|
|
|
|
|
|
|
|
if options is None:
|
|
|
|
|
|
options = {}
|
|
|
|
|
|
options["temperature"] = 1
|
|
|
|
|
|
|
|
|
|
|
|
moe_reqs = []
|
|
|
|
|
|
|
2026-03-03 14:57:37 +01:00
|
|
|
|
# Generate 3 responses — choose_endpoint is called inside _make_chat_request and
|
|
|
|
|
|
# atomically reserves a slot, so all 3 tasks see each other's load immediately.
|
|
|
|
|
|
response1_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
|
|
|
|
response2_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
|
|
|
|
response3_task = asyncio.create_task(_make_chat_request(model, messages, tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
|
|
|
|
|
|
responses = await asyncio.gather(response1_task, response2_task, response3_task)
|
|
|
|
|
|
|
|
|
|
|
|
for n, r in enumerate(responses):
|
|
|
|
|
|
moe_req = enhance.moe(query, n, r.message.content)
|
|
|
|
|
|
moe_reqs.append(moe_req)
|
|
|
|
|
|
|
|
|
|
|
|
# Generate 3 critiques
|
2026-03-03 14:57:37 +01:00
|
|
|
|
critique1_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[0]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
|
|
|
|
critique2_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[1]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
|
|
|
|
|
critique3_task = asyncio.create_task(_make_chat_request(model, [{"role": "user", "content": moe_reqs[2]}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive))
|
2025-12-15 10:35:56 +01:00
|
|
|
|
|
|
|
|
|
|
critiques = await asyncio.gather(critique1_task, critique2_task, critique3_task)
|
|
|
|
|
|
|
|
|
|
|
|
# Select final response
|
|
|
|
|
|
m = enhance.moe_select_candidate(query, critiques)
|
|
|
|
|
|
|
|
|
|
|
|
# Generate final response
|
2026-03-03 14:57:37 +01:00
|
|
|
|
return await _make_chat_request(model, [{"role": "user", "content": m}], tools, stream=False, think=think, format=format, options=options, keep_alive=keep_alive)
|
2025-12-15 10:35:56 +01:00
|
|
|
|
|
2025-09-13 11:24:28 +02:00
|
|
|
|
def iso8601_ns():
|
2025-10-30 10:17:18 +01:00
|
|
|
|
ns = time.time_ns()
|
|
|
|
|
|
sec, ns_rem = divmod(ns, 1_000_000_000)
|
2025-11-18 11:16:21 +01:00
|
|
|
|
dt = datetime.fromtimestamp(sec, tz=timezone.utc)
|
2025-10-30 10:17:18 +01:00
|
|
|
|
return (
|
|
|
|
|
|
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}T"
|
|
|
|
|
|
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}."
|
|
|
|
|
|
f"{ns_rem:09d}Z"
|
2025-09-13 11:24:28 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-09-23 17:33:15 +02:00
|
|
|
|
def is_base64(image_string):
|
|
|
|
|
|
try:
|
|
|
|
|
|
if isinstance(image_string, str) and base64.b64encode(base64.b64decode(image_string)) == image_string.encode():
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
2025-09-24 11:46:38 +02:00
|
|
|
|
def resize_image_if_needed(image_data):
|
|
|
|
|
|
try:
|
2025-09-24 18:10:17 +02:00
|
|
|
|
# Check if already data-url
|
|
|
|
|
|
if image_data.startswith("data:"):
|
|
|
|
|
|
try:
|
|
|
|
|
|
header, image_data = image_data.split(",", 1)
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
pass
|
2025-09-24 11:46:38 +02:00
|
|
|
|
# Decode the base64 image data
|
|
|
|
|
|
image_bytes = base64.b64decode(image_data)
|
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes))
|
2025-09-24 18:10:17 +02:00
|
|
|
|
if image.mode not in ("RGB", "L"):
|
|
|
|
|
|
image = image.convert("RGB")
|
2025-09-24 11:46:38 +02:00
|
|
|
|
|
|
|
|
|
|
# Get current size
|
|
|
|
|
|
width, height = image.size
|
|
|
|
|
|
|
|
|
|
|
|
# Calculate the new dimensions while maintaining aspect ratio
|
|
|
|
|
|
if width > 512 or height > 512:
|
|
|
|
|
|
aspect_ratio = width / height
|
|
|
|
|
|
if aspect_ratio > 1: # Width is larger
|
|
|
|
|
|
new_width = 512
|
|
|
|
|
|
new_height = int(512 / aspect_ratio)
|
|
|
|
|
|
else: # Height is larger
|
|
|
|
|
|
new_height = 512
|
|
|
|
|
|
new_width = int(512 * aspect_ratio)
|
|
|
|
|
|
|
2025-10-28 11:08:52 +01:00
|
|
|
|
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
2025-09-24 11:46:38 +02:00
|
|
|
|
|
|
|
|
|
|
# Encode the resized image back to base64
|
|
|
|
|
|
buffered = io.BytesIO()
|
|
|
|
|
|
image.save(buffered, format="PNG")
|
|
|
|
|
|
resized_image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
|
|
return resized_image_data
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"Error processing image: {e}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-02-10 20:21:46 +01:00
|
|
|
|
def transform_tool_calls_to_openai(message_list):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Ensure tool_calls in assistant messages conform to the OpenAI format:
|
|
|
|
|
|
- Each tool call must have "type": "function"
|
|
|
|
|
|
- Each tool call must have an "id"
|
|
|
|
|
|
- arguments must be a JSON string, not a dict
|
|
|
|
|
|
Also ensure tool-role messages have a tool_call_id.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Track generated IDs so tool-role messages can reference them
|
|
|
|
|
|
last_tool_call_ids = {}
|
|
|
|
|
|
for msg in message_list:
|
|
|
|
|
|
role = msg.get("role")
|
|
|
|
|
|
if role == "assistant" and "tool_calls" in msg:
|
|
|
|
|
|
for tc in msg["tool_calls"]:
|
|
|
|
|
|
if "type" not in tc:
|
|
|
|
|
|
tc["type"] = "function"
|
|
|
|
|
|
if "id" not in tc:
|
|
|
|
|
|
tc["id"] = f"call_{secrets.token_hex(16)}"
|
|
|
|
|
|
func = tc.get("function", {})
|
|
|
|
|
|
if isinstance(func.get("arguments"), dict):
|
|
|
|
|
|
func["arguments"] = orjson.dumps(func["arguments"]).decode("utf-8")
|
|
|
|
|
|
# Remember the id for the following tool-role message
|
|
|
|
|
|
name = func.get("name")
|
|
|
|
|
|
if name:
|
|
|
|
|
|
last_tool_call_ids[name] = tc["id"]
|
|
|
|
|
|
elif role == "tool":
|
|
|
|
|
|
if "tool_call_id" not in msg:
|
|
|
|
|
|
# Try to match by name from a preceding assistant tool_call
|
|
|
|
|
|
name = msg.get("name") or msg.get("tool_name")
|
|
|
|
|
|
if name and name in last_tool_call_ids:
|
|
|
|
|
|
msg["tool_call_id"] = last_tool_call_ids.pop(name)
|
|
|
|
|
|
return message_list
|
|
|
|
|
|
|
2025-09-23 17:33:15 +02:00
|
|
|
|
def transform_images_to_data_urls(message_list):
|
|
|
|
|
|
for message in message_list:
|
|
|
|
|
|
if "images" in message:
|
|
|
|
|
|
images = message.pop("images")
|
|
|
|
|
|
if not isinstance(images, list):
|
|
|
|
|
|
continue
|
|
|
|
|
|
new_content = []
|
|
|
|
|
|
for image in images: #TODO: quality downsize if images are too big to fit into model context window size
|
|
|
|
|
|
if not is_base64(image):
|
|
|
|
|
|
raise ValueError(f"Image string is not a valid base64 encoded string.")
|
2025-09-24 11:46:38 +02:00
|
|
|
|
resized_image = resize_image_if_needed(image)
|
|
|
|
|
|
if resized_image:
|
|
|
|
|
|
data_url = f"data:image/png;base64,{resized_image}"
|
|
|
|
|
|
#new_content.append({
|
|
|
|
|
|
# "type": "text",
|
|
|
|
|
|
# "text": ""
|
|
|
|
|
|
#})
|
|
|
|
|
|
new_content.append({
|
|
|
|
|
|
"type": "image_url",
|
|
|
|
|
|
"image_url": {
|
|
|
|
|
|
"url": data_url
|
|
|
|
|
|
}
|
|
|
|
|
|
})
|
2025-09-23 17:33:15 +02:00
|
|
|
|
message["content"] = new_content
|
|
|
|
|
|
|
|
|
|
|
|
return message_list
|
|
|
|
|
|
|
2026-02-10 20:21:46 +01:00
|
|
|
|
def _accumulate_openai_tc_delta(chunk, accumulator: dict) -> None:
|
|
|
|
|
|
"""Accumulate tool_call deltas from a single OpenAI streaming chunk.
|
|
|
|
|
|
|
|
|
|
|
|
``accumulator`` is a dict mapping tool-call *index* to
|
|
|
|
|
|
``{"id": str, "name": str, "arguments": str}`` where ``arguments``
|
|
|
|
|
|
is the concatenation of all JSON fragments seen so far.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not chunk.choices:
|
|
|
|
|
|
return
|
|
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
|
|
tc_deltas = getattr(delta, "tool_calls", None)
|
|
|
|
|
|
if not tc_deltas:
|
|
|
|
|
|
return
|
|
|
|
|
|
for tc in tc_deltas:
|
|
|
|
|
|
idx = tc.index
|
|
|
|
|
|
if idx not in accumulator:
|
|
|
|
|
|
accumulator[idx] = {
|
|
|
|
|
|
"id": getattr(tc, "id", None) or f"call_{secrets.token_hex(16)}",
|
|
|
|
|
|
"name": tc.function.name if tc.function else None,
|
|
|
|
|
|
"arguments": "",
|
|
|
|
|
|
}
|
|
|
|
|
|
else:
|
|
|
|
|
|
if getattr(tc, "id", None):
|
|
|
|
|
|
accumulator[idx]["id"] = tc.id
|
|
|
|
|
|
if tc.function and tc.function.name:
|
|
|
|
|
|
accumulator[idx]["name"] = tc.function.name
|
|
|
|
|
|
if tc.function and tc.function.arguments:
|
|
|
|
|
|
accumulator[idx]["arguments"] += tc.function.arguments
|
|
|
|
|
|
|
|
|
|
|
|
def _build_ollama_tool_calls(accumulator: dict) -> list | None:
|
|
|
|
|
|
"""Convert accumulated tool-call data into Ollama-format tool_calls list."""
|
|
|
|
|
|
if not accumulator:
|
|
|
|
|
|
return None
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for idx in sorted(accumulator.keys()):
|
|
|
|
|
|
tc = accumulator[idx]
|
|
|
|
|
|
try:
|
|
|
|
|
|
args = orjson.loads(tc["arguments"]) if tc["arguments"] else {}
|
|
|
|
|
|
except (orjson.JSONDecodeError, TypeError):
|
|
|
|
|
|
args = {}
|
|
|
|
|
|
result.append(ollama.Message.ToolCall(
|
|
|
|
|
|
function=ollama.Message.ToolCall.Function(name=tc["name"], arguments=args)
|
|
|
|
|
|
))
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
2026-02-13 13:29:45 +01:00
|
|
|
|
def _convert_openai_logprobs(choice) -> list | None:
|
|
|
|
|
|
"""Convert OpenAI logprobs from a choice into Ollama Logprob objects."""
|
|
|
|
|
|
lp = getattr(choice, "logprobs", None)
|
|
|
|
|
|
if lp is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
content = getattr(lp, "content", None)
|
|
|
|
|
|
if not content:
|
|
|
|
|
|
return None
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for entry in content:
|
|
|
|
|
|
top = [
|
|
|
|
|
|
TokenLogprob(token=alt.token, logprob=alt.logprob)
|
|
|
|
|
|
for alt in (entry.top_logprobs or [])
|
|
|
|
|
|
]
|
|
|
|
|
|
result.append(Logprob(
|
|
|
|
|
|
token=entry.token,
|
|
|
|
|
|
logprob=entry.logprob,
|
|
|
|
|
|
top_logprobs=top or None,
|
|
|
|
|
|
))
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
2025-09-13 11:24:28 +02:00
|
|
|
|
class rechunk:
|
2025-09-22 09:30:27 +02:00
|
|
|
|
def openai_chat_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.ChatResponse:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
now = time.perf_counter()
|
2025-09-23 12:51:37 +02:00
|
|
|
|
if chunk.choices == [] and chunk.usage is not None:
|
|
|
|
|
|
return ollama.ChatResponse(
|
|
|
|
|
|
model=chunk.model,
|
|
|
|
|
|
created_at=iso8601_ns(),
|
|
|
|
|
|
done=True,
|
|
|
|
|
|
done_reason='stop',
|
2025-11-10 15:37:46 +01:00
|
|
|
|
total_duration=int((now - start_ts) * 1_000_000_000),
|
2026-02-10 20:21:46 +01:00
|
|
|
|
load_duration=100000,
|
2025-09-23 12:51:37 +02:00
|
|
|
|
prompt_eval_count=int(chunk.usage.prompt_tokens),
|
2026-02-10 20:21:46 +01:00
|
|
|
|
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)),
|
2025-09-23 12:51:37 +02:00
|
|
|
|
eval_count=int(chunk.usage.completion_tokens),
|
2025-11-10 15:37:46 +01:00
|
|
|
|
eval_duration=int((now - start_ts) * 1_000_000_000),
|
2026-02-10 20:21:46 +01:00
|
|
|
|
message=ollama.Message(role="assistant", content=""),
|
2025-09-23 12:51:37 +02:00
|
|
|
|
)
|
2025-09-22 09:30:27 +02:00
|
|
|
|
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
2025-09-13 12:28:42 +02:00
|
|
|
|
if stream == True:
|
2026-02-08 11:29:47 +01:00
|
|
|
|
thinking = (getattr(with_thinking.delta, "reasoning_content", None) or getattr(with_thinking.delta, "reasoning", None)) if with_thinking else None
|
2025-09-21 16:33:43 +02:00
|
|
|
|
role = chunk.choices[0].delta.role or "assistant"
|
2025-09-23 12:51:37 +02:00
|
|
|
|
content = chunk.choices[0].delta.content or ''
|
2025-09-13 12:28:42 +02:00
|
|
|
|
else:
|
2026-02-08 11:29:47 +01:00
|
|
|
|
thinking = (getattr(with_thinking.message, "reasoning_content", None) or getattr(with_thinking.message, "reasoning", None)) if with_thinking else None
|
2025-09-21 16:33:43 +02:00
|
|
|
|
role = chunk.choices[0].message.role or "assistant"
|
2025-09-23 12:51:37 +02:00
|
|
|
|
content = chunk.choices[0].message.content or ''
|
2026-02-09 11:04:14 +01:00
|
|
|
|
# Convert OpenAI tool_calls to Ollama format
|
2026-02-10 20:21:46 +01:00
|
|
|
|
# In streaming mode, tool_calls arrive as partial deltas across multiple chunks
|
|
|
|
|
|
# (name only in first delta, arguments as incremental JSON fragments).
|
|
|
|
|
|
# Callers must accumulate deltas and inject the final result; skip here.
|
2026-02-09 11:04:14 +01:00
|
|
|
|
ollama_tool_calls = None
|
2026-02-10 20:21:46 +01:00
|
|
|
|
if not stream:
|
2026-02-09 11:04:14 +01:00
|
|
|
|
raw_tool_calls = getattr(with_thinking.message, "tool_calls", None) if with_thinking else None
|
2026-02-10 20:21:46 +01:00
|
|
|
|
if raw_tool_calls:
|
|
|
|
|
|
ollama_tool_calls = []
|
|
|
|
|
|
for tc in raw_tool_calls:
|
|
|
|
|
|
try:
|
|
|
|
|
|
args = orjson.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else (tc.function.arguments or {})
|
|
|
|
|
|
except (orjson.JSONDecodeError, TypeError):
|
|
|
|
|
|
args = {}
|
|
|
|
|
|
ollama_tool_calls.append(ollama.Message.ToolCall(
|
|
|
|
|
|
function=ollama.Message.ToolCall.Function(name=tc.function.name, arguments=args)
|
|
|
|
|
|
))
|
2026-02-13 13:29:45 +01:00
|
|
|
|
# Convert OpenAI logprobs to Ollama format
|
|
|
|
|
|
ollama_logprobs = _convert_openai_logprobs(with_thinking) if with_thinking else None
|
2025-09-21 16:33:43 +02:00
|
|
|
|
assistant_msg = ollama.Message(
|
|
|
|
|
|
role=role,
|
|
|
|
|
|
content=content,
|
2025-09-22 09:30:27 +02:00
|
|
|
|
thinking=thinking,
|
2025-09-21 16:33:43 +02:00
|
|
|
|
images=None,
|
|
|
|
|
|
tool_name=None,
|
2026-02-09 11:04:14 +01:00
|
|
|
|
tool_calls=ollama_tool_calls)
|
2025-09-21 16:33:43 +02:00
|
|
|
|
rechunk = ollama.ChatResponse(
|
2026-02-13 13:29:45 +01:00
|
|
|
|
model=chunk.model,
|
2025-09-21 16:33:43 +02:00
|
|
|
|
created_at=iso8601_ns(),
|
2025-09-23 12:51:37 +02:00
|
|
|
|
done=True if chunk.usage is not None else False,
|
|
|
|
|
|
done_reason=chunk.choices[0].finish_reason, #if chunk.choices[0].finish_reason is not None else None,
|
2025-11-10 15:37:46 +01:00
|
|
|
|
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
2026-02-13 13:29:45 +01:00
|
|
|
|
load_duration=100000,
|
2025-09-22 19:01:14 +02:00
|
|
|
|
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
2026-02-13 13:29:45 +01:00
|
|
|
|
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
2025-09-22 19:01:14 +02:00
|
|
|
|
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
2025-11-10 15:37:46 +01:00
|
|
|
|
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
2026-02-13 13:29:45 +01:00
|
|
|
|
message=assistant_msg,
|
|
|
|
|
|
logprobs=ollama_logprobs)
|
2025-09-13 12:38:13 +02:00
|
|
|
|
return rechunk
|
2025-09-13 16:57:09 +02:00
|
|
|
|
|
2025-09-22 09:30:27 +02:00
|
|
|
|
def openai_completion2ollama(chunk: dict, stream: bool, start_ts: float) -> ollama.GenerateResponse:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
now = time.perf_counter()
|
2025-09-17 11:40:48 +02:00
|
|
|
|
with_thinking = chunk.choices[0] if chunk.choices[0] else None
|
|
|
|
|
|
thinking = getattr(with_thinking, "reasoning", None) if with_thinking else None
|
2025-09-21 16:33:43 +02:00
|
|
|
|
rechunk = ollama.GenerateResponse(
|
|
|
|
|
|
model=chunk.model,
|
|
|
|
|
|
created_at=iso8601_ns(),
|
2025-09-23 12:51:37 +02:00
|
|
|
|
done=True if chunk.usage is not None else False,
|
2025-09-21 16:33:43 +02:00
|
|
|
|
done_reason=chunk.choices[0].finish_reason,
|
2025-11-10 15:37:46 +01:00
|
|
|
|
total_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
2025-09-22 09:30:27 +02:00
|
|
|
|
load_duration=10000,
|
2025-09-23 12:51:37 +02:00
|
|
|
|
prompt_eval_count=int(chunk.usage.prompt_tokens) if chunk.usage is not None else 0,
|
2025-11-10 15:37:46 +01:00
|
|
|
|
prompt_eval_duration=int((now - start_ts) * 1_000_000_000 * (chunk.usage.prompt_tokens / chunk.usage.completion_tokens / 100)) if chunk.usage is not None and chunk.usage.completion_tokens != 0 else 0,
|
2025-09-23 12:51:37 +02:00
|
|
|
|
eval_count=int(chunk.usage.completion_tokens) if chunk.usage is not None else 0,
|
2025-11-10 15:37:46 +01:00
|
|
|
|
eval_duration=int((now - start_ts) * 1_000_000_000) if chunk.usage is not None else 0,
|
2025-09-23 12:51:37 +02:00
|
|
|
|
response=chunk.choices[0].text or '',
|
2025-09-22 09:30:27 +02:00
|
|
|
|
thinking=thinking)
|
2025-09-13 16:57:09 +02:00
|
|
|
|
return rechunk
|
2025-09-15 11:47:55 +02:00
|
|
|
|
|
2025-09-22 09:30:27 +02:00
|
|
|
|
def openai_embeddings2ollama(chunk: dict) -> ollama.EmbeddingsResponse:
|
2025-09-21 16:20:36 +02:00
|
|
|
|
rechunk = ollama.EmbeddingsResponse(embedding=chunk.data[0].embedding)
|
2025-09-15 11:47:55 +02:00
|
|
|
|
return rechunk
|
2025-09-13 11:24:28 +02:00
|
|
|
|
|
2025-09-22 09:30:27 +02:00
|
|
|
|
def openai_embed2ollama(chunk: dict, model: str) -> ollama.EmbedResponse:
|
2025-09-21 16:33:43 +02:00
|
|
|
|
rechunk = ollama.EmbedResponse(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
created_at=iso8601_ns(),
|
|
|
|
|
|
done=None,
|
|
|
|
|
|
done_reason=None,
|
|
|
|
|
|
total_duration=None,
|
|
|
|
|
|
load_duration=None,
|
|
|
|
|
|
prompt_eval_count=None,
|
|
|
|
|
|
prompt_eval_duration=None,
|
|
|
|
|
|
eval_count=None,
|
|
|
|
|
|
eval_duration=None,
|
|
|
|
|
|
embeddings=[chunk.data[0].embedding])
|
2025-09-15 11:47:55 +02:00
|
|
|
|
return rechunk
|
2026-02-14 14:51:44 +01:00
|
|
|
|
|
|
|
|
|
|
def extract_usage_from_llama_timings(obj) -> tuple[int, int] | None:
|
|
|
|
|
|
"""Extract (prompt_tokens, completion_tokens) from llama-server's timings object.
|
|
|
|
|
|
|
|
|
|
|
|
llama-server returns a ``timings`` dict instead of the standard OpenAI
|
|
|
|
|
|
``usage`` field::
|
|
|
|
|
|
|
|
|
|
|
|
"timings": {
|
|
|
|
|
|
"cache_n": 236, // prompt tokens reused from cache
|
|
|
|
|
|
"prompt_n": 1, // prompt tokens processed
|
|
|
|
|
|
"predicted_n": 35 // predicted (completion) tokens
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
prompt_tokens = prompt_n + cache_n
|
|
|
|
|
|
completion_tokens = predicted_n
|
|
|
|
|
|
|
|
|
|
|
|
Returns ``(prompt_tokens, completion_tokens)`` or ``None`` when no
|
|
|
|
|
|
timings are found.
|
|
|
|
|
|
"""
|
|
|
|
|
|
timings = getattr(obj, "timings", None)
|
|
|
|
|
|
if timings is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
if isinstance(timings, dict):
|
|
|
|
|
|
prompt_n = timings.get("prompt_n", 0) or 0
|
|
|
|
|
|
cache_n = timings.get("cache_n", 0) or 0
|
|
|
|
|
|
predicted_n = timings.get("predicted_n", 0) or 0
|
|
|
|
|
|
return (prompt_n + cache_n, predicted_n)
|
|
|
|
|
|
return None
|
2025-09-22 14:04:19 +02:00
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# SSE Helpser
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
async def publish_snapshot():
|
2026-01-29 18:00:33 +01:00
|
|
|
|
# NOTE: This function assumes usage_lock OR token_usage_lock is already held by the caller
|
|
|
|
|
|
# Create a snapshot without acquiring the lock (caller must hold it)
|
|
|
|
|
|
snapshot = orjson.dumps({
|
|
|
|
|
|
"usage_counts": dict(usage_counts), # Create a copy
|
|
|
|
|
|
"token_usage_counts": dict(token_usage_counts)
|
|
|
|
|
|
}, option=orjson.OPT_SORT_KEYS).decode("utf-8")
|
2026-01-26 17:18:57 +01:00
|
|
|
|
|
|
|
|
|
|
# Distribute the snapshot (no lock needed here since we have a copy)
|
2025-09-05 12:11:31 +02:00
|
|
|
|
async with _subscribers_lock:
|
|
|
|
|
|
for q in _subscribers:
|
|
|
|
|
|
# If the queue is full, drop the message to avoid back‑pressure.
|
|
|
|
|
|
if q.full():
|
2025-09-19 16:38:48 +02:00
|
|
|
|
try:
|
|
|
|
|
|
await q.get()
|
|
|
|
|
|
except asyncio.QueueEmpty:
|
|
|
|
|
|
pass
|
2025-09-05 12:11:31 +02:00
|
|
|
|
await q.put(snapshot)
|
|
|
|
|
|
|
2025-09-12 09:44:56 +02:00
|
|
|
|
async def close_all_sse_queues():
|
|
|
|
|
|
for q in list(_subscribers):
|
|
|
|
|
|
# sentinel value that the generator will recognise
|
|
|
|
|
|
await q.put(None)
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Subscriber helpers
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
async def subscribe() -> asyncio.Queue:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Returns a new Queue that will receive every snapshot.
|
|
|
|
|
|
"""
|
|
|
|
|
|
q: asyncio.Queue = asyncio.Queue(maxsize=10)
|
|
|
|
|
|
async with _subscribers_lock:
|
|
|
|
|
|
_subscribers.add(q)
|
|
|
|
|
|
return q
|
|
|
|
|
|
|
|
|
|
|
|
async def unsubscribe(q: asyncio.Queue):
|
|
|
|
|
|
async with _subscribers_lock:
|
|
|
|
|
|
_subscribers.discard(q)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Convenience wrapper – returns the current snapshot (for the proxy)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
async def get_usage_counts() -> Dict:
|
|
|
|
|
|
return dict(usage_counts) # shallow copy
|
2025-08-29 13:13:25 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 5. Endpoint selection logic (respecting the configurable limit)
|
|
|
|
|
|
# -------------------------------------------------------------
|
2026-03-03 14:57:37 +01:00
|
|
|
|
async def choose_endpoint(model: str, reserve: bool = True) -> tuple[str, str]:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
2025-08-29 13:13:25 +02:00
|
|
|
|
Determine which endpoint to use for the given model while respecting
|
|
|
|
|
|
the `max_concurrent_connections` per endpoint‑model pair **and**
|
|
|
|
|
|
ensuring that the chosen endpoint actually *advertises* the model.
|
|
|
|
|
|
|
|
|
|
|
|
The selection algorithm:
|
|
|
|
|
|
|
|
|
|
|
|
1️⃣ Query every endpoint for its advertised models (`/api/tags`).
|
|
|
|
|
|
2️⃣ Build a list of endpoints that contain the requested model.
|
|
|
|
|
|
3️⃣ For those endpoints, find those that have the model loaded
|
2025-09-18 18:49:11 +02:00
|
|
|
|
(`/api/ps`) *and* still have a free slot.
|
2025-08-29 13:13:25 +02:00
|
|
|
|
4️⃣ If none are both loaded and free, fall back to any endpoint
|
2025-09-18 18:49:11 +02:00
|
|
|
|
from the filtered list that simply has a free slot and randomly
|
|
|
|
|
|
select one.
|
2025-08-29 13:13:25 +02:00
|
|
|
|
5️⃣ If all are saturated, pick any endpoint from the filtered list
|
2025-09-18 18:49:11 +02:00
|
|
|
|
(the request will queue on that endpoint).
|
2025-08-29 13:13:25 +02:00
|
|
|
|
6️⃣ If no endpoint advertises the model at all, raise an error.
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
2025-08-29 13:13:25 +02:00
|
|
|
|
# 1️⃣ Gather advertised‑model sets for all endpoints concurrently
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# Include both config.endpoints and config.llama_server_endpoints
|
|
|
|
|
|
llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
|
|
|
|
|
all_endpoints = config.endpoints + llama_eps_extra
|
|
|
|
|
|
|
|
|
|
|
|
tag_tasks = [fetch.available_models(ep) for ep in config.endpoints if not is_openai_compatible(ep)]
|
|
|
|
|
|
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in config.endpoints if is_openai_compatible(ep)]
|
|
|
|
|
|
tag_tasks += [fetch.available_models(ep, config.api_keys.get(ep)) for ep in llama_eps_extra]
|
2025-08-29 13:13:25 +02:00
|
|
|
|
advertised_sets = await asyncio.gather(*tag_tasks)
|
|
|
|
|
|
|
|
|
|
|
|
# 2️⃣ Filter endpoints that advertise the requested model
|
|
|
|
|
|
candidate_endpoints = [
|
2026-02-10 16:46:51 +01:00
|
|
|
|
ep for ep, models in zip(all_endpoints, advertised_sets)
|
2025-08-29 13:13:25 +02:00
|
|
|
|
if model in models
|
|
|
|
|
|
]
|
2026-02-10 16:46:51 +01:00
|
|
|
|
|
2025-12-14 17:58:45 +01:00
|
|
|
|
# 6️⃣
|
2025-08-29 13:13:25 +02:00
|
|
|
|
if not candidate_endpoints:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if ":latest" in model: #ollama naming convention not applicable to openai/llama-server
|
2025-10-30 09:06:21 +01:00
|
|
|
|
model_without_latest = model.split(":latest")[0]
|
|
|
|
|
|
candidate_endpoints = [
|
2026-02-10 16:46:51 +01:00
|
|
|
|
ep for ep, models in zip(all_endpoints, advertised_sets)
|
|
|
|
|
|
if model_without_latest in models and (is_ext_openai_endpoint(ep) or ep in config.llama_server_endpoints)
|
2025-10-30 09:06:21 +01:00
|
|
|
|
]
|
|
|
|
|
|
if not candidate_endpoints:
|
2025-12-14 17:58:45 +01:00
|
|
|
|
# Only add :latest suffix if model doesn't already have a version suffix
|
|
|
|
|
|
if ":" not in model:
|
|
|
|
|
|
model = model + ":latest"
|
2025-09-15 17:39:15 +02:00
|
|
|
|
candidate_endpoints = [
|
2026-02-10 16:46:51 +01:00
|
|
|
|
ep for ep, models in zip(all_endpoints, advertised_sets)
|
2025-09-15 17:39:15 +02:00
|
|
|
|
if model in models
|
|
|
|
|
|
]
|
|
|
|
|
|
if not candidate_endpoints:
|
|
|
|
|
|
raise RuntimeError(
|
2026-02-10 16:46:51 +01:00
|
|
|
|
f"None of the configured endpoints ({', '.join(all_endpoints)}) "
|
2025-09-15 17:39:15 +02:00
|
|
|
|
f"advertise the model '{model}'."
|
|
|
|
|
|
)
|
2025-08-29 13:13:25 +02:00
|
|
|
|
# 3️⃣ Among the candidates, find those that have the model *loaded*
|
|
|
|
|
|
# (concurrently, but only for the filtered list)
|
2025-09-13 16:57:09 +02:00
|
|
|
|
load_tasks = [fetch.loaded_models(ep) for ep in candidate_endpoints]
|
2025-08-29 13:13:25 +02:00
|
|
|
|
loaded_sets = await asyncio.gather(*load_tasks)
|
2026-01-26 17:18:57 +01:00
|
|
|
|
|
2026-03-03 14:57:37 +01:00
|
|
|
|
# Protect all reads/writes of usage_counts with the lock so that selection
|
|
|
|
|
|
# and reservation are atomic — concurrent callers see each other's pending load.
|
2025-08-26 18:19:43 +02:00
|
|
|
|
async with usage_lock:
|
2026-02-19 17:32:54 +01:00
|
|
|
|
# Helper: current usage for (endpoint, model) using the same normalized key
|
|
|
|
|
|
# that increment_usage/decrement_usage store — raw model names differ from
|
|
|
|
|
|
# tracking names for llama-server (HF prefix / quant suffix stripped).
|
|
|
|
|
|
def tracking_usage(ep: str) -> int:
|
|
|
|
|
|
return usage_counts.get(ep, {}).get(get_tracking_model(ep, model), 0)
|
2026-01-26 17:18:57 +01:00
|
|
|
|
|
2025-08-29 13:13:25 +02:00
|
|
|
|
# 3️⃣ Endpoints that have the model loaded *and* a free slot
|
2025-08-26 18:19:43 +02:00
|
|
|
|
loaded_and_free = [
|
2025-08-29 13:13:25 +02:00
|
|
|
|
ep for ep, models in zip(candidate_endpoints, loaded_sets)
|
2026-02-19 17:32:54 +01:00
|
|
|
|
if model in models and tracking_usage(ep) < config.max_concurrent_connections
|
2025-08-26 18:19:43 +02:00
|
|
|
|
]
|
2026-01-26 17:18:57 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if loaded_and_free:
|
2026-02-19 17:32:54 +01:00
|
|
|
|
# Sort ascending for load balancing — all endpoints here already have the
|
|
|
|
|
|
# model loaded, so there is no model-switching cost to optimise for.
|
|
|
|
|
|
loaded_and_free.sort(key=tracking_usage)
|
|
|
|
|
|
# When all candidates are equally idle, randomise to avoid always picking
|
|
|
|
|
|
# the first entry in a stable sort.
|
|
|
|
|
|
if all(tracking_usage(ep) == 0 for ep in loaded_and_free):
|
2026-03-03 14:57:37 +01:00
|
|
|
|
selected = random.choice(loaded_and_free)
|
|
|
|
|
|
else:
|
|
|
|
|
|
selected = loaded_and_free[0]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 4️⃣ Endpoints among the candidates that simply have a free slot
|
|
|
|
|
|
endpoints_with_free_slot = [
|
|
|
|
|
|
ep for ep in candidate_endpoints
|
|
|
|
|
|
if tracking_usage(ep) < config.max_concurrent_connections
|
|
|
|
|
|
]
|
2026-01-29 18:00:33 +01:00
|
|
|
|
|
2026-03-03 14:57:37 +01:00
|
|
|
|
if endpoints_with_free_slot:
|
|
|
|
|
|
# Sort by total endpoint load (ascending) to prefer idle endpoints.
|
|
|
|
|
|
endpoints_with_free_slot.sort(
|
|
|
|
|
|
key=lambda ep: sum(usage_counts.get(ep, {}).values())
|
|
|
|
|
|
)
|
|
|
|
|
|
if all(tracking_usage(ep) == 0 for ep in endpoints_with_free_slot):
|
|
|
|
|
|
selected = random.choice(endpoints_with_free_slot)
|
|
|
|
|
|
else:
|
|
|
|
|
|
selected = endpoints_with_free_slot[0]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 5️⃣ All candidate endpoints are saturated – pick the least-busy one (will queue)
|
|
|
|
|
|
selected = min(candidate_endpoints, key=tracking_usage)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2026-03-03 14:57:37 +01:00
|
|
|
|
tracking_model = get_tracking_model(selected, model)
|
|
|
|
|
|
if reserve:
|
|
|
|
|
|
usage_counts[selected][tracking_model] += 1
|
|
|
|
|
|
await publish_snapshot()
|
|
|
|
|
|
return selected, tracking_model
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 6. API route – Generate
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/generate")
|
|
|
|
|
|
async def proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a generate request to Ollama and stream the response back to the client.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
suffix = payload.get("suffix")
|
|
|
|
|
|
system = payload.get("system")
|
|
|
|
|
|
template = payload.get("template")
|
|
|
|
|
|
context = payload.get("context")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
think = payload.get("think")
|
|
|
|
|
|
raw = payload.get("raw")
|
2025-09-11 18:53:23 +02:00
|
|
|
|
_format = payload.get("format")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
images = payload.get("images")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
2026-03-10 15:19:37 +01:00
|
|
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-10-03 10:04:50 +02:00
|
|
|
|
error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted."
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=error_msg) from e
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Cache lookup — before endpoint selection so no slot is wasted on a hit
|
|
|
|
|
|
_cache = get_llm_cache()
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
_cached = await _cache.get_generate(model, prompt, system or "")
|
|
|
|
|
|
if _cached is not None:
|
|
|
|
|
|
async def _serve_cached_generate():
|
|
|
|
|
|
yield _cached
|
|
|
|
|
|
return StreamingResponse(_serve_cached_generate(), media_type="application/json")
|
|
|
|
|
|
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
2025-09-15 17:39:15 +02:00
|
|
|
|
if ":latest" in model:
|
2025-09-15 19:12:00 +02:00
|
|
|
|
model = model.split(":latest")
|
2025-09-15 17:39:15 +02:00
|
|
|
|
model = model[0]
|
2025-09-13 16:57:09 +02:00
|
|
|
|
params = {
|
2026-02-10 16:46:51 +01:00
|
|
|
|
"prompt": prompt,
|
2025-09-13 16:57:09 +02:00
|
|
|
|
"model": model,
|
|
|
|
|
|
}
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2025-09-13 16:57:09 +02:00
|
|
|
|
optional_params = {
|
|
|
|
|
|
"stream": stream,
|
2025-09-22 14:04:19 +02:00
|
|
|
|
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
|
|
|
|
|
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
|
|
|
|
|
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
|
|
|
|
|
"seed": options.get("seed") if options and "seed" in options else None,
|
|
|
|
|
|
"stop": options.get("stop") if options and "stop" in options else None,
|
|
|
|
|
|
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
|
|
|
|
|
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
2025-10-28 11:08:52 +01:00
|
|
|
|
"suffix": suffix,
|
2025-09-22 14:04:19 +02:00
|
|
|
|
}
|
2025-09-13 16:57:09 +02:00
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
2026-02-10 16:46:51 +01:00
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-09-13 16:57:09 +02:00
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# 4. Async generator that streams data and decrements the counter
|
|
|
|
|
|
async def stream_generate_response():
|
|
|
|
|
|
try:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
start_ts = time.perf_counter()
|
|
|
|
|
|
async_gen = await oclient.completions.create(**params)
|
|
|
|
|
|
else:
|
|
|
|
|
|
async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if stream == True:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
content_parts: list[str] = []
|
2025-08-26 18:19:43 +02:00
|
|
|
|
async for chunk in async_gen:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts)
|
2025-11-04 17:55:19 +01:00
|
|
|
|
prompt_tok = chunk.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = chunk.eval_count or 0
|
2025-11-18 19:02:36 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if hasattr(chunk, "model_dump_json"):
|
|
|
|
|
|
json_line = chunk.model_dump_json()
|
|
|
|
|
|
else:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
json_line = orjson.dumps(chunk)
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Accumulate and store cache on done chunk — before yield so it always runs
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
if getattr(chunk, "response", None):
|
|
|
|
|
|
content_parts.append(chunk.response)
|
|
|
|
|
|
if getattr(chunk, "done", False):
|
|
|
|
|
|
assembled = orjson.dumps({
|
|
|
|
|
|
k: v for k, v in {
|
|
|
|
|
|
"model": getattr(chunk, "model", model),
|
|
|
|
|
|
"response": "".join(content_parts),
|
|
|
|
|
|
"done": True,
|
|
|
|
|
|
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
|
|
|
|
|
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
|
|
|
|
|
"eval_count": getattr(chunk, "eval_count", None),
|
|
|
|
|
|
"total_duration": getattr(chunk, "total_duration", None),
|
|
|
|
|
|
"eval_duration": getattr(chunk, "eval_duration", None),
|
|
|
|
|
|
}.items() if v is not None
|
|
|
|
|
|
}) + b"\n"
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_generate(model, prompt, system or "", assembled)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_generate (streaming) failed: {_ce}")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
else:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-13 16:57:09 +02:00
|
|
|
|
response = rechunk.openai_completion2ollama(async_gen, stream, start_ts)
|
2025-09-22 09:30:27 +02:00
|
|
|
|
response = response.model_dump_json()
|
2025-09-13 16:57:09 +02:00
|
|
|
|
else:
|
|
|
|
|
|
response = async_gen.model_dump_json()
|
2025-11-04 17:55:19 +01:00
|
|
|
|
prompt_tok = async_gen.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = async_gen.eval_count or 0
|
2025-11-18 19:02:36 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
json_line = (
|
2025-09-13 16:57:09 +02:00
|
|
|
|
response
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
2025-11-10 15:37:46 +01:00
|
|
|
|
else orjson.dumps(async_gen)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
)
|
2026-03-08 09:12:09 +01:00
|
|
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
yield cache_bytes
|
|
|
|
|
|
# Cache non-streaming response
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_generate(model, prompt, system or "", cache_bytes)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_generate (non-streaming) failed: {_ce}")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_generate_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 7. API route – Chat
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/chat")
|
|
|
|
|
|
async def chat_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a chat request to Ollama and stream the endpoint reply.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-09-22 14:04:19 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
tools = payload.get("tools")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
think = payload.get("think")
|
2025-09-21 16:20:36 +02:00
|
|
|
|
_format = payload.get("format")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
keep_alive = payload.get("keep_alive")
|
2025-09-21 16:20:36 +02:00
|
|
|
|
options = payload.get("options")
|
2026-02-13 13:29:45 +01:00
|
|
|
|
logprobs = payload.get("logprobs")
|
|
|
|
|
|
top_logprobs = payload.get("top_logprobs")
|
2026-03-10 15:19:37 +01:00
|
|
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
2025-09-23 17:33:15 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
|
raise HTTPException(
|
2025-09-13 11:24:28 +02:00
|
|
|
|
status_code=400, detail="Missing or invalid 'messages' field (must be a list)"
|
2025-08-26 18:19:43 +02:00
|
|
|
|
)
|
2025-09-21 16:20:36 +02:00
|
|
|
|
if options is not None and not isinstance(options, dict):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="`options` must be a JSON object"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Cache lookup — before endpoint selection, always bypassed for MOE
|
|
|
|
|
|
_is_moe = model.startswith("moe-")
|
|
|
|
|
|
_cache = get_llm_cache()
|
|
|
|
|
|
# Normalise model name for cache key: strip ":latest" suffix here so that
|
|
|
|
|
|
# get_chat and set_chat use the same model string regardless of when the
|
|
|
|
|
|
# strip happens further down (line ~1793 strips it for OpenAI endpoints).
|
|
|
|
|
|
_cache_model = model[: -len(":latest")] if model.endswith(":latest") else model
|
|
|
|
|
|
# Snapshot original messages before any OpenAI-format transformation so that
|
|
|
|
|
|
# get_chat and set_chat always use the same key regardless of backend type.
|
|
|
|
|
|
_cache_messages = messages
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and not _is_moe and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
_cached = await _cache.get_chat("ollama_chat", _cache_model, messages)
|
|
|
|
|
|
if _cached is not None:
|
|
|
|
|
|
async def _serve_cached_chat():
|
|
|
|
|
|
yield _cached
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
_serve_cached_chat(),
|
|
|
|
|
|
media_type="application/x-ndjson" if stream else "application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# 2. Endpoint logic
|
2025-12-13 11:58:49 +01:00
|
|
|
|
if model.startswith("moe-"):
|
|
|
|
|
|
model = model.split("moe-")[1]
|
|
|
|
|
|
opt = True
|
|
|
|
|
|
else:
|
|
|
|
|
|
opt = False
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
2025-09-15 17:39:15 +02:00
|
|
|
|
if ":latest" in model:
|
2025-09-15 19:12:00 +02:00
|
|
|
|
model = model.split(":latest")
|
2025-09-15 17:39:15 +02:00
|
|
|
|
model = model[0]
|
2025-09-23 17:33:15 +02:00
|
|
|
|
if messages:
|
|
|
|
|
|
messages = transform_images_to_data_urls(messages)
|
2026-02-10 20:21:46 +01:00
|
|
|
|
messages = transform_tool_calls_to_openai(messages)
|
2025-09-13 11:24:28 +02:00
|
|
|
|
params = {
|
2026-02-10 16:46:51 +01:00
|
|
|
|
"messages": messages,
|
2025-09-13 11:24:28 +02:00
|
|
|
|
"model": model,
|
2025-09-22 14:04:19 +02:00
|
|
|
|
}
|
2025-09-13 11:24:28 +02:00
|
|
|
|
optional_params = {
|
|
|
|
|
|
"tools": tools,
|
|
|
|
|
|
"stream": stream,
|
2025-11-04 17:55:19 +01:00
|
|
|
|
"stream_options": {"include_usage": True} if stream else None,
|
2025-09-22 14:04:19 +02:00
|
|
|
|
"max_tokens": options.get("num_predict") if options and "num_predict" in options else None,
|
|
|
|
|
|
"frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None,
|
|
|
|
|
|
"presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None,
|
|
|
|
|
|
"seed": options.get("seed") if options and "seed" in options else None,
|
|
|
|
|
|
"stop": options.get("stop") if options and "stop" in options else None,
|
|
|
|
|
|
"top_p": options.get("top_p") if options and "top_p" in options else None,
|
|
|
|
|
|
"temperature": options.get("temperature") if options and "temperature" in options else None,
|
2026-02-13 13:29:45 +01:00
|
|
|
|
"logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None),
|
|
|
|
|
|
"top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None),
|
2025-09-22 14:04:19 +02:00
|
|
|
|
"response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None
|
|
|
|
|
|
}
|
2025-09-13 11:24:28 +02:00
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
2026-02-10 16:46:51 +01:00
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=ep2base(endpoint), default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-09-13 11:24:28 +02:00
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# 3. Async generator that streams chat data and decrements the counter
|
|
|
|
|
|
async def stream_chat_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-13 12:10:40 +02:00
|
|
|
|
start_ts = time.perf_counter()
|
2025-09-13 11:24:28 +02:00
|
|
|
|
async_gen = await oclient.chat.completions.create(**params)
|
|
|
|
|
|
else:
|
2025-12-13 11:58:49 +01:00
|
|
|
|
if opt == True:
|
2025-12-15 10:35:56 +01:00
|
|
|
|
# Use the dedicated MOE helper function
|
|
|
|
|
|
async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive)
|
2025-12-13 11:58:49 +01:00
|
|
|
|
else:
|
2026-02-13 13:29:45 +01:00
|
|
|
|
async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if stream == True:
|
2026-02-10 20:21:46 +01:00
|
|
|
|
tc_acc = {} # accumulate OpenAI tool-call deltas across chunks
|
2026-03-08 09:12:09 +01:00
|
|
|
|
content_parts: list[str] = []
|
2025-08-26 18:19:43 +02:00
|
|
|
|
async for chunk in async_gen:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2026-02-10 20:21:46 +01:00
|
|
|
|
_accumulate_openai_tc_delta(chunk, tc_acc)
|
2025-09-13 12:28:42 +02:00
|
|
|
|
chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts)
|
2026-02-10 20:21:46 +01:00
|
|
|
|
# Inject fully-accumulated tool calls only into the final chunk
|
|
|
|
|
|
if chunk.done and tc_acc and chunk.message:
|
|
|
|
|
|
chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# `chunk` can be a dict or a pydantic model – dump to JSON safely
|
2025-11-04 17:55:19 +01:00
|
|
|
|
prompt_tok = chunk.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = chunk.eval_count or 0
|
2025-11-18 19:02:36 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if hasattr(chunk, "model_dump_json"):
|
|
|
|
|
|
json_line = chunk.model_dump_json()
|
|
|
|
|
|
else:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
json_line = orjson.dumps(chunk)
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Accumulate and store cache on done chunk — before yield so it always runs
|
|
|
|
|
|
# Works for both Ollama-native and OpenAI-compatible backends; chunks are
|
|
|
|
|
|
# already converted to Ollama format by rechunk before this point.
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and not _is_moe and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
if chunk.message and getattr(chunk.message, "content", None):
|
|
|
|
|
|
content_parts.append(chunk.message.content)
|
|
|
|
|
|
if getattr(chunk, "done", False):
|
|
|
|
|
|
assembled = orjson.dumps({
|
|
|
|
|
|
k: v for k, v in {
|
|
|
|
|
|
"model": getattr(chunk, "model", model),
|
|
|
|
|
|
"created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)),
|
|
|
|
|
|
"message": {"role": "assistant", "content": "".join(content_parts)},
|
|
|
|
|
|
"done": True,
|
|
|
|
|
|
"done_reason": getattr(chunk, "done_reason", "stop") or "stop",
|
|
|
|
|
|
"prompt_eval_count": getattr(chunk, "prompt_eval_count", None),
|
|
|
|
|
|
"eval_count": getattr(chunk, "eval_count", None),
|
|
|
|
|
|
"total_duration": getattr(chunk, "total_duration", None),
|
|
|
|
|
|
"eval_duration": getattr(chunk, "eval_duration", None),
|
|
|
|
|
|
}.items() if v is not None
|
|
|
|
|
|
}) + b"\n"
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
else:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-13 12:28:42 +02:00
|
|
|
|
response = rechunk.openai_chat_completion2ollama(async_gen, stream, start_ts)
|
2025-09-22 09:30:27 +02:00
|
|
|
|
response = response.model_dump_json()
|
2025-09-13 12:28:42 +02:00
|
|
|
|
else:
|
2025-09-13 12:38:13 +02:00
|
|
|
|
response = async_gen.model_dump_json()
|
2025-11-04 17:55:19 +01:00
|
|
|
|
prompt_tok = async_gen.prompt_eval_count or 0
|
|
|
|
|
|
comp_tok = async_gen.eval_count or 0
|
2025-11-18 19:02:36 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
json_line = (
|
2025-09-13 12:28:42 +02:00
|
|
|
|
response
|
2025-09-22 19:01:14 +02:00
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
2025-11-10 15:37:46 +01:00
|
|
|
|
else orjson.dumps(async_gen)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
)
|
2026-03-08 09:12:09 +01:00
|
|
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
yield cache_bytes
|
|
|
|
|
|
# Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends)
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and not _is_moe and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
2025-09-22 19:01:14 +02:00
|
|
|
|
media_type = "application/x-ndjson" if stream else "application/json"
|
2025-08-26 18:19:43 +02:00
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_chat_response(),
|
2025-09-22 19:01:14 +02:00
|
|
|
|
media_type=media_type,
|
2025-08-26 18:19:43 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 8. API route – Embedding - deprecated
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/embeddings")
|
|
|
|
|
|
async def embedding_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an embedding request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
2025-09-15 17:48:17 +02:00
|
|
|
|
if ":latest" in model:
|
2025-09-15 19:12:00 +02:00
|
|
|
|
model = model.split(":latest")
|
2025-09-15 17:48:17 +02:00
|
|
|
|
model = model[0]
|
2026-02-10 16:46:51 +01:00
|
|
|
|
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-09-15 11:47:55 +02:00
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# 3. Async generator that streams embedding data and decrements the counter
|
|
|
|
|
|
async def stream_embedding_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-21 16:20:36 +02:00
|
|
|
|
async_gen = await client.embeddings.create(input=prompt, model=model)
|
2025-09-15 11:47:55 +02:00
|
|
|
|
async_gen = rechunk.openai_embeddings2ollama(async_gen)
|
|
|
|
|
|
else:
|
|
|
|
|
|
async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if hasattr(async_gen, "model_dump_json"):
|
|
|
|
|
|
json_line = async_gen.model_dump_json()
|
|
|
|
|
|
else:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
json_line = orjson.dumps(async_gen)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# 5. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_embedding_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 9. API route – Embed
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/embed")
|
|
|
|
|
|
async def embed_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an embed request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
2025-09-15 11:47:55 +02:00
|
|
|
|
_input = payload.get("input")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
truncate = payload.get("truncate")
|
|
|
|
|
|
options = payload.get("options")
|
|
|
|
|
|
keep_alive = payload.get("keep_alive")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
2025-09-15 11:47:55 +02:00
|
|
|
|
if not _input:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'input'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
use_openai = is_openai_compatible(endpoint)
|
|
|
|
|
|
if use_openai:
|
2025-09-15 17:48:17 +02:00
|
|
|
|
if ":latest" in model:
|
2025-09-15 19:12:00 +02:00
|
|
|
|
model = model.split(":latest")
|
2025-09-15 17:48:17 +02:00
|
|
|
|
model = model[0]
|
2026-02-10 16:46:51 +01:00
|
|
|
|
client = openai.AsyncOpenAI(base_url=ep2base(endpoint), api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-09-15 11:47:55 +02:00
|
|
|
|
else:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# 3. Async generator that streams embed data and decrements the counter
|
|
|
|
|
|
async def stream_embedding_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if use_openai:
|
2025-09-21 16:20:36 +02:00
|
|
|
|
async_gen = await client.embeddings.create(input=_input, model=model)
|
2025-09-15 11:47:55 +02:00
|
|
|
|
async_gen = rechunk.openai_embed2ollama(async_gen, model)
|
|
|
|
|
|
else:
|
|
|
|
|
|
async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if hasattr(async_gen, "model_dump_json"):
|
|
|
|
|
|
json_line = async_gen.model_dump_json()
|
|
|
|
|
|
else:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
json_line = orjson.dumps(async_gen)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
yield json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
2026-02-18 11:45:37 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_embedding_response(),
|
|
|
|
|
|
media_type="application/json",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 10. API route – Create
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/create")
|
|
|
|
|
|
async def create_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a create request to all Ollama endpoints and reply with deduplicated status.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
quantize = payload.get("quantize")
|
|
|
|
|
|
from_ = payload.get("from")
|
|
|
|
|
|
files = payload.get("files")
|
|
|
|
|
|
adapters = payload.get("adapters")
|
|
|
|
|
|
template = payload.get("template")
|
|
|
|
|
|
license = payload.get("license")
|
|
|
|
|
|
system = payload.get("system")
|
|
|
|
|
|
parameters = payload.get("parameters")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not from_ and not files:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="You need to provide either from_ or files parameter!"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
status_lists = []
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False)
|
|
|
|
|
|
status_lists.append(create)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status_list in status_lists:
|
|
|
|
|
|
combined_status += status_list
|
|
|
|
|
|
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 11. API route – Show
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/show")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
async def show_proxy(request: Request, model: Optional[str] = None):
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model show request to Ollama and reply with ShowResponse.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if not model:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-09-05 12:11:31 +02:00
|
|
|
|
model = payload.get("model")
|
2025-11-18 19:02:36 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, _ = await choose_endpoint(model, reserve=False)
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Proxy a simple show request
|
|
|
|
|
|
show = await client.show(model=model)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Return ShowResponse
|
|
|
|
|
|
return show
|
|
|
|
|
|
|
2025-11-18 19:02:36 +01:00
|
|
|
|
# -------------------------------------------------------------
|
2025-11-28 14:59:29 +01:00
|
|
|
|
@app.get("/api/token_counts")
|
|
|
|
|
|
async def token_counts_proxy():
|
|
|
|
|
|
breakdown = []
|
|
|
|
|
|
total = 0
|
|
|
|
|
|
async for entry in db.load_token_counts():
|
|
|
|
|
|
total += entry['total_tokens']
|
|
|
|
|
|
breakdown.append({
|
|
|
|
|
|
"endpoint": entry["endpoint"],
|
|
|
|
|
|
"model": entry["model"],
|
|
|
|
|
|
"input_tokens": entry["input_tokens"],
|
|
|
|
|
|
"output_tokens": entry["output_tokens"],
|
|
|
|
|
|
"total_tokens": entry["total_tokens"],
|
|
|
|
|
|
})
|
|
|
|
|
|
return {"total_tokens": total, "breakdown": breakdown}
|
|
|
|
|
|
|
2025-12-02 12:18:23 +01:00
|
|
|
|
@app.post("/api/aggregate_time_series_days")
|
|
|
|
|
|
async def aggregate_time_series_days_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Aggregate time_series entries older than days into daily aggregates by endpoint/model/date.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
if not body_bytes:
|
|
|
|
|
|
days = 30
|
|
|
|
|
|
trim_old = False
|
|
|
|
|
|
else:
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
days = int(payload.get("days", 30))
|
|
|
|
|
|
trim_old = bool(payload.get("trim_old", False))
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
days = 30
|
|
|
|
|
|
trim_old = False
|
|
|
|
|
|
aggregated = await db.aggregate_time_series_older_than(days, trim_old=trim_old)
|
|
|
|
|
|
return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated}
|
|
|
|
|
|
|
2025-11-18 19:02:36 +01:00
|
|
|
|
# 12. API route – Stats
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/stats")
|
|
|
|
|
|
async def stats_proxy(request: Request, model: Optional[str] = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Return token usage statistics for a specific model.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# Get token counts from database
|
|
|
|
|
|
token_data = await db.get_token_counts_for_model(model)
|
|
|
|
|
|
|
|
|
|
|
|
if not token_data:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=404, detail="No token data found for this model"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-03 17:20:33 +01:00
|
|
|
|
time_series = [
|
|
|
|
|
|
entry async for entry in db.get_time_series_for_model(model)
|
|
|
|
|
|
]
|
|
|
|
|
|
endpoint_distribution = await db.get_endpoint_distribution_for_model(model)
|
2025-11-18 19:02:36 +01:00
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'model': model,
|
|
|
|
|
|
'input_tokens': token_data['input_tokens'],
|
|
|
|
|
|
'output_tokens': token_data['output_tokens'],
|
|
|
|
|
|
'total_tokens': token_data['total_tokens'],
|
2025-11-19 17:28:31 +01:00
|
|
|
|
'time_series': time_series,
|
2026-03-03 17:20:33 +01:00
|
|
|
|
'endpoint_distribution': endpoint_distribution,
|
2025-11-18 19:02:36 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 12. API route – Copy
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/copy")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None):
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model copy request to each Ollama endpoint and reply with Status Code.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if not source and not destination:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-09-05 12:11:31 +02:00
|
|
|
|
src = payload.get("source")
|
|
|
|
|
|
dst = payload.get("destination")
|
|
|
|
|
|
else:
|
|
|
|
|
|
src = source
|
|
|
|
|
|
dst = destination
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
if not src:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'source'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not dst:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'destination'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Iterate over all endpoints to copy the model on each endpoint
|
|
|
|
|
|
status_list = []
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
for endpoint in config.endpoints:
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 4. Proxy a simple copy request
|
|
|
|
|
|
copy = await client.copy(source=src, destination=dst)
|
|
|
|
|
|
status_list.append(copy.status)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# 4. Return with 200 OK if all went well, 404 if a single endpoint failed
|
2025-09-05 12:11:31 +02:00
|
|
|
|
return Response(status_code=404 if 404 in status_list else 200)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 13. API route – Delete
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.delete("/api/delete")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
async def delete_proxy(request: Request, model: Optional[str] = None):
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Proxy a model delete request to each Ollama endpoint and reply with Status Code.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if not model:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-09-05 12:11:31 +02:00
|
|
|
|
model = payload.get("model")
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints to delete the model on each endpoint
|
|
|
|
|
|
status_list = []
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
for endpoint in config.endpoints:
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple copy request
|
|
|
|
|
|
copy = await client.delete(model=model)
|
|
|
|
|
|
status_list.append(copy.status)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2025-10-28 11:08:52 +01:00
|
|
|
|
# 4. Return 200 0K, if a single enpoint fails, respond with 404
|
2025-09-05 12:11:31 +02:00
|
|
|
|
return Response(status_code=404 if 404 in status_list else 200)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 14. API route – Pull
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/pull")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
async def pull_proxy(request: Request, model: Optional[str] = None):
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
|
|
|
|
|
Proxy a pull request to all Ollama endpoint and report status back.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if not model:
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-09-05 12:11:31 +02:00
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
insecure = payload.get("insecure")
|
|
|
|
|
|
else:
|
|
|
|
|
|
insecure = None
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints to pull the model
|
|
|
|
|
|
status_list = []
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
for endpoint in config.endpoints:
|
2025-09-05 12:11:31 +02:00
|
|
|
|
if "/v1" not in endpoint:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple pull request
|
|
|
|
|
|
pull = await client.pull(model=model, insecure=insecure, stream=False)
|
|
|
|
|
|
status_list.append(pull)
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status in status_list:
|
|
|
|
|
|
combined_status += status
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Report back a deduplicated status message
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 15. API route – Push
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/api/push")
|
|
|
|
|
|
async def push_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
insecure = payload.get("insecure")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-26 18:19:43 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Iterate over all endpoints
|
|
|
|
|
|
status_list = []
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
for endpoint in config.endpoints:
|
|
|
|
|
|
client = ollama.AsyncClient(host=endpoint)
|
|
|
|
|
|
# 3. Proxy a simple push request
|
|
|
|
|
|
push = await client.push(model=model, insecure=insecure, stream=False)
|
|
|
|
|
|
status_list.append(push)
|
|
|
|
|
|
|
|
|
|
|
|
combined_status = []
|
|
|
|
|
|
for status in status_list:
|
|
|
|
|
|
combined_status += status
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Report a deduplicated status
|
|
|
|
|
|
final_status = list(dict.fromkeys(combined_status))
|
|
|
|
|
|
|
|
|
|
|
|
return dict(final_status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 16. API route – Version
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/version")
|
|
|
|
|
|
async def version_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a version request to Ollama and reply lowest version of all endpoints.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Query all endpoints for version
|
2025-09-13 16:57:09 +02:00
|
|
|
|
tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep]
|
2026-02-17 15:56:09 +01:00
|
|
|
|
all_versions_raw = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
|
|
# Filter out non-string values (e.g., empty lists from failed/timeout responses)
|
|
|
|
|
|
all_versions = [v for v in all_versions_raw if isinstance(v, str) and v]
|
|
|
|
|
|
|
|
|
|
|
|
if not all_versions:
|
|
|
|
|
|
raise HTTPException(status_code=503, detail="No valid version response from any endpoint")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
def version_key(v):
|
|
|
|
|
|
return tuple(map(int, v.split('.')))
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"version": str(min(all_versions, key=version_key))},
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 17. API route – tags
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/tags")
|
|
|
|
|
|
async def tags_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a tags request to Ollama endpoints and reply with a unique list of all models.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
2025-09-05 12:11:31 +02:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# 1. Query all endpoints for models
|
2025-09-13 16:57:09 +02:00
|
|
|
|
tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep]) for ep in config.endpoints if "/v1" in ep]
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# Also query llama-server endpoints not already covered by config.endpoints
|
|
|
|
|
|
llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints]
|
|
|
|
|
|
tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep)) for ep in llama_eps_for_tags]
|
2025-08-26 18:19:43 +02:00
|
|
|
|
all_models = await asyncio.gather(*tasks)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
models = {'models': []}
|
|
|
|
|
|
for modellist in all_models:
|
2025-09-15 11:57:00 +02:00
|
|
|
|
for model in modellist:
|
|
|
|
|
|
if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id
|
2025-09-15 19:12:00 +02:00
|
|
|
|
model['model'] = model['id'] + ":latest"
|
2025-09-15 11:57:00 +02:00
|
|
|
|
else:
|
|
|
|
|
|
model['id'] = model['model']
|
2025-09-15 17:00:53 +02:00
|
|
|
|
if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys
|
|
|
|
|
|
model['name'] = model['model']
|
|
|
|
|
|
else:
|
|
|
|
|
|
model['id'] = model['model']
|
2025-08-26 18:19:43 +02:00
|
|
|
|
models['models'] += modellist
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
|
|
|
|
|
return JSONResponse(
|
2025-09-05 12:11:31 +02:00
|
|
|
|
content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])},
|
2025-08-26 18:19:43 +02:00
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 18. API route – ps
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/ps")
|
|
|
|
|
|
async def ps_proxy(request: Request):
|
|
|
|
|
|
"""
|
2026-02-10 16:46:51 +01:00
|
|
|
|
Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models.
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
For Ollama endpoints: queries /api/ps
|
|
|
|
|
|
For llama-server endpoints: queries /v1/models with status.value == "loaded"
|
2025-08-26 18:19:43 +02:00
|
|
|
|
"""
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# 1. Query Ollama endpoints for running models via /api/ps
|
|
|
|
|
|
ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models") for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
# 2. Query llama-server endpoints for loaded models via /v1/models
|
|
|
|
|
|
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
|
|
|
|
|
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
|
|
|
|
|
llama_tasks = [
|
|
|
|
|
|
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep))
|
|
|
|
|
|
for ep in all_llama_endpoints
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
|
|
|
|
|
llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
|
|
|
|
|
models = {'models': []}
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# Add Ollama models (if any)
|
|
|
|
|
|
if ollama_loaded:
|
|
|
|
|
|
for modellist in ollama_loaded:
|
|
|
|
|
|
models['models'] += modellist
|
|
|
|
|
|
# Add llama-server models (filter for loaded only, if any)
|
|
|
|
|
|
if llama_loaded:
|
|
|
|
|
|
for modellist in llama_loaded:
|
|
|
|
|
|
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
|
|
|
|
|
|
# Convert llama-server format to Ollama-like format for consistency
|
|
|
|
|
|
for item in loaded_models:
|
|
|
|
|
|
raw_id = item.get("id", "")
|
|
|
|
|
|
normalized = _normalize_llama_model_name(raw_id)
|
|
|
|
|
|
quant = _extract_llama_quant(raw_id)
|
|
|
|
|
|
models['models'].append({
|
|
|
|
|
|
"name": normalized,
|
|
|
|
|
|
"id": normalized,
|
|
|
|
|
|
"digest": "",
|
|
|
|
|
|
"status": item.get("status"),
|
|
|
|
|
|
"details": {"quantization_level": quant} if quant else {}
|
|
|
|
|
|
})
|
2025-08-26 18:19:43 +02:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# 3. Return a JSONResponse with deduplicated currently deployed models
|
2026-03-03 16:34:16 +01:00
|
|
|
|
# Deduplicate on 'name' rather than 'digest': llama-server models always
|
|
|
|
|
|
# have digest="" so deduping on digest collapses all of them to one entry.
|
2025-08-26 18:19:43 +02:00
|
|
|
|
return JSONResponse(
|
2026-03-03 16:34:16 +01:00
|
|
|
|
content={"models": dedupe_on_keys(models['models'], ['name'])},
|
2025-08-26 18:19:43 +02:00
|
|
|
|
status_code=200,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-27 13:29:54 +01:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 18b. API route – ps details (backwards compatible)
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/ps_details")
|
|
|
|
|
|
async def ps_details_proxy(request: Request):
|
|
|
|
|
|
"""
|
2026-02-10 16:46:51 +01:00
|
|
|
|
Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances.
|
2026-01-27 13:29:54 +01:00
|
|
|
|
This keeps /api/ps backward compatible while providing richer data.
|
2026-02-10 16:46:51 +01:00
|
|
|
|
|
|
|
|
|
|
For Ollama endpoints: queries /api/ps
|
|
|
|
|
|
For llama-server endpoints: queries /v1/models with status info
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Query Ollama endpoints via /api/ps
|
|
|
|
|
|
ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models")) for ep in config.endpoints if "/v1" not in ep]
|
|
|
|
|
|
# 2. Query llama-server endpoints via /v1/models
|
|
|
|
|
|
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
|
|
|
|
|
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
|
|
|
|
|
llama_tasks = [
|
|
|
|
|
|
(ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep)))
|
|
|
|
|
|
for ep in all_llama_endpoints
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else []
|
|
|
|
|
|
llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else []
|
2026-01-27 13:29:54 +01:00
|
|
|
|
|
|
|
|
|
|
models: list[dict] = []
|
2026-02-10 16:46:51 +01:00
|
|
|
|
|
|
|
|
|
|
# Add Ollama models with endpoint info (if any)
|
|
|
|
|
|
if ollama_loaded:
|
|
|
|
|
|
for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded):
|
|
|
|
|
|
for model in modellist:
|
|
|
|
|
|
if isinstance(model, dict):
|
|
|
|
|
|
model_with_endpoint = dict(model)
|
|
|
|
|
|
model_with_endpoint["endpoint"] = endpoint
|
|
|
|
|
|
models.append(model_with_endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
# Add llama-server models with endpoint info and full status metadata (if any)
|
|
|
|
|
|
if llama_loaded:
|
2026-02-15 17:05:35 +01:00
|
|
|
|
# Collect (endpoint, raw_id) pairs to fetch /props in parallel
|
|
|
|
|
|
props_requests: list[tuple[str, str]] = []
|
|
|
|
|
|
llama_models_pending: list[dict] = []
|
|
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded):
|
|
|
|
|
|
# Filter for loaded models only
|
|
|
|
|
|
loaded_models = [item for item in modellist if _is_llama_model_loaded(item)]
|
|
|
|
|
|
for item in loaded_models:
|
|
|
|
|
|
if isinstance(item, dict) and item.get("id"):
|
|
|
|
|
|
raw_id = item["id"]
|
|
|
|
|
|
normalized = _normalize_llama_model_name(raw_id)
|
|
|
|
|
|
quant = _extract_llama_quant(raw_id)
|
|
|
|
|
|
model_with_endpoint = {
|
|
|
|
|
|
"name": normalized,
|
|
|
|
|
|
"id": normalized,
|
|
|
|
|
|
"original_name": raw_id,
|
|
|
|
|
|
"digest": "",
|
|
|
|
|
|
"details": {"quantization_level": quant} if quant else {},
|
|
|
|
|
|
"endpoint": endpoint,
|
|
|
|
|
|
"status": item.get("status"),
|
|
|
|
|
|
"created": item.get("created"),
|
|
|
|
|
|
"owned_by": item.get("owned_by")
|
|
|
|
|
|
}
|
|
|
|
|
|
# Include full llama-server status details (args, preset)
|
|
|
|
|
|
status_info = item.get("status", {})
|
|
|
|
|
|
if isinstance(status_info, dict):
|
|
|
|
|
|
model_with_endpoint["llama_status_args"] = status_info.get("args")
|
|
|
|
|
|
model_with_endpoint["llama_status_preset"] = status_info.get("preset")
|
2026-02-15 17:05:35 +01:00
|
|
|
|
llama_models_pending.append(model_with_endpoint)
|
|
|
|
|
|
props_requests.append((endpoint, raw_id))
|
|
|
|
|
|
|
|
|
|
|
|
# Fetch /props for each llama-server model to get context length (n_ctx)
|
|
|
|
|
|
# and unload sleeping models automatically
|
|
|
|
|
|
async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool]:
|
|
|
|
|
|
client: aiohttp.ClientSession = app_state["session"]
|
|
|
|
|
|
base_url = endpoint.rstrip("/").removesuffix("/v1")
|
|
|
|
|
|
props_url = f"{base_url}/props?model={model_id}"
|
|
|
|
|
|
headers = None
|
|
|
|
|
|
api_key = config.api_keys.get(endpoint)
|
|
|
|
|
|
if api_key:
|
|
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.get(props_url, headers=headers) as resp:
|
|
|
|
|
|
if resp.status == 200:
|
|
|
|
|
|
data = await resp.json()
|
|
|
|
|
|
dgs = data.get("default_generation_settings", {})
|
|
|
|
|
|
n_ctx = dgs.get("n_ctx")
|
|
|
|
|
|
is_sleeping = data.get("is_sleeping", False)
|
|
|
|
|
|
|
|
|
|
|
|
if is_sleeping:
|
|
|
|
|
|
unload_url = f"{base_url}/models/unload"
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.post(
|
|
|
|
|
|
unload_url,
|
|
|
|
|
|
json={"model": model_id},
|
|
|
|
|
|
headers=headers,
|
|
|
|
|
|
) as unload_resp:
|
|
|
|
|
|
print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}")
|
|
|
|
|
|
except Exception as ue:
|
|
|
|
|
|
print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}")
|
|
|
|
|
|
|
|
|
|
|
|
return n_ctx, is_sleeping
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[ps_details] Failed to fetch props from {props_url}: {e}")
|
|
|
|
|
|
return None, False
|
|
|
|
|
|
|
|
|
|
|
|
props_results = await asyncio.gather(
|
|
|
|
|
|
*[_fetch_llama_props(ep, mid) for ep, mid in props_requests]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
for model_dict, (n_ctx, is_sleeping) in zip(llama_models_pending, props_results):
|
|
|
|
|
|
if n_ctx is not None:
|
|
|
|
|
|
model_dict["context_length"] = n_ctx
|
|
|
|
|
|
if not is_sleeping:
|
|
|
|
|
|
models.append(model_dict)
|
2026-01-27 13:29:54 +01:00
|
|
|
|
|
|
|
|
|
|
return JSONResponse(content={"models": models}, status_code=200)
|
|
|
|
|
|
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 19. Proxy usage route – for monitoring
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/usage")
|
|
|
|
|
|
async def usage_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Return a snapshot of the usage counter for each endpoint.
|
|
|
|
|
|
Useful for debugging / monitoring.
|
|
|
|
|
|
"""
|
2025-11-04 17:55:19 +01:00
|
|
|
|
return {"usage_counts": usage_counts,
|
|
|
|
|
|
"token_usage_counts": token_usage_counts}
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 20. Proxy config route – for monitoring and frontent usage
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/config")
|
|
|
|
|
|
async def config_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Return a simple JSON object that contains the configured
|
2026-02-10 16:46:51 +01:00
|
|
|
|
Ollama endpoints and llama_server_endpoints. The front‑end uses this to display
|
2025-08-30 00:12:56 +02:00
|
|
|
|
which endpoints are being proxied.
|
|
|
|
|
|
"""
|
2025-08-30 12:43:35 +02:00
|
|
|
|
async def check_endpoint(url: str):
|
2025-11-07 13:59:16 +01:00
|
|
|
|
client: aiohttp.ClientSession = app_state["session"]
|
|
|
|
|
|
headers = None
|
|
|
|
|
|
if "/v1" in url:
|
2026-03-03 16:34:16 +01:00
|
|
|
|
headers = {"Authorization": "Bearer " + config.api_keys.get(url, "no-key")}
|
2025-11-07 13:59:16 +01:00
|
|
|
|
target_url = f"{url}/models"
|
|
|
|
|
|
else:
|
|
|
|
|
|
target_url = f"{url}/api/version"
|
|
|
|
|
|
|
2025-08-30 12:43:35 +02:00
|
|
|
|
try:
|
2025-11-07 13:59:16 +01:00
|
|
|
|
async with client.get(target_url, headers=headers) as resp:
|
|
|
|
|
|
await _ensure_success(resp)
|
|
|
|
|
|
data = await resp.json()
|
2025-09-10 10:21:49 +02:00
|
|
|
|
if "/v1" in url:
|
|
|
|
|
|
return {"url": url, "status": "ok", "version": "latest"}
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {"url": url, "status": "ok", "version": data.get("version")}
|
2025-09-11 18:53:23 +02:00
|
|
|
|
except Exception as e:
|
2025-11-07 13:59:16 +01:00
|
|
|
|
detail = _format_connection_issue(target_url, e)
|
|
|
|
|
|
return {"url": url, "status": "error", "detail": detail}
|
2025-08-30 12:43:35 +02:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# Check Ollama endpoints
|
|
|
|
|
|
ollama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.endpoints])
|
|
|
|
|
|
|
|
|
|
|
|
# Check llama-server endpoints
|
|
|
|
|
|
llama_results = []
|
|
|
|
|
|
if config.llama_server_endpoints:
|
|
|
|
|
|
llama_results = await asyncio.gather(*[check_endpoint(ep) for ep in config.llama_server_endpoints])
|
|
|
|
|
|
|
2026-01-14 09:28:02 +01:00
|
|
|
|
return {
|
2026-02-10 16:46:51 +01:00
|
|
|
|
"endpoints": ollama_results,
|
|
|
|
|
|
"llama_server_endpoints": llama_results,
|
2026-01-14 09:28:02 +01:00
|
|
|
|
"require_router_api_key": bool(config.router_api_key),
|
|
|
|
|
|
}
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 21. API route – OpenAI compatible Embedding
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/v1/embeddings")
|
|
|
|
|
|
async def openai_embedding_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
2025-09-11 13:56:51 +02:00
|
|
|
|
doc = payload.get("input")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
|
2025-08-28 09:40:33 +02:00
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
2025-09-11 13:56:51 +02:00
|
|
|
|
if not doc:
|
2025-08-28 09:40:33 +02:00
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'input'"
|
|
|
|
|
|
)
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-28 09:40:33 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Endpoint logic
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if is_openai_compatible(endpoint):
|
|
|
|
|
|
api_key = config.api_keys.get(endpoint, "no-key")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
else:
|
|
|
|
|
|
api_key = "ollama"
|
2025-09-15 09:04:38 +02:00
|
|
|
|
base_url = ep2base(endpoint)
|
2026-02-12 16:15:39 +01:00
|
|
|
|
|
2025-09-21 16:20:36 +02:00
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=api_key)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
2026-02-27 16:39:27 +01:00
|
|
|
|
try:
|
|
|
|
|
|
async_gen = await oclient.embeddings.create(input=doc, model=model)
|
|
|
|
|
|
result = async_gen.model_dump()
|
|
|
|
|
|
for item in result.get("data", []):
|
|
|
|
|
|
emb = item.get("embedding")
|
|
|
|
|
|
if emb:
|
|
|
|
|
|
item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb]
|
|
|
|
|
|
return JSONResponse(content=result)
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 22. API route – OpenAI compatible Chat Completions
|
2025-08-27 09:23:59 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async def openai_chat_completions_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
messages = payload.get("messages")
|
|
|
|
|
|
frequency_penalty = payload.get("frequency_penalty")
|
|
|
|
|
|
presence_penalty = payload.get("presence_penalty")
|
|
|
|
|
|
response_format = payload.get("response_format")
|
|
|
|
|
|
seed = payload.get("seed")
|
|
|
|
|
|
stop = payload.get("stop")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
stream_options = payload.get("stream_options")
|
|
|
|
|
|
temperature = payload.get("temperature")
|
|
|
|
|
|
top_p = payload.get("top_p")
|
|
|
|
|
|
max_tokens = payload.get("max_tokens")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
max_completion_tokens = payload.get("max_completion_tokens")
|
2025-08-30 00:12:56 +02:00
|
|
|
|
tools = payload.get("tools")
|
2026-02-13 14:43:10 +01:00
|
|
|
|
logprobs = payload.get("logprobs")
|
|
|
|
|
|
top_logprobs = payload.get("top_logprobs")
|
2026-03-10 15:19:37 +01:00
|
|
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
2026-03-03 16:34:16 +01:00
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not isinstance(messages, list):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'messages' (must be a list)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-09-17 11:40:48 +02:00
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")
|
|
|
|
|
|
model = model[0]
|
|
|
|
|
|
|
2025-08-30 00:12:56 +02:00
|
|
|
|
params = {
|
2026-03-03 16:34:16 +01:00
|
|
|
|
"messages": messages,
|
2025-08-30 00:12:56 +02:00
|
|
|
|
"model": model,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-09-11 13:56:51 +02:00
|
|
|
|
optional_params = {
|
|
|
|
|
|
"tools": tools,
|
|
|
|
|
|
"response_format": response_format,
|
2025-11-21 09:56:42 +01:00
|
|
|
|
"stream_options": stream_options or {"include_usage": True },
|
2025-09-11 13:56:51 +02:00
|
|
|
|
"max_completion_tokens": max_completion_tokens,
|
|
|
|
|
|
"max_tokens": max_tokens,
|
|
|
|
|
|
"temperature": temperature,
|
|
|
|
|
|
"top_p": top_p,
|
|
|
|
|
|
"seed": seed,
|
|
|
|
|
|
"presence_penalty": presence_penalty,
|
|
|
|
|
|
"frequency_penalty": frequency_penalty,
|
|
|
|
|
|
"stop": stop,
|
|
|
|
|
|
"stream": stream,
|
2026-02-13 14:43:10 +01:00
|
|
|
|
"logprobs": logprobs,
|
|
|
|
|
|
"top_logprobs": top_logprobs,
|
2025-09-11 13:56:51 +02:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-28 09:40:33 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
2026-03-10 15:19:37 +01:00
|
|
|
|
# Reject unsupported image formats (SVG) before doing any work
|
|
|
|
|
|
for _msg in messages:
|
|
|
|
|
|
for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []:
|
|
|
|
|
|
if _item.get("type") == "image_url":
|
|
|
|
|
|
_url = (_item.get("image_url") or {}).get("url", "")
|
|
|
|
|
|
if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400,
|
|
|
|
|
|
detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Cache lookup — before endpoint selection
|
|
|
|
|
|
_cache = get_llm_cache()
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
_cached = await _cache.get_chat("openai_chat", model, messages)
|
|
|
|
|
|
if _cached is not None:
|
|
|
|
|
|
if stream:
|
|
|
|
|
|
_sse = openai_nonstream_to_sse(_cached, model)
|
|
|
|
|
|
async def _serve_cached_ochat_stream():
|
|
|
|
|
|
yield _sse
|
|
|
|
|
|
return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream")
|
|
|
|
|
|
else:
|
|
|
|
|
|
async def _serve_cached_ochat_json():
|
|
|
|
|
|
yield _cached
|
|
|
|
|
|
return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json")
|
|
|
|
|
|
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 2. Endpoint logic
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2025-08-30 00:12:56 +02:00
|
|
|
|
base_url = ep2base(endpoint)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 3. Async generator that streams completions data and decrements the counter
|
2026-03-10 15:19:37 +01:00
|
|
|
|
async def _normalize_images_in_messages(msgs: list) -> list:
|
|
|
|
|
|
"""Fetch remote image URLs and convert them to base64 data URLs so
|
|
|
|
|
|
Ollama/llama-server can handle them without making outbound HTTP requests."""
|
|
|
|
|
|
resolved = []
|
|
|
|
|
|
for msg in msgs:
|
|
|
|
|
|
content = msg.get("content")
|
|
|
|
|
|
if not isinstance(content, list):
|
|
|
|
|
|
resolved.append(msg)
|
|
|
|
|
|
continue
|
|
|
|
|
|
new_content = []
|
|
|
|
|
|
for item in content:
|
|
|
|
|
|
if item.get("type") == "image_url":
|
|
|
|
|
|
url = (item.get("image_url") or {}).get("url", "")
|
|
|
|
|
|
if url and not url.startswith("data:"):
|
|
|
|
|
|
try:
|
|
|
|
|
|
http: aiohttp.ClientSession = app_state["session"]
|
|
|
|
|
|
async with http.get(url) as resp:
|
|
|
|
|
|
ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
|
|
|
|
|
img_bytes = await resp.read()
|
|
|
|
|
|
b64 = base64.b64encode(img_bytes).decode("utf-8")
|
|
|
|
|
|
new_content.append({
|
|
|
|
|
|
"type": "image_url",
|
|
|
|
|
|
"image_url": {"url": f"data:{ctype};base64,{b64}"}
|
|
|
|
|
|
})
|
|
|
|
|
|
except Exception as _ie:
|
|
|
|
|
|
print(f"[image] Failed to fetch image URL: {_ie}")
|
|
|
|
|
|
new_content.append(item)
|
|
|
|
|
|
else:
|
|
|
|
|
|
new_content.append(item)
|
|
|
|
|
|
else:
|
|
|
|
|
|
new_content.append(item)
|
|
|
|
|
|
resolved.append({**msg, "content": new_content})
|
|
|
|
|
|
return resolved
|
|
|
|
|
|
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async def stream_ochat_response():
|
|
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
2026-01-29 18:00:33 +01:00
|
|
|
|
try:
|
2026-03-10 15:19:37 +01:00
|
|
|
|
# For non-external endpoints (Ollama, llama-server), resolve remote
|
|
|
|
|
|
# image URLs to base64 data URLs so the server can handle them locally.
|
|
|
|
|
|
send_params = params
|
|
|
|
|
|
if not is_ext_openai_endpoint(endpoint):
|
|
|
|
|
|
resolved_msgs = await _normalize_images_in_messages(params.get("messages", []))
|
|
|
|
|
|
send_params = {**params, "messages": resolved_msgs}
|
|
|
|
|
|
async_gen = await oclient.chat.completions.create(**send_params)
|
2026-01-29 18:00:33 +01:00
|
|
|
|
except openai.BadRequestError as e:
|
|
|
|
|
|
# If tools are not supported by the model, retry without tools
|
|
|
|
|
|
if "does not support tools" in str(e):
|
|
|
|
|
|
print(f"[openai_chat_completions_proxy] Model {model} doesn't support tools, retrying without tools")
|
2026-03-10 15:19:37 +01:00
|
|
|
|
params_without_tools = {k: v for k, v in send_params.items() if k != "tools"}
|
2026-01-29 18:00:33 +01:00
|
|
|
|
async_gen = await oclient.chat.completions.create(**params_without_tools)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise
|
2025-08-28 09:40:33 +02:00
|
|
|
|
if stream == True:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
content_parts: list[str] = []
|
|
|
|
|
|
usage_snapshot: dict = {}
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
data = (
|
|
|
|
|
|
chunk.model_dump_json()
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json")
|
2025-11-10 15:37:46 +01:00
|
|
|
|
else orjson.dumps(chunk)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
)
|
2025-11-21 09:56:42 +01:00
|
|
|
|
if chunk.choices:
|
2026-02-15 17:05:35 +01:00
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
|
|
has_content = delta.content is not None
|
|
|
|
|
|
has_reasoning = (
|
|
|
|
|
|
getattr(delta, "reasoning_content", None) is not None
|
|
|
|
|
|
or getattr(delta, "reasoning", None) is not None
|
|
|
|
|
|
)
|
|
|
|
|
|
has_tool_calls = getattr(delta, "tool_calls", None) is not None
|
|
|
|
|
|
if has_content or has_reasoning or has_tool_calls:
|
2025-11-21 09:56:42 +01:00
|
|
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
2026-03-08 09:12:09 +01:00
|
|
|
|
if has_content and delta.content:
|
|
|
|
|
|
content_parts.append(delta.content)
|
2026-02-14 14:51:44 +01:00
|
|
|
|
elif chunk.usage is not None:
|
|
|
|
|
|
# Forward the usage-only final chunk (e.g. from llama-server)
|
|
|
|
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|
|
|
|
|
prompt_tok = 0
|
|
|
|
|
|
comp_tok = 0
|
2025-11-21 09:56:42 +01:00
|
|
|
|
if chunk.usage is not None:
|
|
|
|
|
|
prompt_tok = chunk.usage.prompt_tokens or 0
|
|
|
|
|
|
comp_tok = chunk.usage.completion_tokens or 0
|
2026-03-08 09:12:09 +01:00
|
|
|
|
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
2026-02-14 14:51:44 +01:00
|
|
|
|
else:
|
|
|
|
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
|
|
|
|
|
if llama_usage:
|
|
|
|
|
|
prompt_tok, comp_tok = llama_usage
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-17 11:35:53 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Cache assembled streaming response — before [DONE] so it always runs
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled and content_parts:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
assembled = orjson.dumps({
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}],
|
|
|
|
|
|
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
|
|
|
|
|
}) + b"\n"
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("openai_chat", model, messages, assembled)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}")
|
2025-09-23 17:33:15 +02:00
|
|
|
|
yield b"data: [DONE]\n\n"
|
2025-08-28 09:40:33 +02:00
|
|
|
|
else:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
prompt_tok = 0
|
|
|
|
|
|
comp_tok = 0
|
|
|
|
|
|
if async_gen.usage is not None:
|
|
|
|
|
|
prompt_tok = async_gen.usage.prompt_tokens or 0
|
|
|
|
|
|
comp_tok = async_gen.usage.completion_tokens or 0
|
|
|
|
|
|
else:
|
|
|
|
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
|
|
|
|
|
if llama_usage:
|
|
|
|
|
|
prompt_tok, comp_tok = llama_usage
|
2025-11-18 19:02:36 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-17 11:35:53 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
json_line = (
|
|
|
|
|
|
async_gen.model_dump_json()
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
2025-11-10 15:37:46 +01:00
|
|
|
|
else orjson.dumps(async_gen)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
)
|
2026-03-08 09:12:09 +01:00
|
|
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
yield cache_bytes
|
|
|
|
|
|
# Cache non-streaming response
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("openai_chat", model, messages, cache_bytes)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
2026-02-17 11:35:53 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_ochat_response(),
|
2026-02-14 14:51:44 +01:00
|
|
|
|
media_type="text/event-stream" if stream else "application/json",
|
2025-08-28 09:40:33 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 23. API route – OpenAI compatible Completions
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# -------------------------------------------------------------
|
2025-08-27 09:23:59 +02:00
|
|
|
|
@app.post("/v1/completions")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async def openai_completions_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response.
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Parse and validate request
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
2025-11-10 15:37:46 +01:00
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
prompt = payload.get("prompt")
|
|
|
|
|
|
frequency_penalty = payload.get("frequency_penalty")
|
|
|
|
|
|
presence_penalty = payload.get("presence_penalty")
|
|
|
|
|
|
seed = payload.get("seed")
|
|
|
|
|
|
stop = payload.get("stop")
|
|
|
|
|
|
stream = payload.get("stream")
|
|
|
|
|
|
stream_options = payload.get("stream_options")
|
|
|
|
|
|
temperature = payload.get("temperature")
|
|
|
|
|
|
top_p = payload.get("top_p")
|
|
|
|
|
|
max_tokens = payload.get("max_tokens")
|
2025-09-05 12:11:31 +02:00
|
|
|
|
max_completion_tokens = payload.get("max_completion_tokens")
|
2025-08-30 00:12:56 +02:00
|
|
|
|
suffix = payload.get("suffix")
|
2026-03-10 15:19:37 +01:00
|
|
|
|
_cache_enabled = payload.get("nomyo", {}).get("cache", False)
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
2026-03-03 16:34:16 +01:00
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'model'"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not prompt:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=400, detail="Missing required field 'prompt'"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-09-17 11:40:48 +02:00
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")
|
|
|
|
|
|
model = model[0]
|
|
|
|
|
|
|
2025-08-30 00:12:56 +02:00
|
|
|
|
params = {
|
2026-03-03 16:34:16 +01:00
|
|
|
|
"prompt": prompt,
|
2025-08-30 00:12:56 +02:00
|
|
|
|
"model": model,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-09-11 13:56:51 +02:00
|
|
|
|
optional_params = {
|
|
|
|
|
|
"frequency_penalty": frequency_penalty,
|
|
|
|
|
|
"presence_penalty": presence_penalty,
|
|
|
|
|
|
"seed": seed,
|
|
|
|
|
|
"stop": stop,
|
|
|
|
|
|
"stream": stream,
|
2025-11-21 09:56:42 +01:00
|
|
|
|
"stream_options": stream_options or {"include_usage": True },
|
2025-09-11 13:56:51 +02:00
|
|
|
|
"temperature": temperature,
|
|
|
|
|
|
"top_p": top_p,
|
|
|
|
|
|
"max_tokens": max_tokens,
|
|
|
|
|
|
"max_completion_tokens": max_completion_tokens,
|
|
|
|
|
|
"suffix": suffix
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
2025-11-10 15:37:46 +01:00
|
|
|
|
except orjson.JSONDecodeError as e:
|
2025-08-28 09:40:33 +02:00
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Cache lookup — completions prompt mapped to a single-turn messages list
|
|
|
|
|
|
_cache = get_llm_cache()
|
|
|
|
|
|
_compl_messages = [{"role": "user", "content": prompt}]
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
_cached = await _cache.get_chat("openai_completions", model, _compl_messages)
|
|
|
|
|
|
if _cached is not None:
|
|
|
|
|
|
if stream:
|
|
|
|
|
|
_sse = openai_nonstream_to_sse(_cached, model)
|
|
|
|
|
|
async def _serve_cached_ocompl_stream():
|
|
|
|
|
|
yield _sse
|
|
|
|
|
|
return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream")
|
|
|
|
|
|
else:
|
|
|
|
|
|
async def _serve_cached_ocompl_json():
|
|
|
|
|
|
yield _cached
|
|
|
|
|
|
return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json")
|
|
|
|
|
|
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# 2. Endpoint logic
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2025-08-30 00:12:56 +02:00
|
|
|
|
base_url = ep2base(endpoint)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
oclient = openai.AsyncOpenAI(base_url=base_url, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key"))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
# 3. Async generator that streams completions data and decrements the counter
|
2025-11-21 09:56:42 +01:00
|
|
|
|
async def stream_ocompletions_response(model=model):
|
2025-08-28 09:40:33 +02:00
|
|
|
|
try:
|
|
|
|
|
|
# The chat method returns a generator of dicts (or GenerateResponse)
|
2025-08-30 00:12:56 +02:00
|
|
|
|
async_gen = await oclient.completions.create(**params)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
if stream == True:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
text_parts: list[str] = []
|
|
|
|
|
|
usage_snapshot: dict = {}
|
2025-08-28 09:40:33 +02:00
|
|
|
|
async for chunk in async_gen:
|
|
|
|
|
|
data = (
|
|
|
|
|
|
chunk.model_dump_json()
|
|
|
|
|
|
if hasattr(chunk, "model_dump_json")
|
2025-11-10 15:37:46 +01:00
|
|
|
|
else orjson.dumps(chunk)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
)
|
2025-11-21 09:56:42 +01:00
|
|
|
|
if chunk.choices:
|
2026-02-15 17:05:35 +01:00
|
|
|
|
choice = chunk.choices[0]
|
|
|
|
|
|
has_text = getattr(choice, "text", None) is not None
|
|
|
|
|
|
has_reasoning = (
|
|
|
|
|
|
getattr(choice, "reasoning_content", None) is not None
|
|
|
|
|
|
or getattr(choice, "reasoning", None) is not None
|
|
|
|
|
|
)
|
|
|
|
|
|
if has_text or has_reasoning or choice.finish_reason is not None:
|
2025-11-21 09:56:42 +01:00
|
|
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
2026-03-08 09:12:09 +01:00
|
|
|
|
if has_text and choice.text:
|
|
|
|
|
|
text_parts.append(choice.text)
|
2026-02-14 14:51:44 +01:00
|
|
|
|
elif chunk.usage is not None:
|
|
|
|
|
|
# Forward the usage-only final chunk (e.g. from llama-server)
|
|
|
|
|
|
yield f"data: {data}\n\n".encode("utf-8")
|
|
|
|
|
|
prompt_tok = 0
|
|
|
|
|
|
comp_tok = 0
|
2025-11-21 09:56:42 +01:00
|
|
|
|
if chunk.usage is not None:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
prompt_tok = chunk.usage.prompt_tokens or 0
|
|
|
|
|
|
comp_tok = chunk.usage.completion_tokens or 0
|
2026-03-08 09:12:09 +01:00
|
|
|
|
usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok}
|
2026-02-14 14:51:44 +01:00
|
|
|
|
else:
|
|
|
|
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(chunk)
|
|
|
|
|
|
if llama_usage:
|
|
|
|
|
|
prompt_tok, comp_tok = llama_usage
|
|
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-17 11:35:53 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# Cache assembled streaming response — before [DONE] so it always runs
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled and text_parts:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
assembled = orjson.dumps({
|
|
|
|
|
|
"model": model,
|
|
|
|
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}],
|
|
|
|
|
|
**({"usage": usage_snapshot} if usage_snapshot else {}),
|
|
|
|
|
|
}) + b"\n"
|
|
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("openai_completions", model, _compl_messages, assembled)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# Final DONE event
|
|
|
|
|
|
yield b"data: [DONE]\n\n"
|
|
|
|
|
|
else:
|
2026-02-14 14:51:44 +01:00
|
|
|
|
prompt_tok = 0
|
|
|
|
|
|
comp_tok = 0
|
|
|
|
|
|
if async_gen.usage is not None:
|
|
|
|
|
|
prompt_tok = async_gen.usage.prompt_tokens or 0
|
|
|
|
|
|
comp_tok = async_gen.usage.completion_tokens or 0
|
|
|
|
|
|
else:
|
|
|
|
|
|
llama_usage = rechunk.extract_usage_from_llama_timings(async_gen)
|
|
|
|
|
|
if llama_usage:
|
|
|
|
|
|
prompt_tok, comp_tok = llama_usage
|
2025-11-18 19:02:36 +01:00
|
|
|
|
if prompt_tok != 0 or comp_tok != 0:
|
2026-02-17 11:35:53 +01:00
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok))
|
2025-08-28 09:40:33 +02:00
|
|
|
|
json_line = (
|
|
|
|
|
|
async_gen.model_dump_json()
|
|
|
|
|
|
if hasattr(async_gen, "model_dump_json")
|
2025-11-10 15:37:46 +01:00
|
|
|
|
else orjson.dumps(async_gen)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
)
|
2026-03-08 09:12:09 +01:00
|
|
|
|
cache_bytes = json_line.encode("utf-8") + b"\n"
|
|
|
|
|
|
yield cache_bytes
|
|
|
|
|
|
# Cache non-streaming response
|
2026-03-10 15:19:37 +01:00
|
|
|
|
if _cache is not None and _cache_enabled:
|
2026-03-08 09:12:09 +01:00
|
|
|
|
try:
|
|
|
|
|
|
await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes)
|
|
|
|
|
|
except Exception as _ce:
|
|
|
|
|
|
print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}")
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Ensure counter is decremented even if an exception occurs
|
2026-02-17 11:35:53 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2025-08-28 09:40:33 +02:00
|
|
|
|
|
|
|
|
|
|
# 4. Return a StreamingResponse backed by the generator
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
stream_ocompletions_response(),
|
2026-02-14 14:51:44 +01:00
|
|
|
|
media_type="text/event-stream" if stream else "application/json",
|
2025-08-28 09:40:33 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# 24. OpenAI API compatible models endpoint
|
2025-08-28 09:40:33 +02:00
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 00:12:56 +02:00
|
|
|
|
@app.get("/v1/models")
|
|
|
|
|
|
async def openai_models_proxy(request: Request):
|
|
|
|
|
|
"""
|
2026-02-10 16:46:51 +01:00
|
|
|
|
Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models.
|
|
|
|
|
|
|
|
|
|
|
|
For Ollama endpoints: queries /api/tags (all models)
|
|
|
|
|
|
For llama-server endpoints: queries /v1/models and filters for status.value == "loaded"
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. Query Ollama endpoints for all models via /api/tags
|
|
|
|
|
|
ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models") for ep in config.endpoints if "/v1" not in ep]
|
2026-02-13 16:27:06 +01:00
|
|
|
|
# 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models
|
|
|
|
|
|
ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep)) for ep in config.endpoints if is_ext_openai_endpoint(ep)]
|
|
|
|
|
|
# 3. Query llama-server endpoints for loaded models via /v1/models
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# Also query endpoints from llama_server_endpoints that may not be in config.endpoints
|
|
|
|
|
|
all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints)
|
|
|
|
|
|
llama_tasks = [
|
|
|
|
|
|
fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep))
|
|
|
|
|
|
for ep in all_llama_endpoints
|
|
|
|
|
|
]
|
2026-02-13 16:27:06 +01:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else []
|
2026-02-13 16:27:06 +01:00
|
|
|
|
ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else []
|
2026-02-10 16:46:51 +01:00
|
|
|
|
llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else []
|
2026-02-13 16:27:06 +01:00
|
|
|
|
|
2025-08-30 00:12:56 +02:00
|
|
|
|
models = {'data': []}
|
2026-02-13 16:27:06 +01:00
|
|
|
|
|
2026-02-10 16:46:51 +01:00
|
|
|
|
# Add Ollama models (if any)
|
|
|
|
|
|
if ollama_models:
|
|
|
|
|
|
for modellist in ollama_models:
|
|
|
|
|
|
for model in modellist:
|
|
|
|
|
|
if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name
|
|
|
|
|
|
model['id'] = model.get('name', model.get('id', ''))
|
|
|
|
|
|
else:
|
|
|
|
|
|
model['name'] = model['id']
|
|
|
|
|
|
models['data'].append(model)
|
2026-02-13 16:27:06 +01:00
|
|
|
|
|
|
|
|
|
|
# Add external OpenAI models (if any)
|
|
|
|
|
|
if ext_openai_models:
|
|
|
|
|
|
for modellist in ext_openai_models:
|
|
|
|
|
|
for model in modellist:
|
|
|
|
|
|
if not "id" in model.keys():
|
|
|
|
|
|
model['id'] = model.get('name', model.get('id', ''))
|
|
|
|
|
|
else:
|
|
|
|
|
|
model['name'] = model['id']
|
|
|
|
|
|
models['data'].append(model)
|
|
|
|
|
|
|
|
|
|
|
|
# Add llama-server models (all available, not just loaded)
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if llama_models:
|
|
|
|
|
|
for modellist in llama_models:
|
2026-02-13 16:27:06 +01:00
|
|
|
|
for model in modellist:
|
2026-02-10 16:46:51 +01:00
|
|
|
|
if not "id" in model.keys():
|
|
|
|
|
|
model['id'] = model.get('name', model.get('id', ''))
|
|
|
|
|
|
else:
|
|
|
|
|
|
model['name'] = model['id']
|
|
|
|
|
|
models['data'].append(model)
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
|
|
|
|
|
# 2. Return a JSONResponse with a deduplicated list of unique models for inference
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
content={"data": dedupe_on_keys(models['data'], ['name'])},
|
|
|
|
|
|
status_code=200,
|
2025-08-27 09:23:59 +02:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2026-02-28 09:31:25 +01:00
|
|
|
|
# 25. API route – OpenAI/Jina/Cohere compatible Rerank
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.post("/v1/rerank")
|
|
|
|
|
|
@app.post("/rerank")
|
|
|
|
|
|
async def rerank_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint.
|
|
|
|
|
|
|
|
|
|
|
|
Compatible with the Jina/Cohere rerank API convention used by llama-server,
|
|
|
|
|
|
vLLM, and services such as Cohere and Jina AI.
|
|
|
|
|
|
|
|
|
|
|
|
Ollama does not natively support reranking; requests routed to a plain Ollama
|
|
|
|
|
|
endpoint will receive a 501 Not Implemented response.
|
|
|
|
|
|
|
|
|
|
|
|
Request body:
|
|
|
|
|
|
model (str, required) – reranker model name
|
|
|
|
|
|
query (str, required) – search query
|
|
|
|
|
|
documents (list[str], required) – candidate documents to rank
|
|
|
|
|
|
top_n (int, optional) – limit returned results (default: all)
|
|
|
|
|
|
return_documents (bool, optional) – include document text in results
|
|
|
|
|
|
max_tokens_per_doc (int, optional) – truncation limit per document
|
|
|
|
|
|
|
|
|
|
|
|
Response (Jina/Cohere-compatible):
|
|
|
|
|
|
{
|
|
|
|
|
|
"id": "...",
|
|
|
|
|
|
"model": "...",
|
|
|
|
|
|
"usage": {"prompt_tokens": N, "total_tokens": N},
|
|
|
|
|
|
"results": [{"index": 0, "relevance_score": 0.95}, ...]
|
|
|
|
|
|
}
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
body_bytes = await request.body()
|
|
|
|
|
|
payload = orjson.loads(body_bytes.decode("utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
model = payload.get("model")
|
|
|
|
|
|
query = payload.get("query")
|
|
|
|
|
|
documents = payload.get("documents")
|
|
|
|
|
|
|
|
|
|
|
|
if not model:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing required field 'model'")
|
|
|
|
|
|
if not query:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing required field 'query'")
|
|
|
|
|
|
if not isinstance(documents, list) or not documents:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)")
|
|
|
|
|
|
except orjson.JSONDecodeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
# Determine which endpoint serves this model
|
|
|
|
|
|
try:
|
2026-03-03 14:57:37 +01:00
|
|
|
|
endpoint, tracking_model = await choose_endpoint(model)
|
2026-02-28 09:31:25 +01:00
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
# Ollama endpoints have no native rerank support
|
|
|
|
|
|
if not is_openai_compatible(endpoint):
|
2026-03-03 14:57:37 +01:00
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
2026-02-28 09:31:25 +01:00
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=501,
|
|
|
|
|
|
detail=(
|
|
|
|
|
|
f"Endpoint '{endpoint}' is a plain Ollama instance which does not support "
|
|
|
|
|
|
"reranking. Use a llama-server or OpenAI-compatible endpoint with a "
|
|
|
|
|
|
"dedicated reranker model."
|
|
|
|
|
|
),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if ":latest" in model:
|
|
|
|
|
|
model = model.split(":latest")[0]
|
|
|
|
|
|
|
|
|
|
|
|
# Build upstream rerank request body – forward only recognised fields
|
|
|
|
|
|
upstream_payload: dict = {"model": model, "query": query, "documents": documents}
|
|
|
|
|
|
for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"):
|
|
|
|
|
|
if optional_key in payload:
|
|
|
|
|
|
upstream_payload[optional_key] = payload[optional_key]
|
|
|
|
|
|
|
|
|
|
|
|
# Determine upstream URL:
|
|
|
|
|
|
# llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints)
|
|
|
|
|
|
# External OpenAI endpoints expose /rerank under their /v1 base
|
|
|
|
|
|
if endpoint in config.llama_server_endpoints:
|
|
|
|
|
|
# llama-server: endpoint may or may not already contain /v1
|
|
|
|
|
|
if "/v1" in endpoint:
|
|
|
|
|
|
rerank_url = f"{endpoint}/rerank"
|
|
|
|
|
|
else:
|
|
|
|
|
|
rerank_url = f"{endpoint}/v1/rerank"
|
|
|
|
|
|
else:
|
|
|
|
|
|
# External OpenAI-compatible: ep2base gives us the /v1 base
|
|
|
|
|
|
rerank_url = f"{ep2base(endpoint)}/rerank"
|
|
|
|
|
|
|
|
|
|
|
|
api_key = config.api_keys.get(endpoint, "no-key")
|
|
|
|
|
|
headers = {
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
"Authorization": f"Bearer {api_key}",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
client: aiohttp.ClientSession = app_state["session"]
|
|
|
|
|
|
try:
|
|
|
|
|
|
async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp:
|
|
|
|
|
|
response_bytes = await resp.read()
|
|
|
|
|
|
if resp.status >= 400:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=resp.status,
|
|
|
|
|
|
detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")),
|
|
|
|
|
|
)
|
|
|
|
|
|
data = orjson.loads(response_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
# Record token usage if the upstream returned a usage object
|
|
|
|
|
|
usage = data.get("usage") or {}
|
|
|
|
|
|
prompt_tok = usage.get("prompt_tokens") or 0
|
|
|
|
|
|
total_tok = usage.get("total_tokens") or 0
|
|
|
|
|
|
# For reranking there are no completion tokens; we record prompt tokens only
|
|
|
|
|
|
if prompt_tok or total_tok:
|
|
|
|
|
|
await token_queue.put((endpoint, tracking_model, prompt_tok, 0))
|
|
|
|
|
|
|
|
|
|
|
|
return JSONResponse(content=data)
|
|
|
|
|
|
finally:
|
|
|
|
|
|
await decrement_usage(endpoint, tracking_model)
|
|
|
|
|
|
|
2026-03-08 09:12:09 +01:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 25b. Cache management endpoints
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/cache/stats")
|
|
|
|
|
|
async def cache_stats():
|
|
|
|
|
|
"""Return hit/miss counters and configuration for the LLM response cache."""
|
|
|
|
|
|
c = get_llm_cache()
|
|
|
|
|
|
if c is None:
|
|
|
|
|
|
return {"enabled": False}
|
|
|
|
|
|
return {"enabled": True, **c.stats()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/cache/invalidate")
|
|
|
|
|
|
async def cache_invalidate():
|
|
|
|
|
|
"""Clear all entries from the LLM response cache and reset counters."""
|
|
|
|
|
|
c = get_llm_cache()
|
|
|
|
|
|
if c is None:
|
|
|
|
|
|
return {"enabled": False, "cleared": False}
|
|
|
|
|
|
await c.clear()
|
|
|
|
|
|
return {"enabled": True, "cleared": True}
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-28 09:31:25 +01:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
# 26. Serve the static front‑end
|
2025-08-30 00:12:56 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
2025-08-30 12:43:35 +02:00
|
|
|
|
@app.get("/favicon.ico")
|
|
|
|
|
|
async def redirect_favicon():
|
|
|
|
|
|
return RedirectResponse(url="/static/favicon.ico")
|
|
|
|
|
|
|
2025-08-30 00:12:56 +02:00
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
|
|
|
async def index(request: Request):
|
|
|
|
|
|
"""
|
2025-09-05 12:11:31 +02:00
|
|
|
|
Render the dynamic NOMYO Router dashboard listing the configured endpoints
|
|
|
|
|
|
and the models details, availability & task status.
|
2025-08-30 00:12:56 +02:00
|
|
|
|
"""
|
2026-01-05 17:16:31 +01:00
|
|
|
|
index_path = STATIC_DIR / "index.html"
|
|
|
|
|
|
try:
|
|
|
|
|
|
return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200)
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="Page not found")
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
2025-08-30 00:12:56 +02:00
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-08-30 12:43:35 +02:00
|
|
|
|
# 26. Healthendpoint
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/health")
|
|
|
|
|
|
async def health_proxy(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Health‑check endpoint for monitoring the proxy.
|
|
|
|
|
|
|
|
|
|
|
|
* Queries each configured endpoint for its `/api/version` response.
|
|
|
|
|
|
* Returns a JSON object containing:
|
|
|
|
|
|
- `status`: "ok" if every endpoint replied, otherwise "error".
|
|
|
|
|
|
- `endpoints`: a mapping of endpoint URL → `{status, version|detail}`.
|
|
|
|
|
|
* The HTTP status code is 200 when everything is healthy, 503 otherwise.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Run all health checks in parallel
|
2026-02-13 10:11:41 +01:00
|
|
|
|
tasks = [fetch.endpoint_details(ep, "/api/version", "version", skip_error_cache=True) for ep in config.endpoints] # if not is_ext_openai_endpoint(ep)]
|
2025-08-30 12:43:35 +02:00
|
|
|
|
|
|
|
|
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
|
|
|
|
|
|
health_summary = {}
|
|
|
|
|
|
overall_ok = True
|
|
|
|
|
|
|
|
|
|
|
|
for ep, result in zip(config.endpoints, results):
|
|
|
|
|
|
if isinstance(result, Exception):
|
|
|
|
|
|
# Endpoint did not respond / returned an error
|
|
|
|
|
|
health_summary[ep] = {"status": "error", "detail": str(result)}
|
|
|
|
|
|
overall_ok = False
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Successful response – report the reported version
|
|
|
|
|
|
health_summary[ep] = {"status": "ok", "version": result}
|
|
|
|
|
|
|
|
|
|
|
|
response_payload = {
|
|
|
|
|
|
"status": "ok" if overall_ok else "error",
|
|
|
|
|
|
"endpoints": health_summary,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
http_status = 200 if overall_ok else 503
|
|
|
|
|
|
return JSONResponse(content=response_payload, status_code=http_status)
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# 27. SSE route for usage broadcasts
|
|
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.get("/api/usage-stream")
|
|
|
|
|
|
async def usage_stream(request: Request):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Server‑Sent‑Events that emits a JSON payload every time the
|
|
|
|
|
|
global `usage_counts` dictionary changes.
|
|
|
|
|
|
"""
|
|
|
|
|
|
async def event_generator():
|
|
|
|
|
|
# The queue that receives *every* new snapshot
|
|
|
|
|
|
queue = await subscribe()
|
|
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
# If the client disconnects, cancel the loop
|
|
|
|
|
|
if await request.is_disconnected():
|
|
|
|
|
|
break
|
|
|
|
|
|
data = await queue.get()
|
2025-09-12 09:44:56 +02:00
|
|
|
|
if data is None:
|
|
|
|
|
|
break
|
2025-09-05 12:11:31 +02:00
|
|
|
|
# Send the data as a single SSE message
|
|
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
|
finally:
|
|
|
|
|
|
# Clean‑up: unsubscribe from the broadcast channel
|
|
|
|
|
|
await unsubscribe(queue)
|
|
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------------------------------------
|
2025-09-10 10:21:49 +02:00
|
|
|
|
# 28. FastAPI startup/shutdown events
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# -------------------------------------------------------------
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
|
|
|
async def startup_event() -> None:
|
2026-03-03 16:34:16 +01:00
|
|
|
|
global config, db, token_worker_task, flush_task
|
2025-08-26 18:19:43 +02:00
|
|
|
|
# Load YAML config (or use defaults if not present)
|
2025-11-07 13:59:16 +01:00
|
|
|
|
config_path = _config_path_from_env()
|
|
|
|
|
|
config = Config.from_yaml(config_path)
|
|
|
|
|
|
if config_path.exists():
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"Loaded configuration from {config_path}:\n"
|
|
|
|
|
|
f" endpoints={config.endpoints},\n"
|
2026-02-10 16:46:51 +01:00
|
|
|
|
f" llama_server_endpoints={config.llama_server_endpoints},\n"
|
2025-11-07 13:59:16 +01:00
|
|
|
|
f" max_concurrent_connections={config.max_concurrent_connections}"
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"No configuration file found at {config_path}. "
|
|
|
|
|
|
"Falling back to default settings."
|
|
|
|
|
|
)
|
2025-11-18 11:16:21 +01:00
|
|
|
|
|
|
|
|
|
|
# Initialize database
|
|
|
|
|
|
db = TokenDatabase(config.db_path)
|
|
|
|
|
|
await db.init_db()
|
|
|
|
|
|
|
|
|
|
|
|
# Load existing token counts from database
|
|
|
|
|
|
async for count_entry in db.load_token_counts():
|
|
|
|
|
|
endpoint = count_entry['endpoint']
|
|
|
|
|
|
model = count_entry['model']
|
|
|
|
|
|
input_tokens = count_entry['input_tokens']
|
|
|
|
|
|
output_tokens = count_entry['output_tokens']
|
|
|
|
|
|
total_tokens = count_entry['total_tokens']
|
|
|
|
|
|
|
|
|
|
|
|
token_usage_counts[endpoint][model] = total_tokens
|
|
|
|
|
|
|
2025-09-10 10:21:49 +02:00
|
|
|
|
ssl_context = ssl.create_default_context()
|
|
|
|
|
|
connector = aiohttp.TCPConnector(limit=0, limit_per_host=512, ssl=ssl_context)
|
2025-10-03 10:04:50 +02:00
|
|
|
|
timeout = aiohttp.ClientTimeout(total=60, connect=15, sock_read=120, sock_connect=15)
|
2025-09-10 10:21:49 +02:00
|
|
|
|
session = aiohttp.ClientSession(connector=connector, timeout=timeout)
|
|
|
|
|
|
|
|
|
|
|
|
app_state["connector"] = connector
|
|
|
|
|
|
app_state["session"] = session
|
2025-11-13 10:13:10 +01:00
|
|
|
|
token_worker_task = asyncio.create_task(token_worker())
|
2025-11-18 11:16:21 +01:00
|
|
|
|
flush_task = asyncio.create_task(flush_buffer())
|
2026-03-08 09:12:09 +01:00
|
|
|
|
await init_llm_cache(config)
|
2025-09-10 10:21:49 +02:00
|
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
|
|
|
async def shutdown_event() -> None:
|
2025-09-12 09:44:56 +02:00
|
|
|
|
await close_all_sse_queues()
|
2025-12-02 12:18:23 +01:00
|
|
|
|
await flush_remaining_buffers()
|
2025-10-03 10:04:50 +02:00
|
|
|
|
await app_state["session"].close()
|
2025-11-14 15:53:26 +01:00
|
|
|
|
if token_worker_task is not None:
|
2025-11-18 11:16:21 +01:00
|
|
|
|
token_worker_task.cancel()
|
|
|
|
|
|
if flush_task is not None:
|
|
|
|
|
|
flush_task.cancel()
|