diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/management.py b/api/management.py new file mode 100644 index 0000000..ac1f356 --- /dev/null +++ b/api/management.py @@ -0,0 +1,278 @@ +"""Management / observability routes. + +Read-only endpoints used by the dashboard and external monitoring: + * usage counters and token-counts breakdown, + * conversation-affinity introspection, + * endpoint health summary, + * LLM-response cache stats and invalidation, + * SSE live-stream of usage updates, + * hostname and ``/health`` probe. +""" +import asyncio +import socket +import time +from typing import Optional + +import orjson +from fastapi import APIRouter, HTTPException, Request +from starlette.responses import JSONResponse, StreamingResponse + +from cache import get_llm_cache +from config import get_config +from db import get_db +from state import ( + usage_counts, + token_usage_counts, + _affinity_map, + _affinity_lock, +) +from sse import subscribe, unsubscribe +from backends.normalize import _normalize_llama_model_name +from backends.probe import _endpoint_health + + +router = APIRouter() + + +@router.get("/api/token_counts") +async def token_counts_proxy(): + breakdown = [] + total = 0 + async for entry in get_db().load_token_counts(): + total += entry['total_tokens'] + breakdown.append({ + "endpoint": entry["endpoint"], + "model": entry["model"], + "input_tokens": entry["input_tokens"], + "output_tokens": entry["output_tokens"], + "total_tokens": entry["total_tokens"], + }) + return {"total_tokens": total, "breakdown": breakdown} + + +@router.post("/api/aggregate_time_series_days") +async def aggregate_time_series_days_proxy(request: Request): + """ + Aggregate time_series entries older than days into daily aggregates by endpoint/model/date. + """ + try: + body_bytes = await request.body() + if not body_bytes: + days = 30 + trim_old = False + else: + payload = orjson.loads(body_bytes.decode("utf-8")) + days = int(payload.get("days", 30)) + trim_old = bool(payload.get("trim_old", False)) + except Exception: + days = 30 + trim_old = False + aggregated = await get_db().aggregate_time_series_older_than(days, trim_old=trim_old) + return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated} + + +@router.post("/api/stats") +async def stats_proxy(request: Request, model: Optional[str] = None): + """ + Return token usage statistics for a specific model. + """ + try: + body_bytes = await request.body() + + if not model: + payload = orjson.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + db = get_db() + token_data = await db.get_token_counts_for_model(model) + + if not token_data: + raise HTTPException( + status_code=404, detail="No token data found for this model" + ) + + time_series = [ + entry async for entry in db.get_time_series_for_model(model) + ] + endpoint_distribution = await db.get_endpoint_distribution_for_model(model) + + return { + 'model': model, + 'input_tokens': token_data['input_tokens'], + 'output_tokens': token_data['output_tokens'], + 'total_tokens': token_data['total_tokens'], + 'time_series': time_series, + 'endpoint_distribution': endpoint_distribution, + } + + +@router.get("/api/affinity_stats") +async def affinity_stats(request: Request): + """ + Aggregate live conversation-affinity pins, one entry per pinned conversation. + Each entry exposes only the endpoint, model, and remaining TTL in seconds — + no fingerprints or content. When conversation_affinity is disabled the + `entries` list is always empty. + """ + config = get_config() + if not config.conversation_affinity: + return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []} + + now = time.monotonic() + entries: list[dict] = [] + llama_eps = set(config.llama_server_endpoints) + async with _affinity_lock: + for fp, (ep, mdl, expires_at) in list(_affinity_map.items()): + remaining = expires_at - now + if remaining <= 0: + _affinity_map.pop(fp, None) + continue + # Mirror the normalisation used by /api/ps_details so the dashboard + # can join affinity entries to PS rows by (endpoint, model). + display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl + entries.append({ + "endpoint": ep, + "model": display_model, + "remaining": round(remaining, 2), + }) + return { + "enabled": True, + "ttl": config.conversation_affinity_ttl, + "entries": entries, + } + + +@router.get("/api/usage") +async def usage_proxy(request: Request): + """ + Return a snapshot of the usage counter for each endpoint. + Useful for debugging / monitoring. + """ + return {"usage_counts": usage_counts, + "token_usage_counts": token_usage_counts} + + +@router.get("/api/config") +async def config_proxy(request: Request): + """ + Return a simple JSON object that contains the configured + Ollama endpoints and llama_server_endpoints. The front‑end uses this + to display which endpoints are being proxied and their health. + Status is "error" when either liveness (/api/version) or routing + health (/api/ps) fails — see issue #83. + """ + config = get_config() + + async def check(url: str) -> dict: + return {"url": url, **(await _endpoint_health(url, timeout=5))} + + ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints]) + llama_results = [] + if config.llama_server_endpoints: + llama_results = await asyncio.gather( + *[check(ep) for ep in config.llama_server_endpoints] + ) + + return { + "endpoints": ollama_results, + "llama_server_endpoints": llama_results, + "require_router_api_key": bool(config.router_api_key), + } + + +@router.get("/api/cache/stats") +async def cache_stats(): + """Return hit/miss counters and configuration for the LLM response cache.""" + c = get_llm_cache() + if c is None: + return {"enabled": False} + return {"enabled": True, **c.stats()} + + +@router.post("/api/cache/invalidate") +async def cache_invalidate(): + """Clear all entries from the LLM response cache and reset counters.""" + c = get_llm_cache() + if c is None: + return {"enabled": False, "cleared": False} + await c.clear() + return {"enabled": True, "cleared": True} + + +@router.get("/health") +async def health_proxy(request: Request): + """ + Health‑check endpoint for monitoring the proxy. + + * Queries each configured endpoint for both liveness and routing health: + Ollama endpoints are probed at `/api/version` AND `/api/ps`, + OpenAI-compatible endpoints at `/models`. + * Returns a JSON object containing: + - `status`: "ok" if every endpoint replied to every probe, otherwise "error". + - `endpoints`: a mapping of endpoint URL → `{status, version|detail}`. + * The HTTP status code is 200 when everything is healthy, 503 otherwise. + """ + config = get_config() + # Run all health checks in parallel. + # Ollama endpoints expose /api/version (liveness) and /api/ps (routing + # health — required by `choose_endpoint`). OpenAI-compatible endpoints + # (vLLM, llama-server, external) expose /models, which serves both + # purposes. Probing /api/version alone would miss the case where the + # Ollama process is up but /api/ps is failing — see issue #83. + all_endpoints = list(config.endpoints) + llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + all_endpoints += llama_eps_extra + + probe_results = await asyncio.gather( + *(_endpoint_health(ep) for ep in all_endpoints), + ) + + health_summary = dict(zip(all_endpoints, probe_results)) + overall_ok = all(entry.get("status") == "ok" for entry in probe_results) + + response_payload = { + "status": "ok" if overall_ok else "error", + "endpoints": health_summary, + } + + http_status = 200 if overall_ok else 503 + return JSONResponse(content=response_payload, status_code=http_status) + + +@router.get("/api/hostname") +async def get_hostname(): + """Return the hostname of the machine running the router.""" + return JSONResponse(content={"hostname": socket.gethostname()}) + + +@router.get("/api/usage-stream") +async def usage_stream(request: Request): + """ + Server‑Sent‑Events that emits a JSON payload every time the + global `usage_counts` dictionary changes. + """ + async def event_generator(): + # The queue that receives *every* new snapshot + queue = await subscribe() + try: + while True: + # If the client disconnects, cancel the loop + if await request.is_disconnected(): + break + data = await queue.get() + if data is None: + break + # Send the data as a single SSE message + yield f"data: {data}\n\n" + finally: + # Clean‑up: unsubscribe from the broadcast channel + await unsubscribe(queue) + + return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/api/ollama.py b/api/ollama.py new file mode 100644 index 0000000..e673663 --- /dev/null +++ b/api/ollama.py @@ -0,0 +1,1106 @@ +"""Ollama-native API routes (``/api/*``). + +These are the ``/api/generate``, ``/api/chat``, ``/api/embed(dings)`` and the +model-management routes (``/api/create``, ``/api/show``, ``/api/copy``, +``/api/delete``, ``/api/pull``, ``/api/push``, ``/api/version``, +``/api/tags``, ``/api/ps``, ``/api/ps_details``) that the Ollama clients +expect. The chat/generate handlers also serve OpenAI-compatible endpoints +when ``is_openai_compatible(endpoint)`` is true — in that case they +translate the request to the OpenAI Chat Completions / Completions API and +``rechunk`` the response back into Ollama wire format. +""" +import asyncio +import re +import time +from typing import Optional + +import aiohttp +import ollama +import orjson +from fastapi import APIRouter, HTTPException, Request +from starlette.responses import JSONResponse, Response, StreamingResponse + +from cache import get_llm_cache +from config import get_config +from context_window import ( + _count_message_tokens, + _trim_messages_for_context, + _calibrated_trim_target, + _endpoint_nctx, + _CTX_TRIM_SMALL_LIMIT, +) +from fingerprint import _conversation_fingerprint +from state import token_queue, default_headers +from backends.health import ( + _is_backend_connection_error, + _is_llama_model_loaded, + _is_llama_model_loaded_or_sleeping, + _mark_backend_unhealthy, +) +from backends.normalize import ( + dedupe_on_keys, + is_openai_compatible, + _normalize_llama_model_name, + _extract_llama_quant, +) +from backends.probe import fetch +from backends.sessions import _make_openai_client, get_session +from requests.chat import _make_moe_requests +from requests.messages import ( + transform_images_to_data_urls, + transform_tool_calls_to_openai, + _strip_assistant_prefill, + _strip_images_from_messages, + _accumulate_openai_tc_delta, + _build_ollama_tool_calls, +) +from requests.rechunk import rechunk +from routing import choose_endpoint, decrement_usage + + +router = APIRouter() + + +@router.post("/api/generate") +async def proxy(request: Request): + """ + Proxy a generate request to Ollama and stream the response back to the client. + """ + config = get_config() + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + prompt = payload.get("prompt") + suffix = payload.get("suffix") + system = payload.get("system") + template = payload.get("template") + context = payload.get("context") + stream = payload.get("stream") + think = payload.get("think") + raw = payload.get("raw") + _format = payload.get("format") + images = payload.get("images") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + _cache_enabled = payload.get("nomyo", {}).get("cache", False) + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + except orjson.JSONDecodeError as e: + error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted." + raise HTTPException(status_code=400, detail=error_msg) from e + + # Cache lookup — before endpoint selection so no slot is wasted on a hit + _cache = get_llm_cache() + if _cache is not None and _cache_enabled: + _cached = await _cache.get_generate(model, prompt, system or "") + if _cached is not None: + async def _serve_cached_generate(): + yield _cached + return StreamingResponse(_serve_cached_generate(), media_type="application/json") + + _affinity_key = _conversation_fingerprint(model, None, prompt) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + params = { + "prompt": prompt, + "model": model, + } + + optional_params = { + "stream": stream, + "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, + "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, + "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, + "seed": options.get("seed") if options and "seed" in options else None, + "stop": options.get("stop") if options and "stop" in options else None, + "top_p": options.get("top_p") if options and "top_p" in options else None, + "temperature": options.get("temperature") if options and "temperature" in options else None, + "suffix": suffix, + } + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = ollama.AsyncClient(host=endpoint) + + # 4. Async generator that streams data and decrements the counter + async def stream_generate_response(): + try: + if use_openai: + start_ts = time.perf_counter() + async_gen = await oclient.completions.create(**params) + else: + async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) + if stream == True: + content_parts: list[str] = [] + async for chunk in async_gen: + if use_openai: + chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + if hasattr(chunk, "model_dump_json"): + json_line = chunk.model_dump_json() + else: + json_line = orjson.dumps(chunk) + # Accumulate and store cache on done chunk — before yield so it always runs + if _cache is not None and _cache_enabled: + if getattr(chunk, "response", None): + content_parts.append(chunk.response) + if getattr(chunk, "done", False): + assembled = orjson.dumps({ + k: v for k, v in { + "model": getattr(chunk, "model", model), + "response": "".join(content_parts), + "done": True, + "done_reason": getattr(chunk, "done_reason", "stop") or "stop", + "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), + "eval_count": getattr(chunk, "eval_count", None), + "total_duration": getattr(chunk, "total_duration", None), + "eval_duration": getattr(chunk, "eval_duration", None), + }.items() if v is not None + }) + b"\n" + try: + await _cache.set_generate(model, prompt, system or "", assembled) + except Exception as _ce: + print(f"[cache] set_generate (streaming) failed: {_ce}") + yield json_line.encode("utf-8") + b"\n" + else: + if use_openai: + response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) + response = response.model_dump_json() + else: + response = async_gen.model_dump_json() + prompt_tok = async_gen.prompt_eval_count or 0 + comp_tok = async_gen.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + json_line = ( + response + if hasattr(async_gen, "model_dump_json") + else orjson.dumps(async_gen) + ) + cache_bytes = json_line.encode("utf-8") + b"\n" + yield cache_bytes + # Cache non-streaming response + if _cache is not None and _cache_enabled: + try: + await _cache.set_generate(model, prompt, system or "", cache_bytes) + except Exception as _ce: + print(f"[cache] set_generate (non-streaming) failed: {_ce}") + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + # 5. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_generate_response(), + media_type="application/json", + ) + + +@router.post("/api/chat") +async def chat_proxy(request: Request): + """ + Proxy a chat request to Ollama and stream the endpoint reply. + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + messages = payload.get("messages") + tools = payload.get("tools") + stream = payload.get("stream") + think = payload.get("think") + _format = payload.get("format") + keep_alive = payload.get("keep_alive") + options = payload.get("options") + logprobs = payload.get("logprobs") + top_logprobs = payload.get("top_logprobs") + _cache_enabled = payload.get("nomyo", {}).get("cache", False) + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not isinstance(messages, list): + raise HTTPException( + status_code=400, detail="Missing or invalid 'messages' field (must be a list)" + ) + if options is not None and not isinstance(options, dict): + raise HTTPException( + status_code=400, detail="`options` must be a JSON object" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Cache lookup — before endpoint selection, always bypassed for MOE + _is_moe = model.startswith("moe-") + _cache = get_llm_cache() + # Normalise model name for cache key: strip ":latest" suffix here so that + # get_chat and set_chat use the same model string regardless of when the + # strip happens further down (line ~1793 strips it for OpenAI endpoints). + _cache_model = model[: -len(":latest")] if model.endswith(":latest") else model + # Snapshot original messages before any OpenAI-format transformation so that + # get_chat and set_chat always use the same key regardless of backend type. + _cache_messages = messages + if _cache is not None and not _is_moe and _cache_enabled: + _cached = await _cache.get_chat("ollama_chat", _cache_model, messages) + if _cached is not None: + async def _serve_cached_chat(): + yield _cached + return StreamingResponse( + _serve_cached_chat(), + media_type="application/x-ndjson" if stream else "application/json", + ) + + # 2. Endpoint logic + if model.startswith("moe-"): + model = model.split("moe-")[1] + opt = True + else: + opt = False + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + if messages: + if any("images" in m for m in messages): + messages = await asyncio.to_thread(transform_images_to_data_urls, messages) + messages = transform_tool_calls_to_openai(messages) + messages = _strip_assistant_prefill(messages) + params = { + "messages": messages, + "model": model, + } + optional_params = { + "tools": tools, + "stream": stream, + "stream_options": {"include_usage": True} if stream else None, + "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, + "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, + "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, + "seed": options.get("seed") if options and "seed" in options else None, + "stop": options.get("stop") if options and "stop" in options else None, + "top_p": options.get("top_p") if options and "top_p" in options else None, + "temperature": options.get("temperature") if options and "temperature" in options else None, + "logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None), + "top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None), + "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None + } + params.update({k: v for k, v in optional_params.items() if v is not None}) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = ollama.AsyncClient(host=endpoint) + # For OpenAI endpoints: make the API call in handler scope + # (try/except inside async generators is unreliable with Starlette's streaming) + start_ts = None + async_gen = None + if use_openai: + start_ts = time.perf_counter() + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) + _pre_est = _count_message_tokens(params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + params = {**params, "messages": _pre_trimmed} + try: + async_gen = await oclient.chat.completions.create(**params) + except Exception as e: + _e_str = str(e) + print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}") + if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + if not n_ctx_limit: + _m = re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit + msgs_to_trim = params.get("messages", []) + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") + try: + async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools") + params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_backend_connection_error(e): + print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) + await _mark_backend_unhealthy(endpoint, model, _e_str) + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} + async_gen = await oclient.chat.completions.create(**params) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + + # 3. Async generator that streams chat data and decrements the counter + async def stream_chat_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + if use_openai: + _async_gen = async_gen # established in handler scope above + else: + if opt == True: + # Use the dedicated MOE helper function + _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) + else: + _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) + if stream == True: + tc_acc = {} # accumulate OpenAI tool-call deltas across chunks + content_parts: list[str] = [] + async for chunk in _async_gen: + if use_openai: + _accumulate_openai_tc_delta(chunk, tc_acc) + chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) + # Inject fully-accumulated tool calls only into the final chunk + if chunk.done and tc_acc and chunk.message: + chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc) + # `chunk` can be a dict or a pydantic model – dump to JSON safely + prompt_tok = chunk.prompt_eval_count or 0 + comp_tok = chunk.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + if hasattr(chunk, "model_dump_json"): + json_line = chunk.model_dump_json() + else: + json_line = orjson.dumps(chunk) + # Accumulate and store cache on done chunk — before yield so it always runs + # Works for both Ollama-native and OpenAI-compatible backends; chunks are + # already converted to Ollama format by rechunk before this point. + if getattr(chunk, "done", False): + # Detect context exhaustion mid-generation for small-ctx models + _dr = getattr(chunk, "done_reason", None) + # Only cache when no max_tokens limit was set — otherwise + # finish_reason=length might just mean max_tokens was hit, + # not that the context window was exhausted. + _req_max_tok = ( + params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") + if use_openai else + (options.get("num_predict") if options else None) + ) + if _dr == "length" and not _req_max_tok: + _pt = getattr(chunk, "prompt_eval_count", 0) or 0 + _ct = getattr(chunk, "eval_count", 0) or 0 + _inferred_nctx = _pt + _ct + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) + if _cache is not None and not _is_moe and _cache_enabled: + if chunk.message and getattr(chunk.message, "content", None): + content_parts.append(chunk.message.content) + if getattr(chunk, "done", False): + assembled = orjson.dumps({ + k: v for k, v in { + "model": getattr(chunk, "model", model), + "created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)), + "message": {"role": "assistant", "content": "".join(content_parts)}, + "done": True, + "done_reason": getattr(chunk, "done_reason", "stop") or "stop", + "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), + "eval_count": getattr(chunk, "eval_count", None), + "total_duration": getattr(chunk, "total_duration", None), + "eval_duration": getattr(chunk, "eval_duration", None), + }.items() if v is not None + }) + b"\n" + try: + await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled) + except Exception as _ce: + print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}") + yield json_line.encode("utf-8") + b"\n" + else: + if use_openai: + response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) + response = response.model_dump_json() + else: + response = _async_gen.model_dump_json() + prompt_tok = _async_gen.prompt_eval_count or 0 + comp_tok = _async_gen.eval_count or 0 + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + json_line = ( + response + if hasattr(_async_gen, "model_dump_json") + else orjson.dumps(_async_gen) + ) + cache_bytes = json_line.encode("utf-8") + b"\n" + yield cache_bytes + # Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends) + if _cache is not None and not _is_moe and _cache_enabled: + try: + await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes) + except Exception as _ce: + print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}") + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + # 4. Return a StreamingResponse backed by the generator + media_type = "application/x-ndjson" if stream else "application/json" + return StreamingResponse( + stream_chat_response(), + media_type=media_type, + ) + + +@router.post("/api/embeddings") +async def embedding_proxy(request: Request): + """ + Proxy an embedding request to Ollama and reply with embeddings. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + prompt = payload.get("prompt") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint, tracking_model = await choose_endpoint(model) + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = ollama.AsyncClient(host=endpoint) + # 3. Async generator that streams embedding data and decrements the counter + async def stream_embedding_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + if use_openai: + async_gen = await client.embeddings.create(input=prompt, model=model) + async_gen = rechunk.openai_embeddings2ollama(async_gen) + else: + async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) + if hasattr(async_gen, "model_dump_json"): + json_line = async_gen.model_dump_json() + else: + json_line = orjson.dumps(async_gen) + yield json_line.encode("utf-8") + b"\n" + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + # 5. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_embedding_response(), + media_type="application/json", + ) + + +@router.post("/api/embed") +async def embed_proxy(request: Request): + """ + Proxy an embed request to Ollama and reply with embeddings. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + _input = payload.get("input") + truncate = payload.get("truncate") + options = payload.get("options") + keep_alive = payload.get("keep_alive") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not _input: + raise HTTPException( + status_code=400, detail="Missing required field 'input'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint, tracking_model = await choose_endpoint(model) + use_openai = is_openai_compatible(endpoint) + if use_openai: + if ":latest" in model: + model = model.split(":latest") + model = model[0] + client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) + else: + client = ollama.AsyncClient(host=endpoint) + # 3. Async generator that streams embed data and decrements the counter + async def stream_embedding_response(): + try: + # The chat method returns a generator of dicts (or GenerateResponse) + if use_openai: + async_gen = await client.embeddings.create(input=_input, model=model) + async_gen = rechunk.openai_embed2ollama(async_gen, model) + else: + async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive) + if hasattr(async_gen, "model_dump_json"): + json_line = async_gen.model_dump_json() + else: + json_line = orjson.dumps(async_gen) + yield json_line.encode("utf-8") + b"\n" + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_embedding_response(), + media_type="application/json", + ) + + +@router.post("/api/create") +async def create_proxy(request: Request): + """ + Proxy a create request to all Ollama endpoints and reply with deduplicated status. + """ + config = get_config() + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + quantize = payload.get("quantize") + from_ = payload.get("from") + files = payload.get("files") + adapters = payload.get("adapters") + template = payload.get("template") + license = payload.get("license") + system = payload.get("system") + parameters = payload.get("parameters") + messages = payload.get("messages") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not from_ and not files: + raise HTTPException( + status_code=400, detail="You need to provide either from_ or files parameter!" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + status_lists = [] + + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False) + status_lists.append(create) + + combined_status = [] + for status_list in status_lists: + combined_status += status_list + + final_status = list(dict.fromkeys(combined_status)) + + return dict(final_status) + + +@router.post("/api/show") +async def show_proxy(request: Request, model: Optional[str] = None): + """ + Proxy a model show request to Ollama and reply with ShowResponse. + + """ + try: + body_bytes = await request.body() + + if not model: + payload = orjson.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint, _ = await choose_endpoint(model, reserve=False) + + client = ollama.AsyncClient(host=endpoint) + + # 3. Proxy a simple show request + show = await client.show(model=model) + + # 4. Return ShowResponse + return show + + +@router.post("/api/copy") +async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None): + """ + Proxy a model copy request to each Ollama endpoint and reply with Status Code. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + + if not source and not destination: + payload = orjson.loads(body_bytes.decode("utf-8")) + src = payload.get("source") + dst = payload.get("destination") + else: + src = source + dst = destination + + if not src: + raise HTTPException( + status_code=400, detail="Missing required field 'source'" + ) + if not dst: + raise HTTPException( + status_code=400, detail="Missing required field 'destination'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 3. Iterate over all endpoints to copy the model on each endpoint + status_list = [] + + for endpoint in config.endpoints: + if "/v1" not in endpoint: + client = ollama.AsyncClient(host=endpoint) + # 4. Proxy a simple copy request + copy = await client.copy(source=src, destination=dst) + status_list.append(copy.status) + + # 4. Return with 200 OK if all went well, 404 if a single endpoint failed + return Response(status_code=404 if 404 in status_list else 200) + + +@router.delete("/api/delete") +async def delete_proxy(request: Request, model: Optional[str] = None): + """ + Proxy a model delete request to each Ollama endpoint and reply with Status Code. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + + if not model: + payload = orjson.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Iterate over all endpoints to delete the model on each endpoint + status_list = [] + + for endpoint in config.endpoints: + if "/v1" not in endpoint: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple copy request + copy = await client.delete(model=model) + status_list.append(copy.status) + + # 4. Return 200 0K, if a single enpoint fails, respond with 404 + return Response(status_code=404 if 404 in status_list else 200) + + +@router.post("/api/pull") +async def pull_proxy(request: Request, model: Optional[str] = None): + """ + Proxy a pull request to all Ollama endpoint and report status back. + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + + if not model: + payload = orjson.loads(body_bytes.decode("utf-8")) + model = payload.get("model") + insecure = payload.get("insecure") + else: + insecure = None + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Iterate over all endpoints to pull the model + status_list = [] + + for endpoint in config.endpoints: + if "/v1" not in endpoint: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple pull request + pull = await client.pull(model=model, insecure=insecure, stream=False) + status_list.append(pull) + + combined_status = [] + for status in status_list: + combined_status += status + + # 4. Report back a deduplicated status message + final_status = list(dict.fromkeys(combined_status)) + + return dict(final_status) + + +@router.post("/api/push") +async def push_proxy(request: Request): + """ + Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies. + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + insecure = payload.get("insecure") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Iterate over all endpoints + status_list = [] + + for endpoint in config.endpoints: + client = ollama.AsyncClient(host=endpoint) + # 3. Proxy a simple push request + push = await client.push(model=model, insecure=insecure, stream=False) + status_list.append(push) + + combined_status = [] + for status in status_list: + combined_status += status + + # 4. Report a deduplicated status + final_status = list(dict.fromkeys(combined_status)) + + return dict(final_status) + + +@router.get("/api/version") +async def version_proxy(request: Request): + """ + Proxy a version request to Ollama and reply lowest version of all endpoints. + + """ + config = get_config() + # 1. Query all endpoints for version + tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] + all_versions_raw = await asyncio.gather(*tasks) + + # Filter out non-string values (e.g., empty lists from failed/timeout responses) + all_versions = [v for v in all_versions_raw if isinstance(v, str) and v] + + if not all_versions: + raise HTTPException(status_code=503, detail="No valid version response from any endpoint") + + def version_key(v): + return tuple(map(int, v.split('.'))) + + # 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility + return JSONResponse( + content={"version": str(min(all_versions, key=version_key))}, + status_code=200, + ) + + +@router.get("/api/tags") +async def tags_proxy(request: Request): + """ + Proxy a tags request to Ollama endpoints and reply with a unique list of all models. + + """ + config = get_config() + + # 1. Query all endpoints for models + tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] + tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep] + # Also query llama-server endpoints not already covered by config.endpoints + llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] + tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags] + all_models = await asyncio.gather(*tasks) + + models = {'models': []} + for modellist in all_models: + for model in modellist: + if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id + model['model'] = model['id'] + ":latest" + else: + model['id'] = model['model'] + if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys + model['name'] = model['model'] + else: + model['id'] = model['model'] + models['models'] += modellist + + # 2. Return a JSONResponse with a deduplicated list of unique models for inference + return JSONResponse( + content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])}, + status_code=200, + ) + + +@router.get("/api/ps") +async def ps_proxy(request: Request): + """ + Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models. + + For Ollama endpoints: queries /api/ps + For llama-server endpoints: queries /v1/models with status.value == "loaded" + """ + config = get_config() + # 1. Query Ollama endpoints for running models via /api/ps + ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] + # 2. Query llama-server endpoints for loaded models via /v1/models + # Also query endpoints from llama_server_endpoints that may not be in config.endpoints + all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) + llama_tasks = [ + fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) + for ep in all_llama_endpoints + ] + + ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] + llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else [] + + models = {'models': []} + # Add Ollama models (if any) + if ollama_loaded: + for modellist in ollama_loaded: + models['models'] += modellist + # Add llama-server models (filter for loaded only, if any) + if llama_loaded: + for modellist in llama_loaded: + loaded_models = [item for item in modellist if _is_llama_model_loaded(item)] + # Convert llama-server format to Ollama-like format for consistency + for item in loaded_models: + raw_id = item.get("id", "") + normalized = _normalize_llama_model_name(raw_id) + quant = _extract_llama_quant(raw_id) + models['models'].append({ + "name": normalized, + "id": normalized, + "digest": "", + "status": item.get("status"), + "details": {"quantization_level": quant} if quant else {} + }) + + # 3. Return a JSONResponse with deduplicated currently deployed models + # Deduplicate on 'name' rather than 'digest': llama-server models always + # have digest="" so deduping on digest collapses all of them to one entry. + return JSONResponse( + content={"models": dedupe_on_keys(models['models'], ['name'])}, + status_code=200, + ) + + +@router.get("/api/ps_details") +async def ps_details_proxy(request: Request): + """ + Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances. + This keeps /api/ps backward compatible while providing richer data. + + For Ollama endpoints: queries /api/ps + For llama-server endpoints: queries /v1/models with status info + """ + config = get_config() + # 1. Query Ollama endpoints via /api/ps + ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep] + # 2. Query llama-server endpoints via /v1/models + # Also query endpoints from llama_server_endpoints that may not be in config.endpoints + all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) + llama_tasks = [ + (ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)) + for ep in all_llama_endpoints + ] + + ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else [] + llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else [] + + models: list[dict] = [] + + # Add Ollama models with endpoint info (if any) + if ollama_loaded: + for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded): + for model in modellist: + if isinstance(model, dict): + model_with_endpoint = dict(model) + model_with_endpoint["endpoint"] = endpoint + models.append(model_with_endpoint) + + # Add llama-server models with endpoint info and full status metadata (if any) + if llama_loaded: + # Collect (endpoint, raw_id) pairs to fetch /props in parallel + props_requests: list[tuple[str, str]] = [] + llama_models_pending: list[dict] = [] + + for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded): + # Include sleeping models too so _fetch_llama_props can unload them + loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)] + for item in loaded_models: + if isinstance(item, dict) and item.get("id"): + raw_id = item["id"] + normalized = _normalize_llama_model_name(raw_id) + quant = _extract_llama_quant(raw_id) + model_with_endpoint = { + "name": normalized, + "id": normalized, + "original_name": raw_id, + "digest": "", + "details": {"quantization_level": quant} if quant else {}, + "endpoint": endpoint, + "status": item.get("status"), + "created": item.get("created"), + "owned_by": item.get("owned_by") + } + # Include full llama-server status details (args, preset) + status_info = item.get("status", {}) + if isinstance(status_info, dict): + model_with_endpoint["llama_status_args"] = status_info.get("args") + model_with_endpoint["llama_status_preset"] = status_info.get("preset") + llama_models_pending.append(model_with_endpoint) + props_requests.append((endpoint, raw_id)) + + # Fetch /props for each llama-server model to get context length (n_ctx) + # and unload sleeping models automatically + async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]: + client: aiohttp.ClientSession = get_session(endpoint) + base_url = endpoint.rstrip("/").removesuffix("/v1") + props_url = f"{base_url}/props?model={model_id}" + headers = None + api_key = config.api_keys.get(endpoint) + if api_key: + headers = {"Authorization": f"Bearer {api_key}"} + try: + async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: + if resp.status == 200: + data = await resp.json() + dgs = data.get("default_generation_settings", {}) + n_ctx = dgs.get("n_ctx") + is_sleeping = data.get("is_sleeping", False) + # Embedding models have no sampling params in default_generation_settings + is_generation = "temperature" in dgs + + if is_sleeping: + unload_url = f"{base_url}/models/unload" + try: + async with client.post( + unload_url, + json={"model": model_id}, + headers=headers, + ) as unload_resp: + print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}") + except Exception as ue: + print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") + + return n_ctx, is_sleeping, is_generation + except Exception as e: + print(f"[ps_details] Failed to fetch props from {props_url}: {e}") + return None, False, False + + props_results = await asyncio.gather( + *[_fetch_llama_props(ep, mid) for ep, mid in props_requests] + ) + + for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results): + if n_ctx is not None: + model_dict["context_length"] = n_ctx + if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: + normalized = _normalize_llama_model_name(raw_id) + _endpoint_nctx[(ep, normalized)] = n_ctx + print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True) + if not is_sleeping: + models.append(model_dict) + + return JSONResponse(content={"models": models}, status_code=200) diff --git a/api/openai.py b/api/openai.py new file mode 100644 index 0000000..4662f50 --- /dev/null +++ b/api/openai.py @@ -0,0 +1,804 @@ +"""OpenAI-compatible routes (``/v1/embeddings``, ``/v1/chat/completions``, +``/v1/completions``, ``/v1/models``, ``/v1/rerank`` and ``/rerank``). + +The chat-completions and completions handlers carry the full reactive-trim +logic for ``exceed_context_size_error`` plus connection-failure rerouting +(``_mark_backend_unhealthy``). The streaming branches assemble cached +responses on the fly so caching works for both streaming and non-streaming +clients. +""" +import asyncio +import base64 +import math + +import aiohttp +import orjson +from fastapi import APIRouter, HTTPException, Request +from starlette.responses import JSONResponse, StreamingResponse + +from cache import get_llm_cache, openai_nonstream_to_sse +from config import get_config +from context_window import ( + _count_message_tokens, + _trim_messages_for_context, + _calibrated_trim_target, + _endpoint_nctx, + _CTX_TRIM_SMALL_LIMIT, +) +from fingerprint import _conversation_fingerprint +from security import _mask_secrets +from state import token_queue, app_state, default_headers +from backends.health import _is_backend_connection_error, _mark_backend_unhealthy +from backends.normalize import ( + dedupe_on_keys, + ep2base, + is_ext_openai_endpoint, + is_openai_compatible, + _normalize_llama_model_name, +) +from backends.probe import fetch +from backends.sessions import _make_openai_client, get_session +from requests.messages import _strip_assistant_prefill, _strip_images_from_messages +from requests.rechunk import rechunk +from routing import choose_endpoint, decrement_usage + + +router = APIRouter() + + +@router.post("/v1/embeddings") +async def openai_embedding_proxy(request: Request): + """ + Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + doc = payload.get("input") + + # Normalize multimodal input: extract only text parts for embedding models + if isinstance(doc, list): + normalized = [] + for item in doc: + if isinstance(item, dict): + # Multimodal content part - extract text only, skip images + if item.get("type") == "text": + normalized.append(item.get("text", "")) + # Skip image_url and other non-text types + else: + normalized.append(item) + doc = normalized if len(normalized) != 1 else normalized[0] + elif isinstance(doc, dict) and doc.get("type") == "text": + doc = doc.get("text", "") + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not doc: + raise HTTPException( + status_code=400, detail="Missing required field 'input'" + ) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # 2. Endpoint logic + endpoint, tracking_model = await choose_endpoint(model) + if is_openai_compatible(endpoint): + api_key = config.api_keys.get(endpoint, "no-key") + else: + api_key = "ollama" + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key) + + try: + async_gen = await oclient.embeddings.create(input=doc, model=model) + result = async_gen.model_dump() + for item in result.get("data", []): + emb = item.get("embedding") + if emb: + item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb] + return JSONResponse(content=result) + finally: + await decrement_usage(endpoint, tracking_model) + + +@router.post("/v1/chat/completions") +async def openai_chat_completions_proxy(request: Request): + """ + Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + messages = payload.get("messages") + frequency_penalty = payload.get("frequency_penalty") + presence_penalty = payload.get("presence_penalty") + response_format = payload.get("response_format") + seed = payload.get("seed") + stop = payload.get("stop") + stream = payload.get("stream") + stream_options = payload.get("stream_options") + temperature = payload.get("temperature") + top_p = payload.get("top_p") + max_tokens = payload.get("max_tokens") + max_completion_tokens = payload.get("max_completion_tokens") + tools = payload.get("tools") + logprobs = payload.get("logprobs") + top_logprobs = payload.get("top_logprobs") + _cache_enabled = payload.get("nomyo", {}).get("cache", False) + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not isinstance(messages, list): + raise HTTPException( + status_code=400, detail="Missing required field 'messages' (must be a list)" + ) + + if ":latest" in model: + model = model.split(":latest") + model = model[0] + + messages = _strip_assistant_prefill(messages) + params = { + "messages": messages, + "model": model, + } + + optional_params = { + "tools": tools, + "response_format": response_format, + "stream_options": stream_options or {"include_usage": True }, + "max_completion_tokens": max_completion_tokens, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "seed": seed, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "stop": stop, + "stream": stream, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + } + + params.update({k: v for k, v in optional_params.items() if v is not None}) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Reject unsupported image formats (SVG) before doing any work + for _msg in messages: + for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []: + if _item.get("type") == "image_url": + _url = (_item.get("image_url") or {}).get("url", "") + if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"): + raise HTTPException( + status_code=400, + detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.", + ) + + # Cache lookup — before endpoint selection + _cache = get_llm_cache() + if _cache is not None and _cache_enabled: + _cached = await _cache.get_chat("openai_chat", model, messages) + if _cached is not None: + if stream: + _sse = openai_nonstream_to_sse(_cached, model) + async def _serve_cached_ochat_stream(): + yield _sse + return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream") + else: + async def _serve_cached_ochat_json(): + yield _cached + return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json") + + # 2. Endpoint logic + _affinity_key = _conversation_fingerprint(model, messages, None) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + # 3. Helpers and API call — done in handler scope so try/except works reliably + async def _normalize_images_in_messages(msgs: list) -> list: + """Fetch remote image URLs and convert them to base64 data URLs so + Ollama/llama-server can handle them without making outbound HTTP requests.""" + resolved = [] + for msg in msgs: + content = msg.get("content") + if not isinstance(content, list): + resolved.append(msg) + continue + new_content = [] + for item in content: + if item.get("type") == "image_url": + url = (item.get("image_url") or {}).get("url", "") + if url and not url.startswith("data:"): + try: + http: aiohttp.ClientSession = app_state["session"] + async with http.get(url) as resp: + ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip() + img_bytes = await resp.read() + b64 = base64.b64encode(img_bytes).decode("utf-8") + new_content.append({ + "type": "image_url", + "image_url": {"url": f"data:{ctype};base64,{b64}"} + }) + except Exception as _ie: + print(f"[image] Failed to fetch image URL: {_ie}") + new_content.append(item) + else: + new_content.append(item) + else: + new_content.append(item) + resolved.append({**msg, "content": new_content}) + return resolved + + # Make the API call in handler scope — try/except inside async generators is unreliable + # with Starlette's streaming machinery, so we resolve errors here before the generator starts. + send_params = params + if not is_ext_openai_endpoint(endpoint): + resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) + send_params = {**params, "messages": resolved_msgs} + # Proactive trim: only for small-ctx models we've already seen run out of space + _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model + _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) + if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: + _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) + _pre_est = _count_message_tokens(send_params.get("messages", [])) + if _pre_est > _pre_target: + _pre_msgs = send_params.get("messages", []) + _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) + _dropped = len(_pre_msgs) - len(_pre_trimmed) + print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) + send_params = {**send_params, "messages": _pre_trimmed} + try: + async_gen = await oclient.chat.completions.create(**send_params) + except Exception as e: + _e_str = str(e) + _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str + print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) + if "does not support tools" in _e_str: + # Model doesn't support tools — retry without them + print(f"[ochat] retry: no tools", flush=True) + try: + params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} + async_gen = await oclient.chat.completions.create(**params_without_tools) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_ctx_err: + # Backend context limit hit — apply sliding-window trim (context-shift at message level) + err_body = getattr(e, "body", {}) or {} + err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} + n_ctx_limit = err_detail.get("n_ctx", 0) + actual_tokens = err_detail.get("n_prompt_tokens", 0) + # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) + if not n_ctx_limit: + import re as _re + _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) + if _m: + n_ctx_limit = int(_m.group(1)) + _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) + if _m: + actual_tokens = int(_m.group(1)) + print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) + if not n_ctx_limit: + await decrement_usage(endpoint, tracking_model) + raise + if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = n_ctx_limit + + msgs_to_trim = send_params.get("messages", []) + try: + cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) + trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) + except Exception as _helper_exc: + print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) + await decrement_usage(endpoint, tracking_model) + raise + dropped = len(msgs_to_trim) - len(trimmed_messages) + print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-1 ok", flush=True) + except Exception as e2: + _e2_str = str(e2) + if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: + # Still too large — tool definitions likely consuming too many tokens, strip them too + print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) + params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} + try: + async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) + print(f"[ctx-trim] retry-2 ok", flush=True) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + elif _is_backend_connection_error(e): + # Upstream connection failed (e.g. llama-server in router mode + # whose delegated worker died). Mark (endpoint, model) so the + # next request reroutes; the client will retry this one. + print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) + await _mark_backend_unhealthy(endpoint, model, _e_str) + await decrement_usage(endpoint, tracking_model) + raise + elif "image input is not supported" in _e_str: + # Model doesn't support images — strip and retry + print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") + try: + async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) + except Exception: + await decrement_usage(endpoint, tracking_model) + raise + else: + await decrement_usage(endpoint, tracking_model) + raise + + # 4. Async generator — only streams the already-established async_gen + async def stream_ochat_response(): + try: + if stream == True: + content_parts: list[str] = [] + usage_snapshot: dict = {} + async for chunk in async_gen: + data = ( + chunk.model_dump_json() + if hasattr(chunk, "model_dump_json") + else orjson.dumps(chunk) + ) + if chunk.choices: + delta = chunk.choices[0].delta + has_content = delta.content is not None + has_reasoning = ( + getattr(delta, "reasoning_content", None) is not None + or getattr(delta, "reasoning", None) is not None + ) + has_tool_calls = getattr(delta, "tool_calls", None) is not None + if has_content or has_reasoning or has_tool_calls: + yield f"data: {data}\n\n".encode("utf-8") + if has_content and delta.content: + content_parts.append(delta.content) + elif chunk.usage is not None: + # Forward the usage-only final chunk (e.g. from llama-server) + yield f"data: {data}\n\n".encode("utf-8") + prompt_tok = 0 + comp_tok = 0 + if chunk.usage is not None: + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok} + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + # Detect context exhaustion mid-generation for small-ctx models. + # Guard: skip if max_tokens was set in the request — finish_reason=length + # could just mean the caller's token budget was exhausted, not the context window. + _req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens") + if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok: + _inferred_nctx = (prompt_tok + comp_tok) or 0 + if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: + _endpoint_nctx[(endpoint, model)] = _inferred_nctx + print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) + # Cache assembled streaming response — before [DONE] so it always runs + if _cache is not None and _cache_enabled and content_parts: + assembled = orjson.dumps({ + "model": model, + "choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}], + **({"usage": usage_snapshot} if usage_snapshot else {}), + }) + b"\n" + try: + await _cache.set_chat("openai_chat", model, messages, assembled) + except Exception as _ce: + print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}") + yield b"data: [DONE]\n\n" + else: + prompt_tok = 0 + comp_tok = 0 + if async_gen.usage is not None: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + json_line = ( + async_gen.model_dump_json() + if hasattr(async_gen, "model_dump_json") + else orjson.dumps(async_gen) + ) + cache_bytes = json_line.encode("utf-8") + b"\n" + yield cache_bytes + # Cache non-streaming response + if _cache is not None and _cache_enabled: + try: + await _cache.set_chat("openai_chat", model, messages, cache_bytes) + except Exception as _ce: + print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}") + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_ochat_response(), + media_type="text/event-stream" if stream else "application/json", + ) + + +@router.post("/v1/completions") +async def openai_completions_proxy(request: Request): + """ + Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. + + """ + config = get_config() + # 1. Parse and validate request + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + prompt = payload.get("prompt") + frequency_penalty = payload.get("frequency_penalty") + presence_penalty = payload.get("presence_penalty") + seed = payload.get("seed") + stop = payload.get("stop") + stream = payload.get("stream") + stream_options = payload.get("stream_options") + temperature = payload.get("temperature") + top_p = payload.get("top_p") + max_tokens = payload.get("max_tokens") + max_completion_tokens = payload.get("max_completion_tokens") + suffix = payload.get("suffix") + _cache_enabled = payload.get("nomyo", {}).get("cache", False) + + if not model: + raise HTTPException( + status_code=400, detail="Missing required field 'model'" + ) + if not prompt: + raise HTTPException( + status_code=400, detail="Missing required field 'prompt'" + ) + + if ":latest" in model: + model = model.split(":latest") + model = model[0] + + params = { + "prompt": prompt, + "model": model, + } + + optional_params = { + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "seed": seed, + "stop": stop, + "stream": stream, + "stream_options": stream_options or {"include_usage": True }, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "max_completion_tokens": max_completion_tokens, + "suffix": suffix + } + + params.update({k: v for k, v in optional_params.items() if v is not None}) + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Cache lookup — completions prompt mapped to a single-turn messages list + _cache = get_llm_cache() + _compl_messages = [{"role": "user", "content": prompt}] + if _cache is not None and _cache_enabled: + _cached = await _cache.get_chat("openai_completions", model, _compl_messages) + if _cached is not None: + if stream: + _sse = openai_nonstream_to_sse(_cached, model) + async def _serve_cached_ocompl_stream(): + yield _sse + return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream") + else: + async def _serve_cached_ocompl_json(): + yield _cached + return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json") + + # 2. Endpoint logic + _affinity_key = _conversation_fingerprint(model, None, prompt) + endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) + oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) + + # 3. Async generator that streams completions data and decrements the counter + # Make the API call in handler scope (try/except inside async generators is unreliable) + try: + async_gen = await oclient.completions.create(**params) + except Exception as e: + if _is_backend_connection_error(e): + print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) + await _mark_backend_unhealthy(endpoint, model, str(e)) + await decrement_usage(endpoint, tracking_model) + raise + + async def stream_ocompletions_response(model=model): + try: + if stream == True: + text_parts: list[str] = [] + usage_snapshot: dict = {} + async for chunk in async_gen: + data = ( + chunk.model_dump_json() + if hasattr(chunk, "model_dump_json") + else orjson.dumps(chunk) + ) + if chunk.choices: + choice = chunk.choices[0] + has_text = getattr(choice, "text", None) is not None + has_reasoning = ( + getattr(choice, "reasoning_content", None) is not None + or getattr(choice, "reasoning", None) is not None + ) + if has_text or has_reasoning or choice.finish_reason is not None: + yield f"data: {data}\n\n".encode("utf-8") + if has_text and choice.text: + text_parts.append(choice.text) + elif chunk.usage is not None: + # Forward the usage-only final chunk (e.g. from llama-server) + yield f"data: {data}\n\n".encode("utf-8") + prompt_tok = 0 + comp_tok = 0 + if chunk.usage is not None: + prompt_tok = chunk.usage.prompt_tokens or 0 + comp_tok = chunk.usage.completion_tokens or 0 + usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok} + else: + llama_usage = rechunk.extract_usage_from_llama_timings(chunk) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + # Cache assembled streaming response — before [DONE] so it always runs + if _cache is not None and _cache_enabled and text_parts: + assembled = orjson.dumps({ + "model": model, + "choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}], + **({"usage": usage_snapshot} if usage_snapshot else {}), + }) + b"\n" + try: + await _cache.set_chat("openai_completions", model, _compl_messages, assembled) + except Exception as _ce: + print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}") + # Final DONE event + yield b"data: [DONE]\n\n" + else: + prompt_tok = 0 + comp_tok = 0 + if async_gen.usage is not None: + prompt_tok = async_gen.usage.prompt_tokens or 0 + comp_tok = async_gen.usage.completion_tokens or 0 + else: + llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) + if llama_usage: + prompt_tok, comp_tok = llama_usage + if prompt_tok != 0 or comp_tok != 0: + await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) + json_line = ( + async_gen.model_dump_json() + if hasattr(async_gen, "model_dump_json") + else orjson.dumps(async_gen) + ) + cache_bytes = json_line.encode("utf-8") + b"\n" + yield cache_bytes + # Cache non-streaming response + if _cache is not None and _cache_enabled: + try: + await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes) + except Exception as _ce: + print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}") + + finally: + # Ensure counter is decremented even if an exception occurs + await decrement_usage(endpoint, tracking_model) + + # 4. Return a StreamingResponse backed by the generator + return StreamingResponse( + stream_ocompletions_response(), + media_type="text/event-stream" if stream else "application/json", + ) + + +@router.get("/v1/models") +async def openai_models_proxy(request: Request): + """ + Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models. + + For Ollama endpoints: queries /api/tags (all models) + For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" + """ + config = get_config() + # 1. Query Ollama endpoints for all models via /api/tags + ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] + # 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models + ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)] + # 3. Query llama-server endpoints for loaded models via /v1/models + # Also query endpoints from llama_server_endpoints that may not be in config.endpoints + all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) + llama_tasks = [ + fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) + for ep in all_llama_endpoints + ] + + ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] + ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else [] + llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else [] + + models = {'data': []} + + # Add Ollama models (if any) + if ollama_models: + for modellist in ollama_models: + for model in modellist: + if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name + model['id'] = model.get('name', model.get('id', '')) + else: + model['name'] = model['id'] + models['data'].append(model) + + # Add external OpenAI models (if any) + if ext_openai_models: + for modellist in ext_openai_models: + for model in modellist: + if not "id" in model.keys(): + model['id'] = model.get('name', model.get('id', '')) + else: + model['name'] = model['id'] + models['data'].append(model) + + # Add llama-server models (all available, not just loaded) + if llama_models: + for modellist in llama_models: + for model in modellist: + if not "id" in model.keys(): + model['id'] = model.get('name', model.get('id', '')) + else: + model['name'] = model['id'] + models['data'].append(model) + + # 2. Return a JSONResponse with a deduplicated list of unique models for inference + return JSONResponse( + content={"data": dedupe_on_keys(models['data'], ['name'])}, + status_code=200, + ) + + +@router.post("/v1/rerank") +@router.post("/rerank") +async def rerank_proxy(request: Request): + """ + Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint. + + Compatible with the Jina/Cohere rerank API convention used by llama-server, + vLLM, and services such as Cohere and Jina AI. + + Ollama does not natively support reranking; requests routed to a plain Ollama + endpoint will receive a 501 Not Implemented response. + + Request body: + model (str, required) – reranker model name + query (str, required) – search query + documents (list[str], required) – candidate documents to rank + top_n (int, optional) – limit returned results (default: all) + return_documents (bool, optional) – include document text in results + max_tokens_per_doc (int, optional) – truncation limit per document + + Response (Jina/Cohere-compatible): + { + "id": "...", + "model": "...", + "usage": {"prompt_tokens": N, "total_tokens": N}, + "results": [{"index": 0, "relevance_score": 0.95}, ...] + } + """ + config = get_config() + try: + body_bytes = await request.body() + payload = orjson.loads(body_bytes.decode("utf-8")) + + model = payload.get("model") + query = payload.get("query") + documents = payload.get("documents") + + if not model: + raise HTTPException(status_code=400, detail="Missing required field 'model'") + if not query: + raise HTTPException(status_code=400, detail="Missing required field 'query'") + if not isinstance(documents, list) or not documents: + raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)") + except orjson.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e + + # Determine which endpoint serves this model + try: + endpoint, tracking_model = await choose_endpoint(model) + except RuntimeError as e: + raise HTTPException(status_code=404, detail=str(e)) + + # Ollama endpoints have no native rerank support + if not is_openai_compatible(endpoint): + await decrement_usage(endpoint, tracking_model) + raise HTTPException( + status_code=501, + detail=( + f"Endpoint '{endpoint}' is a plain Ollama instance which does not support " + "reranking. Use a llama-server or OpenAI-compatible endpoint with a " + "dedicated reranker model." + ), + ) + + if ":latest" in model: + model = model.split(":latest")[0] + + # Build upstream rerank request body – forward only recognised fields + upstream_payload: dict = {"model": model, "query": query, "documents": documents} + for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): + if optional_key in payload: + upstream_payload[optional_key] = payload[optional_key] + + # Determine upstream URL: + # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) + # External OpenAI endpoints expose /rerank under their /v1 base + if endpoint in config.llama_server_endpoints: + # llama-server: endpoint may or may not already contain /v1 + if "/v1" in endpoint: + rerank_url = f"{endpoint}/rerank" + else: + rerank_url = f"{endpoint}/v1/rerank" + else: + # External OpenAI-compatible: ep2base gives us the /v1 base + rerank_url = f"{ep2base(endpoint)}/rerank" + + api_key = config.api_keys.get(endpoint, "no-key") + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + + client: aiohttp.ClientSession = get_session(endpoint) + try: + async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: + response_bytes = await resp.read() + if resp.status >= 400: + raise HTTPException( + status_code=resp.status, + detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")), + ) + data = orjson.loads(response_bytes) + + # Record token usage if the upstream returned a usage object + usage = data.get("usage") or {} + prompt_tok = usage.get("prompt_tokens") or 0 + total_tok = usage.get("total_tokens") or 0 + # For reranking there are no completion tokens; we record prompt tokens only + if prompt_tok or total_tok: + await token_queue.put((endpoint, tracking_model, prompt_tok, 0)) + + return JSONResponse(content=data) + finally: + await decrement_usage(endpoint, tracking_model) diff --git a/api/static.py b/api/static.py new file mode 100644 index 0000000..46cdb2d --- /dev/null +++ b/api/static.py @@ -0,0 +1,30 @@ +"""Static-asset and dashboard routes.""" +from pathlib import Path + +from fastapi import APIRouter, HTTPException, Request +from starlette.responses import HTMLResponse, RedirectResponse + +# Directory containing static files (resolved relative to project root). +STATIC_DIR = Path(__file__).resolve().parent.parent / "static" + +router = APIRouter() + + +@router.get("/favicon.ico") +async def redirect_favicon(): + return RedirectResponse(url="/static/favicon.ico") + + +@router.get("/", response_class=HTMLResponse) +async def index(request: Request): + """ + Render the dynamic NOMYO Router dashboard listing the configured endpoints + and the models details, availability & task status. + """ + index_path = STATIC_DIR / "index.html" + try: + return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200) + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Page not found") + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") diff --git a/router.py b/router.py index bf25be9..96916f4 100644 --- a/router.py +++ b/router.py @@ -271,2111 +271,28 @@ from sse import ( # ------------------------------------------------------------- from routing import get_max_connections, choose_endpoint -# ------------------------------------------------------------- -# 6. API route – Generate -# ------------------------------------------------------------- -@app.post("/api/generate") -async def proxy(request: Request): - """ - Proxy a generate request to Ollama and stream the response back to the client. - """ - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - prompt = payload.get("prompt") - suffix = payload.get("suffix") - system = payload.get("system") - template = payload.get("template") - context = payload.get("context") - stream = payload.get("stream") - think = payload.get("think") - raw = payload.get("raw") - _format = payload.get("format") - images = payload.get("images") - options = payload.get("options") - keep_alive = payload.get("keep_alive") - _cache_enabled = payload.get("nomyo", {}).get("cache", False) - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not prompt: - raise HTTPException( - status_code=400, detail="Missing required field 'prompt'" - ) - except orjson.JSONDecodeError as e: - error_msg = f"Invalid JSON format in request body: {str(e)}. Please ensure the request is properly formatted." - raise HTTPException(status_code=400, detail=error_msg) from e - - # Cache lookup — before endpoint selection so no slot is wasted on a hit - _cache = get_llm_cache() - if _cache is not None and _cache_enabled: - _cached = await _cache.get_generate(model, prompt, system or "") - if _cached is not None: - async def _serve_cached_generate(): - yield _cached - return StreamingResponse(_serve_cached_generate(), media_type="application/json") - - _affinity_key = _conversation_fingerprint(model, None, prompt) - endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - params = { - "prompt": prompt, - "model": model, - } - - optional_params = { - "stream": stream, - "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, - "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, - "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, - "seed": options.get("seed") if options and "seed" in options else None, - "stop": options.get("stop") if options and "stop" in options else None, - "top_p": options.get("top_p") if options and "top_p" in options else None, - "temperature": options.get("temperature") if options and "temperature" in options else None, - "suffix": suffix, - } - params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = ollama.AsyncClient(host=endpoint) - - # 4. Async generator that streams data and decrements the counter - async def stream_generate_response(): - try: - if use_openai: - start_ts = time.perf_counter() - async_gen = await oclient.completions.create(**params) - else: - async_gen = await client.generate(model=model, prompt=prompt, suffix=suffix, system=system, template=template, context=context, stream=stream, think=think, raw=raw, format=_format, images=images, options=options, keep_alive=keep_alive) - if stream == True: - content_parts: list[str] = [] - async for chunk in async_gen: - if use_openai: - chunk = rechunk.openai_completion2ollama(chunk, stream, start_ts) - prompt_tok = chunk.prompt_eval_count or 0 - comp_tok = chunk.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - if hasattr(chunk, "model_dump_json"): - json_line = chunk.model_dump_json() - else: - json_line = orjson.dumps(chunk) - # Accumulate and store cache on done chunk — before yield so it always runs - if _cache is not None and _cache_enabled: - if getattr(chunk, "response", None): - content_parts.append(chunk.response) - if getattr(chunk, "done", False): - assembled = orjson.dumps({ - k: v for k, v in { - "model": getattr(chunk, "model", model), - "response": "".join(content_parts), - "done": True, - "done_reason": getattr(chunk, "done_reason", "stop") or "stop", - "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), - "eval_count": getattr(chunk, "eval_count", None), - "total_duration": getattr(chunk, "total_duration", None), - "eval_duration": getattr(chunk, "eval_duration", None), - }.items() if v is not None - }) + b"\n" - try: - await _cache.set_generate(model, prompt, system or "", assembled) - except Exception as _ce: - print(f"[cache] set_generate (streaming) failed: {_ce}") - yield json_line.encode("utf-8") + b"\n" - else: - if use_openai: - response = rechunk.openai_completion2ollama(async_gen, stream, start_ts) - response = response.model_dump_json() - else: - response = async_gen.model_dump_json() - prompt_tok = async_gen.prompt_eval_count or 0 - comp_tok = async_gen.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - json_line = ( - response - if hasattr(async_gen, "model_dump_json") - else orjson.dumps(async_gen) - ) - cache_bytes = json_line.encode("utf-8") + b"\n" - yield cache_bytes - # Cache non-streaming response - if _cache is not None and _cache_enabled: - try: - await _cache.set_generate(model, prompt, system or "", cache_bytes) - except Exception as _ce: - print(f"[cache] set_generate (non-streaming) failed: {_ce}") - - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 5. Return a StreamingResponse backed by the generator - return StreamingResponse( - stream_generate_response(), - media_type="application/json", - ) - -# ------------------------------------------------------------- -# 7. API route – Chat -# ------------------------------------------------------------- -@app.post("/api/chat") -async def chat_proxy(request: Request): - """ - Proxy a chat request to Ollama and stream the endpoint reply. - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - messages = payload.get("messages") - tools = payload.get("tools") - stream = payload.get("stream") - think = payload.get("think") - _format = payload.get("format") - keep_alive = payload.get("keep_alive") - options = payload.get("options") - logprobs = payload.get("logprobs") - top_logprobs = payload.get("top_logprobs") - _cache_enabled = payload.get("nomyo", {}).get("cache", False) - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not isinstance(messages, list): - raise HTTPException( - status_code=400, detail="Missing or invalid 'messages' field (must be a list)" - ) - if options is not None and not isinstance(options, dict): - raise HTTPException( - status_code=400, detail="`options` must be a JSON object" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # Cache lookup — before endpoint selection, always bypassed for MOE - _is_moe = model.startswith("moe-") - _cache = get_llm_cache() - # Normalise model name for cache key: strip ":latest" suffix here so that - # get_chat and set_chat use the same model string regardless of when the - # strip happens further down (line ~1793 strips it for OpenAI endpoints). - _cache_model = model[: -len(":latest")] if model.endswith(":latest") else model - # Snapshot original messages before any OpenAI-format transformation so that - # get_chat and set_chat always use the same key regardless of backend type. - _cache_messages = messages - if _cache is not None and not _is_moe and _cache_enabled: - _cached = await _cache.get_chat("ollama_chat", _cache_model, messages) - if _cached is not None: - async def _serve_cached_chat(): - yield _cached - return StreamingResponse( - _serve_cached_chat(), - media_type="application/x-ndjson" if stream else "application/json", - ) - - # 2. Endpoint logic - if model.startswith("moe-"): - model = model.split("moe-")[1] - opt = True - else: - opt = False - _affinity_key = _conversation_fingerprint(model, messages, None) - endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - if messages: - if any("images" in m for m in messages): - messages = await asyncio.to_thread(transform_images_to_data_urls, messages) - messages = transform_tool_calls_to_openai(messages) - messages = _strip_assistant_prefill(messages) - params = { - "messages": messages, - "model": model, - } - optional_params = { - "tools": tools, - "stream": stream, - "stream_options": {"include_usage": True} if stream else None, - "max_tokens": options.get("num_predict") if options and "num_predict" in options else None, - "frequency_penalty": options.get("frequency_penalty") if options and "frequency_penalty" in options else None, - "presence_penalty": options.get("presence_penalty") if options and "presence_penalty" in options else None, - "seed": options.get("seed") if options and "seed" in options else None, - "stop": options.get("stop") if options and "stop" in options else None, - "top_p": options.get("top_p") if options and "top_p" in options else None, - "temperature": options.get("temperature") if options and "temperature" in options else None, - "logprobs": logprobs if logprobs is not None else (options.get("logprobs") if options and "logprobs" in options else None), - "top_logprobs": top_logprobs if top_logprobs is not None else (options.get("top_logprobs") if options and "top_logprobs" in options else None), - "response_format": {"type": "json_schema", "json_schema": _format} if _format is not None else None - } - params.update({k: v for k, v in optional_params.items() if v is not None}) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = ollama.AsyncClient(host=endpoint) - # For OpenAI endpoints: make the API call in handler scope - # (try/except inside async generators is unreliable with Starlette's streaming) - start_ts = None - async_gen = None - if use_openai: - start_ts = time.perf_counter() - # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model - _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) - if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: - _pre_target = int((_known_nctx - _known_nctx // 4) / 1.2) - _pre_est = _count_message_tokens(params.get("messages", [])) - if _pre_est > _pre_target: - _pre_msgs = params.get("messages", []) - _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) - _dropped = len(_pre_msgs) - len(_pre_trimmed) - print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) - params = {**params, "messages": _pre_trimmed} - try: - async_gen = await oclient.chat.completions.create(**params) - except Exception as e: - _e_str = str(e) - print(f"[chat_proxy] caught {type(e).__name__}: {_e_str[:200]}") - if "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str: - err_body = getattr(e, "body", {}) or {} - err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) - actual_tokens = err_detail.get("n_prompt_tokens", 0) - if not n_ctx_limit: - _m = re.search(r"'n_ctx':\s*(\d+)", _e_str) - if _m: - n_ctx_limit = int(_m.group(1)) - _m = re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) - if _m: - actual_tokens = int(_m.group(1)) - if not n_ctx_limit: - await decrement_usage(endpoint, tracking_model) - raise - if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = n_ctx_limit - msgs_to_trim = params.get("messages", []) - cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) - trimmed = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) - print(f"[chat_proxy] Context exceeded ({actual_tokens}/{n_ctx_limit} tokens, tiktoken_target={cal_target}), dropped {len(msgs_to_trim) - len(trimmed)} oldest message(s) and retrying") - try: - async_gen = await oclient.chat.completions.create(**{**params, "messages": trimmed}) - except Exception as e2: - _e2_str = str(e2) - if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: - print(f"[chat_proxy] Context still exceeded after trimming messages, also stripping tools") - params_no_tools = {k: v for k, v in params.items() if k not in ("tools", "tool_choice")} - try: - async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed}) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_backend_connection_error(e): - print(f"[chat_proxy] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) - await _mark_backend_unhealthy(endpoint, model, _e_str) - await decrement_usage(endpoint, tracking_model) - raise - elif "image input is not supported" in _e_str: - print(f"[chat_proxy] Model {model} doesn't support images, retrying with text-only messages") - try: - params = {**params, "messages": _strip_images_from_messages(params.get("messages", []))} - async_gen = await oclient.chat.completions.create(**params) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise - - # 3. Async generator that streams chat data and decrements the counter - async def stream_chat_response(): - try: - # The chat method returns a generator of dicts (or GenerateResponse) - if use_openai: - _async_gen = async_gen # established in handler scope above - else: - if opt == True: - # Use the dedicated MOE helper function - _async_gen = await _make_moe_requests(model, messages, tools, think, _format, options, keep_alive) - else: - _async_gen = await client.chat(model=model, messages=messages, tools=tools, stream=stream, think=think, format=_format, options=options, keep_alive=keep_alive, logprobs=logprobs, top_logprobs=top_logprobs) - if stream == True: - tc_acc = {} # accumulate OpenAI tool-call deltas across chunks - content_parts: list[str] = [] - async for chunk in _async_gen: - if use_openai: - _accumulate_openai_tc_delta(chunk, tc_acc) - chunk = rechunk.openai_chat_completion2ollama(chunk, stream, start_ts) - # Inject fully-accumulated tool calls only into the final chunk - if chunk.done and tc_acc and chunk.message: - chunk.message.tool_calls = _build_ollama_tool_calls(tc_acc) - # `chunk` can be a dict or a pydantic model – dump to JSON safely - prompt_tok = chunk.prompt_eval_count or 0 - comp_tok = chunk.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - if hasattr(chunk, "model_dump_json"): - json_line = chunk.model_dump_json() - else: - json_line = orjson.dumps(chunk) - # Accumulate and store cache on done chunk — before yield so it always runs - # Works for both Ollama-native and OpenAI-compatible backends; chunks are - # already converted to Ollama format by rechunk before this point. - if getattr(chunk, "done", False): - # Detect context exhaustion mid-generation for small-ctx models - _dr = getattr(chunk, "done_reason", None) - # Only cache when no max_tokens limit was set — otherwise - # finish_reason=length might just mean max_tokens was hit, - # not that the context window was exhausted. - _req_max_tok = ( - params.get("max_tokens") or params.get("max_completion_tokens") or params.get("num_predict") - if use_openai else - (options.get("num_predict") if options else None) - ) - if _dr == "length" and not _req_max_tok: - _pt = getattr(chunk, "prompt_eval_count", 0) or 0 - _ct = getattr(chunk, "eval_count", 0) or 0 - _inferred_nctx = _pt + _ct - if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = _inferred_nctx - print(f"[ctx-cache] done_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) - if _cache is not None and not _is_moe and _cache_enabled: - if chunk.message and getattr(chunk.message, "content", None): - content_parts.append(chunk.message.content) - if getattr(chunk, "done", False): - assembled = orjson.dumps({ - k: v for k, v in { - "model": getattr(chunk, "model", model), - "created_at": (lambda ca: ca.isoformat() if hasattr(ca, "isoformat") else ca)(getattr(chunk, "created_at", None)), - "message": {"role": "assistant", "content": "".join(content_parts)}, - "done": True, - "done_reason": getattr(chunk, "done_reason", "stop") or "stop", - "prompt_eval_count": getattr(chunk, "prompt_eval_count", None), - "eval_count": getattr(chunk, "eval_count", None), - "total_duration": getattr(chunk, "total_duration", None), - "eval_duration": getattr(chunk, "eval_duration", None), - }.items() if v is not None - }) + b"\n" - try: - await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, assembled) - except Exception as _ce: - print(f"[cache] set_chat (ollama_chat streaming) failed: {_ce}") - yield json_line.encode("utf-8") + b"\n" - else: - if use_openai: - response = rechunk.openai_chat_completion2ollama(_async_gen, stream, start_ts) - response = response.model_dump_json() - else: - response = _async_gen.model_dump_json() - prompt_tok = _async_gen.prompt_eval_count or 0 - comp_tok = _async_gen.eval_count or 0 - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - json_line = ( - response - if hasattr(_async_gen, "model_dump_json") - else orjson.dumps(_async_gen) - ) - cache_bytes = json_line.encode("utf-8") + b"\n" - yield cache_bytes - # Cache non-streaming response (non-MOE; works for both Ollama and OpenAI backends) - if _cache is not None and not _is_moe and _cache_enabled: - try: - await _cache.set_chat("ollama_chat", _cache_model, _cache_messages, cache_bytes) - except Exception as _ce: - print(f"[cache] set_chat (ollama_chat non-streaming) failed: {_ce}") - - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 4. Return a StreamingResponse backed by the generator - media_type = "application/x-ndjson" if stream else "application/json" - return StreamingResponse( - stream_chat_response(), - media_type=media_type, - ) - -# ------------------------------------------------------------- -# 8. API route – Embedding - deprecated -# ------------------------------------------------------------- -@app.post("/api/embeddings") -async def embedding_proxy(request: Request): - """ - Proxy an embedding request to Ollama and reply with embeddings. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - prompt = payload.get("prompt") - options = payload.get("options") - keep_alive = payload.get("keep_alive") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not prompt: - raise HTTPException( - status_code=400, detail="Missing required field 'prompt'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = ollama.AsyncClient(host=endpoint) - # 3. Async generator that streams embedding data and decrements the counter - async def stream_embedding_response(): - try: - # The chat method returns a generator of dicts (or GenerateResponse) - if use_openai: - async_gen = await client.embeddings.create(input=prompt, model=model) - async_gen = rechunk.openai_embeddings2ollama(async_gen) - else: - async_gen = await client.embeddings(model=model, prompt=prompt, options=options, keep_alive=keep_alive) - if hasattr(async_gen, "model_dump_json"): - json_line = async_gen.model_dump_json() - else: - json_line = orjson.dumps(async_gen) - yield json_line.encode("utf-8") + b"\n" - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 5. Return a StreamingResponse backed by the generator - return StreamingResponse( - stream_embedding_response(), - media_type="application/json", - ) - -# ------------------------------------------------------------- -# 9. API route – Embed -# ------------------------------------------------------------- -@app.post("/api/embed") -async def embed_proxy(request: Request): - """ - Proxy an embed request to Ollama and reply with embeddings. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - _input = payload.get("input") - truncate = payload.get("truncate") - options = payload.get("options") - keep_alive = payload.get("keep_alive") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not _input: - raise HTTPException( - status_code=400, detail="Missing required field 'input'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) - use_openai = is_openai_compatible(endpoint) - if use_openai: - if ":latest" in model: - model = model.split(":latest") - model = model[0] - client = _make_openai_client(endpoint, api_key=config.api_keys.get(endpoint, "no-key")) - else: - client = ollama.AsyncClient(host=endpoint) - # 3. Async generator that streams embed data and decrements the counter - async def stream_embedding_response(): - try: - # The chat method returns a generator of dicts (or GenerateResponse) - if use_openai: - async_gen = await client.embeddings.create(input=_input, model=model) - async_gen = rechunk.openai_embed2ollama(async_gen, model) - else: - async_gen = await client.embed(model=model, input=_input, truncate=truncate, options=options, keep_alive=keep_alive) - if hasattr(async_gen, "model_dump_json"): - json_line = async_gen.model_dump_json() - else: - json_line = orjson.dumps(async_gen) - yield json_line.encode("utf-8") + b"\n" - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 4. Return a StreamingResponse backed by the generator - return StreamingResponse( - stream_embedding_response(), - media_type="application/json", - ) - -# ------------------------------------------------------------- -# 10. API route – Create -# ------------------------------------------------------------- -@app.post("/api/create") -async def create_proxy(request: Request): - """ - Proxy a create request to all Ollama endpoints and reply with deduplicated status. - """ - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - quantize = payload.get("quantize") - from_ = payload.get("from") - files = payload.get("files") - adapters = payload.get("adapters") - template = payload.get("template") - license = payload.get("license") - system = payload.get("system") - parameters = payload.get("parameters") - messages = payload.get("messages") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not from_ and not files: - raise HTTPException( - status_code=400, detail="You need to provide either from_ or files parameter!" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - status_lists = [] - - for endpoint in config.endpoints: - client = ollama.AsyncClient(host=endpoint) - create = await client.create(model=model, quantize=quantize, from_=from_, files=files, adapters=adapters, template=template, license=license, system=system, parameters=parameters, messages=messages, stream=False) - status_lists.append(create) - - combined_status = [] - for status_list in status_lists: - combined_status += status_list - - final_status = list(dict.fromkeys(combined_status)) - - return dict(final_status) - -# ------------------------------------------------------------- -# 11. API route – Show -# ------------------------------------------------------------- -@app.post("/api/show") -async def show_proxy(request: Request, model: Optional[str] = None): - """ - Proxy a model show request to Ollama and reply with ShowResponse. - - """ - try: - body_bytes = await request.body() - - if not model: - payload = orjson.loads(body_bytes.decode("utf-8")) - model = payload.get("model") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Endpoint logic - endpoint, _ = await choose_endpoint(model, reserve=False) - - client = ollama.AsyncClient(host=endpoint) - - # 3. Proxy a simple show request - show = await client.show(model=model) - - # 4. Return ShowResponse - return show - -# ------------------------------------------------------------- -@app.get("/api/token_counts") -async def token_counts_proxy(): - breakdown = [] - total = 0 - async for entry in db.load_token_counts(): - total += entry['total_tokens'] - breakdown.append({ - "endpoint": entry["endpoint"], - "model": entry["model"], - "input_tokens": entry["input_tokens"], - "output_tokens": entry["output_tokens"], - "total_tokens": entry["total_tokens"], - }) - return {"total_tokens": total, "breakdown": breakdown} - -@app.post("/api/aggregate_time_series_days") -async def aggregate_time_series_days_proxy(request: Request): - """ - Aggregate time_series entries older than days into daily aggregates by endpoint/model/date. - """ - try: - body_bytes = await request.body() - if not body_bytes: - days = 30 - trim_old = False - else: - payload = orjson.loads(body_bytes.decode("utf-8")) - days = int(payload.get("days", 30)) - trim_old = bool(payload.get("trim_old", False)) - except Exception: - days = 30 - trim_old = False - aggregated = await db.aggregate_time_series_older_than(days, trim_old=trim_old) - return {"status": "ok", "days": days, "trim_old": trim_old, "aggregated_groups": aggregated} - -# 12. API route – Stats -# ------------------------------------------------------------- -@app.post("/api/stats") -async def stats_proxy(request: Request, model: Optional[str] = None): - """ - Return token usage statistics for a specific model. - """ - try: - body_bytes = await request.body() - - if not model: - payload = orjson.loads(body_bytes.decode("utf-8")) - model = payload.get("model") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # Get token counts from database - token_data = await db.get_token_counts_for_model(model) - - if not token_data: - raise HTTPException( - status_code=404, detail="No token data found for this model" - ) - - time_series = [ - entry async for entry in db.get_time_series_for_model(model) - ] - endpoint_distribution = await db.get_endpoint_distribution_for_model(model) - - return { - 'model': model, - 'input_tokens': token_data['input_tokens'], - 'output_tokens': token_data['output_tokens'], - 'total_tokens': token_data['total_tokens'], - 'time_series': time_series, - 'endpoint_distribution': endpoint_distribution, - } - -# ------------------------------------------------------------- -# 12. API route – Copy -# ------------------------------------------------------------- -@app.post("/api/copy") -async def copy_proxy(request: Request, source: Optional[str] = None, destination: Optional[str] = None): - """ - Proxy a model copy request to each Ollama endpoint and reply with Status Code. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - - if not source and not destination: - payload = orjson.loads(body_bytes.decode("utf-8")) - src = payload.get("source") - dst = payload.get("destination") - else: - src = source - dst = destination - - if not src: - raise HTTPException( - status_code=400, detail="Missing required field 'source'" - ) - if not dst: - raise HTTPException( - status_code=400, detail="Missing required field 'destination'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 3. Iterate over all endpoints to copy the model on each endpoint - status_list = [] - - for endpoint in config.endpoints: - if "/v1" not in endpoint: - client = ollama.AsyncClient(host=endpoint) - # 4. Proxy a simple copy request - copy = await client.copy(source=src, destination=dst) - status_list.append(copy.status) - - # 4. Return with 200 OK if all went well, 404 if a single endpoint failed - return Response(status_code=404 if 404 in status_list else 200) - -# ------------------------------------------------------------- -# 13. API route – Delete -# ------------------------------------------------------------- -@app.delete("/api/delete") -async def delete_proxy(request: Request, model: Optional[str] = None): - """ - Proxy a model delete request to each Ollama endpoint and reply with Status Code. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - - if not model: - payload = orjson.loads(body_bytes.decode("utf-8")) - model = payload.get("model") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Iterate over all endpoints to delete the model on each endpoint - status_list = [] - - for endpoint in config.endpoints: - if "/v1" not in endpoint: - client = ollama.AsyncClient(host=endpoint) - # 3. Proxy a simple copy request - copy = await client.delete(model=model) - status_list.append(copy.status) - - # 4. Return 200 0K, if a single enpoint fails, respond with 404 - return Response(status_code=404 if 404 in status_list else 200) - -# ------------------------------------------------------------- -# 14. API route – Pull -# ------------------------------------------------------------- -@app.post("/api/pull") -async def pull_proxy(request: Request, model: Optional[str] = None): - """ - Proxy a pull request to all Ollama endpoint and report status back. - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - - if not model: - payload = orjson.loads(body_bytes.decode("utf-8")) - model = payload.get("model") - insecure = payload.get("insecure") - else: - insecure = None - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Iterate over all endpoints to pull the model - status_list = [] - - for endpoint in config.endpoints: - if "/v1" not in endpoint: - client = ollama.AsyncClient(host=endpoint) - # 3. Proxy a simple pull request - pull = await client.pull(model=model, insecure=insecure, stream=False) - status_list.append(pull) - - combined_status = [] - for status in status_list: - combined_status += status - - # 4. Report back a deduplicated status message - final_status = list(dict.fromkeys(combined_status)) - - return dict(final_status) - -# ------------------------------------------------------------- -# 15. API route – Push -# ------------------------------------------------------------- -@app.post("/api/push") -async def push_proxy(request: Request): - """ - Proxy a push request to Ollama and respond the deduplicated Ollama endpoint replies. - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - insecure = payload.get("insecure") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Iterate over all endpoints - status_list = [] - - for endpoint in config.endpoints: - client = ollama.AsyncClient(host=endpoint) - # 3. Proxy a simple push request - push = await client.push(model=model, insecure=insecure, stream=False) - status_list.append(push) - - combined_status = [] - for status in status_list: - combined_status += status - - # 4. Report a deduplicated status - final_status = list(dict.fromkeys(combined_status)) - - return dict(final_status) - - -# ------------------------------------------------------------- -# 16. API route – Version -# ------------------------------------------------------------- -@app.get("/api/version") -async def version_proxy(request: Request): - """ - Proxy a version request to Ollama and reply lowest version of all endpoints. - - """ - # 1. Query all endpoints for version - tasks = [fetch.endpoint_details(ep, "/api/version", "version") for ep in config.endpoints if "/v1" not in ep] - all_versions_raw = await asyncio.gather(*tasks) - - # Filter out non-string values (e.g., empty lists from failed/timeout responses) - all_versions = [v for v in all_versions_raw if isinstance(v, str) and v] - - if not all_versions: - raise HTTPException(status_code=503, detail="No valid version response from any endpoint") - - def version_key(v): - return tuple(map(int, v.split('.'))) - - # 2. Return a JSONResponse with the min Version of all endpoints to maintain compatibility - return JSONResponse( - content={"version": str(min(all_versions, key=version_key))}, - status_code=200, - ) - -# ------------------------------------------------------------- -# 17. API route – tags -# ------------------------------------------------------------- -@app.get("/api/tags") -async def tags_proxy(request: Request): - """ - Proxy a tags request to Ollama endpoints and reply with a unique list of all models. - - """ - - # 1. Query all endpoints for models - tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] - tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys[ep], skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" in ep] - # Also query llama-server endpoints not already covered by config.endpoints - llama_eps_for_tags = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] - tasks += [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in llama_eps_for_tags] - all_models = await asyncio.gather(*tasks) - - models = {'models': []} - for modellist in all_models: - for model in modellist: - if not "model" in model.keys(): # Relable OpenAI models with Ollama Model.model from Model.id - model['model'] = model['id'] + ":latest" - else: - model['id'] = model['model'] - if not "name" in model.keys(): # Relable OpenAI models with Ollama Model.name from Model.model to have model,name keys - model['name'] = model['model'] - else: - model['id'] = model['model'] - models['models'] += modellist - - # 2. Return a JSONResponse with a deduplicated list of unique models for inference - return JSONResponse( - content={"models": dedupe_on_keys(models['models'], ['digest','name','id'])}, - status_code=200, - ) - -# ------------------------------------------------------------- -# 18. API route – ps -# ------------------------------------------------------------- -@app.get("/api/ps") -async def ps_proxy(request: Request): - """ - Proxy a ps request to all Ollama and llama-server endpoints and reply a unique list of all running models. - - For Ollama endpoints: queries /api/ps - For llama-server endpoints: queries /v1/models with status.value == "loaded" - """ - # 1. Query Ollama endpoints for running models via /api/ps - ollama_tasks = [fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] - # 2. Query llama-server endpoints for loaded models via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) - llama_tasks = [ - fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) - for ep in all_llama_endpoints - ] - - ollama_loaded = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] - llama_loaded = await asyncio.gather(*llama_tasks) if llama_tasks else [] - - models = {'models': []} - # Add Ollama models (if any) - if ollama_loaded: - for modellist in ollama_loaded: - models['models'] += modellist - # Add llama-server models (filter for loaded only, if any) - if llama_loaded: - for modellist in llama_loaded: - loaded_models = [item for item in modellist if _is_llama_model_loaded(item)] - # Convert llama-server format to Ollama-like format for consistency - for item in loaded_models: - raw_id = item.get("id", "") - normalized = _normalize_llama_model_name(raw_id) - quant = _extract_llama_quant(raw_id) - models['models'].append({ - "name": normalized, - "id": normalized, - "digest": "", - "status": item.get("status"), - "details": {"quantization_level": quant} if quant else {} - }) - - # 3. Return a JSONResponse with deduplicated currently deployed models - # Deduplicate on 'name' rather than 'digest': llama-server models always - # have digest="" so deduping on digest collapses all of them to one entry. - return JSONResponse( - content={"models": dedupe_on_keys(models['models'], ['name'])}, - status_code=200, - ) - -# ------------------------------------------------------------- -# 18b. API route – ps details (backwards compatible) -# ------------------------------------------------------------- -@app.get("/api/ps_details") -async def ps_details_proxy(request: Request): - """ - Proxy a ps request to all Ollama and llama-server endpoints and reply with per-endpoint instances. - This keeps /api/ps backward compatible while providing richer data. - - For Ollama endpoints: queries /api/ps - For llama-server endpoints: queries /v1/models with status info - """ - # 1. Query Ollama endpoints via /api/ps - ollama_tasks = [(ep, fetch.endpoint_details(ep, "/api/ps", "models", skip_error_cache=True, timeout=8)) for ep in config.endpoints if "/v1" not in ep] - # 2. Query llama-server endpoints via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) - llama_tasks = [ - (ep, fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8)) - for ep in all_llama_endpoints - ] - - ollama_loaded = await asyncio.gather(*[task for _, task in ollama_tasks]) if ollama_tasks else [] - llama_loaded = await asyncio.gather(*[task for _, task in llama_tasks]) if llama_tasks else [] - - models: list[dict] = [] - - # Add Ollama models with endpoint info (if any) - if ollama_loaded: - for (endpoint, modellist) in zip([ep for ep, _ in ollama_tasks], ollama_loaded): - for model in modellist: - if isinstance(model, dict): - model_with_endpoint = dict(model) - model_with_endpoint["endpoint"] = endpoint - models.append(model_with_endpoint) - - # Add llama-server models with endpoint info and full status metadata (if any) - if llama_loaded: - # Collect (endpoint, raw_id) pairs to fetch /props in parallel - props_requests: list[tuple[str, str]] = [] - llama_models_pending: list[dict] = [] - - for (endpoint, modellist) in zip([ep for ep, _ in llama_tasks], llama_loaded): - # Include sleeping models too so _fetch_llama_props can unload them - loaded_models = [item for item in modellist if _is_llama_model_loaded_or_sleeping(item)] - for item in loaded_models: - if isinstance(item, dict) and item.get("id"): - raw_id = item["id"] - normalized = _normalize_llama_model_name(raw_id) - quant = _extract_llama_quant(raw_id) - model_with_endpoint = { - "name": normalized, - "id": normalized, - "original_name": raw_id, - "digest": "", - "details": {"quantization_level": quant} if quant else {}, - "endpoint": endpoint, - "status": item.get("status"), - "created": item.get("created"), - "owned_by": item.get("owned_by") - } - # Include full llama-server status details (args, preset) - status_info = item.get("status", {}) - if isinstance(status_info, dict): - model_with_endpoint["llama_status_args"] = status_info.get("args") - model_with_endpoint["llama_status_preset"] = status_info.get("preset") - llama_models_pending.append(model_with_endpoint) - props_requests.append((endpoint, raw_id)) - - # Fetch /props for each llama-server model to get context length (n_ctx) - # and unload sleeping models automatically - async def _fetch_llama_props(endpoint: str, model_id: str) -> tuple[int | None, bool, bool]: - client: aiohttp.ClientSession = get_session(endpoint) - base_url = endpoint.rstrip("/").removesuffix("/v1") - props_url = f"{base_url}/props?model={model_id}" - headers = None - api_key = config.api_keys.get(endpoint) - if api_key: - headers = {"Authorization": f"Bearer {api_key}"} - try: - async with client.get(props_url, headers=headers, timeout=aiohttp.ClientTimeout(total=5)) as resp: - if resp.status == 200: - data = await resp.json() - dgs = data.get("default_generation_settings", {}) - n_ctx = dgs.get("n_ctx") - is_sleeping = data.get("is_sleeping", False) - # Embedding models have no sampling params in default_generation_settings - is_generation = "temperature" in dgs - - if is_sleeping: - unload_url = f"{base_url}/models/unload" - try: - async with client.post( - unload_url, - json={"model": model_id}, - headers=headers, - ) as unload_resp: - print(f"[ps_details] Unloaded sleeping model {model_id} from {endpoint}: {unload_resp.status}") - except Exception as ue: - print(f"[ps_details] Failed to unload sleeping model {model_id} from {endpoint}: {ue}") - - return n_ctx, is_sleeping, is_generation - except Exception as e: - print(f"[ps_details] Failed to fetch props from {props_url}: {e}") - return None, False, False - - props_results = await asyncio.gather( - *[_fetch_llama_props(ep, mid) for ep, mid in props_requests] - ) - - for (ep, raw_id), model_dict, (n_ctx, is_sleeping, is_generation) in zip(props_requests, llama_models_pending, props_results): - if n_ctx is not None: - model_dict["context_length"] = n_ctx - if is_generation and 0 < n_ctx <= _CTX_TRIM_SMALL_LIMIT: - normalized = _normalize_llama_model_name(raw_id) - _endpoint_nctx[(ep, normalized)] = n_ctx - print(f"[ctx-cache/ps] cached n_ctx={n_ctx} for ({ep},{normalized})", flush=True) - if not is_sleeping: - models.append(model_dict) - - return JSONResponse(content={"models": models}, status_code=200) - +# (Ollama /api/* routes — moved to api/ollama.py) # ------------------------------------------------------------- # 18b. Conversation-affinity stats – feeds the PS-table dot matrix # ------------------------------------------------------------- -@app.get("/api/affinity_stats") -async def affinity_stats(request: Request): - """ - Aggregate live conversation-affinity pins, one entry per pinned conversation. - Each entry exposes only the endpoint, model, and remaining TTL in seconds — - no fingerprints or content. When conversation_affinity is disabled the - `entries` list is always empty. - """ - if not config.conversation_affinity: - return {"enabled": False, "ttl": config.conversation_affinity_ttl, "entries": []} - - now = time.monotonic() - entries: list[dict] = [] - llama_eps = set(config.llama_server_endpoints) - async with _affinity_lock: - for fp, (ep, mdl, expires_at) in list(_affinity_map.items()): - remaining = expires_at - now - if remaining <= 0: - _affinity_map.pop(fp, None) - continue - # Mirror the normalisation used by /api/ps_details so the dashboard - # can join affinity entries to PS rows by (endpoint, model). - display_model = _normalize_llama_model_name(mdl) if ep in llama_eps else mdl - entries.append({ - "endpoint": ep, - "model": display_model, - "remaining": round(remaining, 2), - }) - return { - "enabled": True, - "ttl": config.conversation_affinity_ttl, - "entries": entries, - } - -# ------------------------------------------------------------- -# 19. Proxy usage route – for monitoring -# ------------------------------------------------------------- -@app.get("/api/usage") -async def usage_proxy(request: Request): - """ - Return a snapshot of the usage counter for each endpoint. - Useful for debugging / monitoring. - """ - return {"usage_counts": usage_counts, - "token_usage_counts": token_usage_counts} - -from backends.probe import _raw_probe, _endpoint_health - - -# ------------------------------------------------------------- -# 20b. Proxy config route – for monitoring and frontend usage -# ------------------------------------------------------------- -@app.get("/api/config") -async def config_proxy(request: Request): - """ - Return a simple JSON object that contains the configured - Ollama endpoints and llama_server_endpoints. The front‑end uses this - to display which endpoints are being proxied and their health. - Status is "error" when either liveness (/api/version) or routing - health (/api/ps) fails — see issue #83. - """ - async def check(url: str) -> dict: - return {"url": url, **(await _endpoint_health(url, timeout=5))} - - ollama_results = await asyncio.gather(*[check(ep) for ep in config.endpoints]) - llama_results = [] - if config.llama_server_endpoints: - llama_results = await asyncio.gather( - *[check(ep) for ep in config.llama_server_endpoints] - ) - - return { - "endpoints": ollama_results, - "llama_server_endpoints": llama_results, - "require_router_api_key": bool(config.router_api_key), - } - -# ------------------------------------------------------------- -# 21. API route – OpenAI compatible Embedding -# ------------------------------------------------------------- -@app.post("/v1/embeddings") -async def openai_embedding_proxy(request: Request): - """ - Proxy an OpenAI API compatible embedding request to Ollama and reply with embeddings. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - doc = payload.get("input") - - # Normalize multimodal input: extract only text parts for embedding models - if isinstance(doc, list): - normalized = [] - for item in doc: - if isinstance(item, dict): - # Multimodal content part - extract text only, skip images - if item.get("type") == "text": - normalized.append(item.get("text", "")) - # Skip image_url and other non-text types - else: - normalized.append(item) - doc = normalized if len(normalized) != 1 else normalized[0] - elif isinstance(doc, dict) and doc.get("type") == "text": - doc = doc.get("text", "") - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not doc: - raise HTTPException( - status_code=400, detail="Missing required field 'input'" - ) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # 2. Endpoint logic - endpoint, tracking_model = await choose_endpoint(model) - if is_openai_compatible(endpoint): - api_key = config.api_keys.get(endpoint, "no-key") - else: - api_key = "ollama" - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=api_key) - - try: - async_gen = await oclient.embeddings.create(input=doc, model=model) - result = async_gen.model_dump() - for item in result.get("data", []): - emb = item.get("embedding") - if emb: - item["embedding"] = [0.0 if isinstance(v, float) and not math.isfinite(v) else v for v in emb] - return JSONResponse(content=result) - finally: - await decrement_usage(endpoint, tracking_model) - -# ------------------------------------------------------------- -# 22. API route – OpenAI compatible Chat Completions -# ------------------------------------------------------------- -@app.post("/v1/chat/completions") -async def openai_chat_completions_proxy(request: Request): - """ - Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - messages = payload.get("messages") - frequency_penalty = payload.get("frequency_penalty") - presence_penalty = payload.get("presence_penalty") - response_format = payload.get("response_format") - seed = payload.get("seed") - stop = payload.get("stop") - stream = payload.get("stream") - stream_options = payload.get("stream_options") - temperature = payload.get("temperature") - top_p = payload.get("top_p") - max_tokens = payload.get("max_tokens") - max_completion_tokens = payload.get("max_completion_tokens") - tools = payload.get("tools") - logprobs = payload.get("logprobs") - top_logprobs = payload.get("top_logprobs") - _cache_enabled = payload.get("nomyo", {}).get("cache", False) - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not isinstance(messages, list): - raise HTTPException( - status_code=400, detail="Missing required field 'messages' (must be a list)" - ) - - if ":latest" in model: - model = model.split(":latest") - model = model[0] - - messages = _strip_assistant_prefill(messages) - params = { - "messages": messages, - "model": model, - } - - optional_params = { - "tools": tools, - "response_format": response_format, - "stream_options": stream_options or {"include_usage": True }, - "max_completion_tokens": max_completion_tokens, - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "seed": seed, - "presence_penalty": presence_penalty, - "frequency_penalty": frequency_penalty, - "stop": stop, - "stream": stream, - "logprobs": logprobs, - "top_logprobs": top_logprobs, - } - - params.update({k: v for k, v in optional_params.items() if v is not None}) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # Reject unsupported image formats (SVG) before doing any work - for _msg in messages: - for _item in (_msg.get("content") or []) if isinstance(_msg.get("content"), list) else []: - if _item.get("type") == "image_url": - _url = (_item.get("image_url") or {}).get("url", "") - if _url.startswith("data:image/svg") or _url.lower().endswith(".svg"): - raise HTTPException( - status_code=400, - detail="SVG images are not supported. Please convert the image to PNG or JPEG before sending.", - ) - - # Cache lookup — before endpoint selection - _cache = get_llm_cache() - if _cache is not None and _cache_enabled: - _cached = await _cache.get_chat("openai_chat", model, messages) - if _cached is not None: - if stream: - _sse = openai_nonstream_to_sse(_cached, model) - async def _serve_cached_ochat_stream(): - yield _sse - return StreamingResponse(_serve_cached_ochat_stream(), media_type="text/event-stream") - else: - async def _serve_cached_ochat_json(): - yield _cached - return StreamingResponse(_serve_cached_ochat_json(), media_type="application/json") - - # 2. Endpoint logic - _affinity_key = _conversation_fingerprint(model, messages, None) - endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - # 3. Helpers and API call — done in handler scope so try/except works reliably - async def _normalize_images_in_messages(msgs: list) -> list: - """Fetch remote image URLs and convert them to base64 data URLs so - Ollama/llama-server can handle them without making outbound HTTP requests.""" - resolved = [] - for msg in msgs: - content = msg.get("content") - if not isinstance(content, list): - resolved.append(msg) - continue - new_content = [] - for item in content: - if item.get("type") == "image_url": - url = (item.get("image_url") or {}).get("url", "") - if url and not url.startswith("data:"): - try: - http: aiohttp.ClientSession = app_state["session"] - async with http.get(url) as resp: - ctype = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip() - img_bytes = await resp.read() - b64 = base64.b64encode(img_bytes).decode("utf-8") - new_content.append({ - "type": "image_url", - "image_url": {"url": f"data:{ctype};base64,{b64}"} - }) - except Exception as _ie: - print(f"[image] Failed to fetch image URL: {_ie}") - new_content.append(item) - else: - new_content.append(item) - else: - new_content.append(item) - resolved.append({**msg, "content": new_content}) - return resolved - - # Make the API call in handler scope — try/except inside async generators is unreliable - # with Starlette's streaming machinery, so we resolve errors here before the generator starts. - send_params = params - if not is_ext_openai_endpoint(endpoint): - resolved_msgs = await _normalize_images_in_messages(params.get("messages", [])) - send_params = {**params, "messages": resolved_msgs} - # Proactive trim: only for small-ctx models we've already seen run out of space - _lookup_model = _normalize_llama_model_name(model) if endpoint in config.llama_server_endpoints else model - _known_nctx = _endpoint_nctx.get((endpoint, _lookup_model)) - if _known_nctx and _known_nctx <= _CTX_TRIM_SMALL_LIMIT: - _pre_target = int(((_known_nctx - _known_nctx // 4)) / 1.2) - _pre_est = _count_message_tokens(send_params.get("messages", [])) - if _pre_est > _pre_target: - _pre_msgs = send_params.get("messages", []) - _pre_trimmed = _trim_messages_for_context(_pre_msgs, _known_nctx, target_tokens=_pre_target) - _dropped = len(_pre_msgs) - len(_pre_trimmed) - print(f"[ctx-pre] n_ctx={_known_nctx} est={_pre_est} target={_pre_target} dropped={_dropped}", flush=True) - send_params = {**send_params, "messages": _pre_trimmed} - try: - async_gen = await oclient.chat.completions.create(**send_params) - except Exception as e: - _e_str = str(e) - _is_ctx_err = "exceed_context_size_error" in _e_str or "exceeds the available context size" in _e_str - print(f"[ochat] caught={type(e).__name__} ctx={_is_ctx_err} msg={_e_str[:120]}", flush=True) - if "does not support tools" in _e_str: - # Model doesn't support tools — retry without them - print(f"[ochat] retry: no tools", flush=True) - try: - params_without_tools = {k: v for k, v in send_params.items() if k != "tools"} - async_gen = await oclient.chat.completions.create(**params_without_tools) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_ctx_err: - # Backend context limit hit — apply sliding-window trim (context-shift at message level) - err_body = getattr(e, "body", {}) or {} - err_detail = err_body.get("error", {}) if isinstance(err_body, dict) else {} - n_ctx_limit = err_detail.get("n_ctx", 0) - actual_tokens = err_detail.get("n_prompt_tokens", 0) - # Fallback: parse from string if body parsing yielded nothing (SDK may not parse llama-server errors) - if not n_ctx_limit: - import re as _re - _m = _re.search(r"'n_ctx':\s*(\d+)", _e_str) - if _m: - n_ctx_limit = int(_m.group(1)) - _m = _re.search(r"'n_prompt_tokens':\s*(\d+)", _e_str) - if _m: - actual_tokens = int(_m.group(1)) - print(f"[ctx-trim] n_ctx={n_ctx_limit} actual={actual_tokens}", flush=True) - if not n_ctx_limit: - await decrement_usage(endpoint, tracking_model) - raise - if n_ctx_limit <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = n_ctx_limit - - msgs_to_trim = send_params.get("messages", []) - try: - cal_target = _calibrated_trim_target(msgs_to_trim, n_ctx_limit, actual_tokens) - trimmed_messages = _trim_messages_for_context(msgs_to_trim, n_ctx_limit, target_tokens=cal_target) - except Exception as _helper_exc: - print(f"[ctx-trim] helper crash: {type(_helper_exc).__name__}: {str(_helper_exc)[:100]}", flush=True) - await decrement_usage(endpoint, tracking_model) - raise - dropped = len(msgs_to_trim) - len(trimmed_messages) - print(f"[ctx-trim] target={cal_target} dropped={dropped} remaining={len(trimmed_messages)} retrying-1", flush=True) - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-1 ok", flush=True) - except Exception as e2: - _e2_str = str(e2) - if "exceed_context_size_error" in _e2_str or "exceeds the available context size" in _e2_str: - # Still too large — tool definitions likely consuming too many tokens, strip them too - print(f"[ctx-trim] retry-1 still exceeded, stripping tools retrying-2", flush=True) - params_no_tools = {k: v for k, v in send_params.items() if k not in ("tools", "tool_choice")} - try: - async_gen = await oclient.chat.completions.create(**{**params_no_tools, "messages": trimmed_messages}) - print(f"[ctx-trim] retry-2 ok", flush=True) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise - elif _is_backend_connection_error(e): - # Upstream connection failed (e.g. llama-server in router mode - # whose delegated worker died). Mark (endpoint, model) so the - # next request reroutes; the client will retry this one. - print(f"[ochat] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) - await _mark_backend_unhealthy(endpoint, model, _e_str) - await decrement_usage(endpoint, tracking_model) - raise - elif "image input is not supported" in _e_str: - # Model doesn't support images — strip and retry - print(f"[openai_chat_completions_proxy] Model {model} doesn't support images, retrying with text-only messages") - try: - async_gen = await oclient.chat.completions.create(**{**send_params, "messages": _strip_images_from_messages(send_params.get("messages", []))}) - except Exception: - await decrement_usage(endpoint, tracking_model) - raise - else: - await decrement_usage(endpoint, tracking_model) - raise - - # 4. Async generator — only streams the already-established async_gen - async def stream_ochat_response(): - try: - if stream == True: - content_parts: list[str] = [] - usage_snapshot: dict = {} - async for chunk in async_gen: - data = ( - chunk.model_dump_json() - if hasattr(chunk, "model_dump_json") - else orjson.dumps(chunk) - ) - if chunk.choices: - delta = chunk.choices[0].delta - has_content = delta.content is not None - has_reasoning = ( - getattr(delta, "reasoning_content", None) is not None - or getattr(delta, "reasoning", None) is not None - ) - has_tool_calls = getattr(delta, "tool_calls", None) is not None - if has_content or has_reasoning or has_tool_calls: - yield f"data: {data}\n\n".encode("utf-8") - if has_content and delta.content: - content_parts.append(delta.content) - elif chunk.usage is not None: - # Forward the usage-only final chunk (e.g. from llama-server) - yield f"data: {data}\n\n".encode("utf-8") - prompt_tok = 0 - comp_tok = 0 - if chunk.usage is not None: - prompt_tok = chunk.usage.prompt_tokens or 0 - comp_tok = chunk.usage.completion_tokens or 0 - usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok} - else: - llama_usage = rechunk.extract_usage_from_llama_timings(chunk) - if llama_usage: - prompt_tok, comp_tok = llama_usage - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - # Detect context exhaustion mid-generation for small-ctx models. - # Guard: skip if max_tokens was set in the request — finish_reason=length - # could just mean the caller's token budget was exhausted, not the context window. - _req_max_tok = send_params.get("max_tokens") or send_params.get("max_completion_tokens") - if chunk.choices and chunk.choices[0].finish_reason == "length" and not _req_max_tok: - _inferred_nctx = (prompt_tok + comp_tok) or 0 - if 0 < _inferred_nctx <= _CTX_TRIM_SMALL_LIMIT: - _endpoint_nctx[(endpoint, model)] = _inferred_nctx - print(f"[ctx-cache] finish_reason=length → cached n_ctx={_inferred_nctx} for ({endpoint},{model})", flush=True) - # Cache assembled streaming response — before [DONE] so it always runs - if _cache is not None and _cache_enabled and content_parts: - assembled = orjson.dumps({ - "model": model, - "choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(content_parts)}, "finish_reason": "stop"}], - **({"usage": usage_snapshot} if usage_snapshot else {}), - }) + b"\n" - try: - await _cache.set_chat("openai_chat", model, messages, assembled) - except Exception as _ce: - print(f"[cache] set_chat (openai_chat streaming) failed: {_ce}") - yield b"data: [DONE]\n\n" - else: - prompt_tok = 0 - comp_tok = 0 - if async_gen.usage is not None: - prompt_tok = async_gen.usage.prompt_tokens or 0 - comp_tok = async_gen.usage.completion_tokens or 0 - else: - llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) - if llama_usage: - prompt_tok, comp_tok = llama_usage - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - json_line = ( - async_gen.model_dump_json() - if hasattr(async_gen, "model_dump_json") - else orjson.dumps(async_gen) - ) - cache_bytes = json_line.encode("utf-8") + b"\n" - yield cache_bytes - # Cache non-streaming response - if _cache is not None and _cache_enabled: - try: - await _cache.set_chat("openai_chat", model, messages, cache_bytes) - except Exception as _ce: - print(f"[cache] set_chat (openai_chat non-streaming) failed: {_ce}") - - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 4. Return a StreamingResponse backed by the generator - return StreamingResponse( - stream_ochat_response(), - media_type="text/event-stream" if stream else "application/json", - ) - -# ------------------------------------------------------------- -# 23. API route – OpenAI compatible Completions -# ------------------------------------------------------------- -@app.post("/v1/completions") -async def openai_completions_proxy(request: Request): - """ - Proxy an OpenAI API compatible chat completions request to Ollama and reply with a streaming response. - - """ - # 1. Parse and validate request - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - prompt = payload.get("prompt") - frequency_penalty = payload.get("frequency_penalty") - presence_penalty = payload.get("presence_penalty") - seed = payload.get("seed") - stop = payload.get("stop") - stream = payload.get("stream") - stream_options = payload.get("stream_options") - temperature = payload.get("temperature") - top_p = payload.get("top_p") - max_tokens = payload.get("max_tokens") - max_completion_tokens = payload.get("max_completion_tokens") - suffix = payload.get("suffix") - _cache_enabled = payload.get("nomyo", {}).get("cache", False) - - if not model: - raise HTTPException( - status_code=400, detail="Missing required field 'model'" - ) - if not prompt: - raise HTTPException( - status_code=400, detail="Missing required field 'prompt'" - ) - - if ":latest" in model: - model = model.split(":latest") - model = model[0] - - params = { - "prompt": prompt, - "model": model, - } - - optional_params = { - "frequency_penalty": frequency_penalty, - "presence_penalty": presence_penalty, - "seed": seed, - "stop": stop, - "stream": stream, - "stream_options": stream_options or {"include_usage": True }, - "temperature": temperature, - "top_p": top_p, - "max_tokens": max_tokens, - "max_completion_tokens": max_completion_tokens, - "suffix": suffix - } - - params.update({k: v for k, v in optional_params.items() if v is not None}) - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # Cache lookup — completions prompt mapped to a single-turn messages list - _cache = get_llm_cache() - _compl_messages = [{"role": "user", "content": prompt}] - if _cache is not None and _cache_enabled: - _cached = await _cache.get_chat("openai_completions", model, _compl_messages) - if _cached is not None: - if stream: - _sse = openai_nonstream_to_sse(_cached, model) - async def _serve_cached_ocompl_stream(): - yield _sse - return StreamingResponse(_serve_cached_ocompl_stream(), media_type="text/event-stream") - else: - async def _serve_cached_ocompl_json(): - yield _cached - return StreamingResponse(_serve_cached_ocompl_json(), media_type="application/json") - - # 2. Endpoint logic - _affinity_key = _conversation_fingerprint(model, None, prompt) - endpoint, tracking_model = await choose_endpoint(model, affinity_key=_affinity_key) - oclient = _make_openai_client(endpoint, default_headers=default_headers, api_key=config.api_keys.get(endpoint, "no-key")) - - # 3. Async generator that streams completions data and decrements the counter - # Make the API call in handler scope (try/except inside async generators is unreliable) - try: - async_gen = await oclient.completions.create(**params) - except Exception as e: - if _is_backend_connection_error(e): - print(f"[ocompl] backend connection error → marking ({endpoint}, {model}) unhealthy", flush=True) - await _mark_backend_unhealthy(endpoint, model, str(e)) - await decrement_usage(endpoint, tracking_model) - raise - - async def stream_ocompletions_response(model=model): - try: - if stream == True: - text_parts: list[str] = [] - usage_snapshot: dict = {} - async for chunk in async_gen: - data = ( - chunk.model_dump_json() - if hasattr(chunk, "model_dump_json") - else orjson.dumps(chunk) - ) - if chunk.choices: - choice = chunk.choices[0] - has_text = getattr(choice, "text", None) is not None - has_reasoning = ( - getattr(choice, "reasoning_content", None) is not None - or getattr(choice, "reasoning", None) is not None - ) - if has_text or has_reasoning or choice.finish_reason is not None: - yield f"data: {data}\n\n".encode("utf-8") - if has_text and choice.text: - text_parts.append(choice.text) - elif chunk.usage is not None: - # Forward the usage-only final chunk (e.g. from llama-server) - yield f"data: {data}\n\n".encode("utf-8") - prompt_tok = 0 - comp_tok = 0 - if chunk.usage is not None: - prompt_tok = chunk.usage.prompt_tokens or 0 - comp_tok = chunk.usage.completion_tokens or 0 - usage_snapshot = {"prompt_tokens": prompt_tok, "completion_tokens": comp_tok, "total_tokens": prompt_tok + comp_tok} - else: - llama_usage = rechunk.extract_usage_from_llama_timings(chunk) - if llama_usage: - prompt_tok, comp_tok = llama_usage - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - # Cache assembled streaming response — before [DONE] so it always runs - if _cache is not None and _cache_enabled and text_parts: - assembled = orjson.dumps({ - "model": model, - "choices": [{"index": 0, "message": {"role": "assistant", "content": "".join(text_parts)}, "finish_reason": "stop"}], - **({"usage": usage_snapshot} if usage_snapshot else {}), - }) + b"\n" - try: - await _cache.set_chat("openai_completions", model, _compl_messages, assembled) - except Exception as _ce: - print(f"[cache] set_chat (openai_completions streaming) failed: {_ce}") - # Final DONE event - yield b"data: [DONE]\n\n" - else: - prompt_tok = 0 - comp_tok = 0 - if async_gen.usage is not None: - prompt_tok = async_gen.usage.prompt_tokens or 0 - comp_tok = async_gen.usage.completion_tokens or 0 - else: - llama_usage = rechunk.extract_usage_from_llama_timings(async_gen) - if llama_usage: - prompt_tok, comp_tok = llama_usage - if prompt_tok != 0 or comp_tok != 0: - await token_queue.put((endpoint, tracking_model, prompt_tok, comp_tok)) - json_line = ( - async_gen.model_dump_json() - if hasattr(async_gen, "model_dump_json") - else orjson.dumps(async_gen) - ) - cache_bytes = json_line.encode("utf-8") + b"\n" - yield cache_bytes - # Cache non-streaming response - if _cache is not None and _cache_enabled: - try: - await _cache.set_chat("openai_completions", model, _compl_messages, cache_bytes) - except Exception as _ce: - print(f"[cache] set_chat (openai_completions non-streaming) failed: {_ce}") - - finally: - # Ensure counter is decremented even if an exception occurs - await decrement_usage(endpoint, tracking_model) - - # 4. Return a StreamingResponse backed by the generator - return StreamingResponse( - stream_ocompletions_response(), - media_type="text/event-stream" if stream else "application/json", - ) - -# ------------------------------------------------------------- -# 24. OpenAI API compatible models endpoint -# ------------------------------------------------------------- -@app.get("/v1/models") -async def openai_models_proxy(request: Request): - """ - Proxy an OpenAI API models request to Ollama and llama-server endpoints and reply with a unique list of models. - - For Ollama endpoints: queries /api/tags (all models) - For llama-server endpoints: queries /v1/models and filters for status.value == "loaded" - """ - # 1. Query Ollama endpoints for all models via /api/tags - ollama_tasks = [fetch.endpoint_details(ep, "/api/tags", "models", skip_error_cache=True, timeout=8) for ep in config.endpoints if "/v1" not in ep] - # 2. Query external OpenAI endpoints (Groq, OpenAI, etc.) via /models - ext_openai_tasks = [fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) for ep in config.endpoints if is_ext_openai_endpoint(ep)] - # 3. Query llama-server endpoints for loaded models via /v1/models - # Also query endpoints from llama_server_endpoints that may not be in config.endpoints - all_llama_endpoints = set(config.llama_server_endpoints) | set(ep for ep in config.endpoints if ep in config.llama_server_endpoints) - llama_tasks = [ - fetch.endpoint_details(ep, "/models", "data", config.api_keys.get(ep), skip_error_cache=True, timeout=8) - for ep in all_llama_endpoints - ] - - ollama_models = await asyncio.gather(*ollama_tasks) if ollama_tasks else [] - ext_openai_models = await asyncio.gather(*ext_openai_tasks) if ext_openai_tasks else [] - llama_models = await asyncio.gather(*llama_tasks) if llama_tasks else [] - - models = {'data': []} - - # Add Ollama models (if any) - if ollama_models: - for modellist in ollama_models: - for model in modellist: - if not "id" in model.keys(): # Relable Ollama models with OpenAI Model.id from Model.name - model['id'] = model.get('name', model.get('id', '')) - else: - model['name'] = model['id'] - models['data'].append(model) - - # Add external OpenAI models (if any) - if ext_openai_models: - for modellist in ext_openai_models: - for model in modellist: - if not "id" in model.keys(): - model['id'] = model.get('name', model.get('id', '')) - else: - model['name'] = model['id'] - models['data'].append(model) - - # Add llama-server models (all available, not just loaded) - if llama_models: - for modellist in llama_models: - for model in modellist: - if not "id" in model.keys(): - model['id'] = model.get('name', model.get('id', '')) - else: - model['name'] = model['id'] - models['data'].append(model) - - # 2. Return a JSONResponse with a deduplicated list of unique models for inference - return JSONResponse( - content={"data": dedupe_on_keys(models['data'], ['name'])}, - status_code=200, - ) - -# ------------------------------------------------------------- -# 25. API route – OpenAI/Jina/Cohere compatible Rerank -# ------------------------------------------------------------- -@app.post("/v1/rerank") -@app.post("/rerank") -async def rerank_proxy(request: Request): - """ - Proxy a rerank request to a llama-server or external OpenAI-compatible endpoint. - - Compatible with the Jina/Cohere rerank API convention used by llama-server, - vLLM, and services such as Cohere and Jina AI. - - Ollama does not natively support reranking; requests routed to a plain Ollama - endpoint will receive a 501 Not Implemented response. - - Request body: - model (str, required) – reranker model name - query (str, required) – search query - documents (list[str], required) – candidate documents to rank - top_n (int, optional) – limit returned results (default: all) - return_documents (bool, optional) – include document text in results - max_tokens_per_doc (int, optional) – truncation limit per document - - Response (Jina/Cohere-compatible): - { - "id": "...", - "model": "...", - "usage": {"prompt_tokens": N, "total_tokens": N}, - "results": [{"index": 0, "relevance_score": 0.95}, ...] - } - """ - try: - body_bytes = await request.body() - payload = orjson.loads(body_bytes.decode("utf-8")) - - model = payload.get("model") - query = payload.get("query") - documents = payload.get("documents") - - if not model: - raise HTTPException(status_code=400, detail="Missing required field 'model'") - if not query: - raise HTTPException(status_code=400, detail="Missing required field 'query'") - if not isinstance(documents, list) or not documents: - raise HTTPException(status_code=400, detail="Missing or empty required field 'documents' (must be a non-empty list)") - except orjson.JSONDecodeError as e: - raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}") from e - - # Determine which endpoint serves this model - try: - endpoint, tracking_model = await choose_endpoint(model) - except RuntimeError as e: - raise HTTPException(status_code=404, detail=str(e)) - - # Ollama endpoints have no native rerank support - if not is_openai_compatible(endpoint): - await decrement_usage(endpoint, tracking_model) - raise HTTPException( - status_code=501, - detail=( - f"Endpoint '{endpoint}' is a plain Ollama instance which does not support " - "reranking. Use a llama-server or OpenAI-compatible endpoint with a " - "dedicated reranker model." - ), - ) - - if ":latest" in model: - model = model.split(":latest")[0] - - # Build upstream rerank request body – forward only recognised fields - upstream_payload: dict = {"model": model, "query": query, "documents": documents} - for optional_key in ("top_n", "return_documents", "max_tokens_per_doc"): - if optional_key in payload: - upstream_payload[optional_key] = payload[optional_key] - - # Determine upstream URL: - # llama-server exposes /v1/rerank (base already contains /v1 for llama_server_endpoints) - # External OpenAI endpoints expose /rerank under their /v1 base - if endpoint in config.llama_server_endpoints: - # llama-server: endpoint may or may not already contain /v1 - if "/v1" in endpoint: - rerank_url = f"{endpoint}/rerank" - else: - rerank_url = f"{endpoint}/v1/rerank" - else: - # External OpenAI-compatible: ep2base gives us the /v1 base - rerank_url = f"{ep2base(endpoint)}/rerank" - - api_key = config.api_keys.get(endpoint, "no-key") - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } - - client: aiohttp.ClientSession = get_session(endpoint) - try: - async with client.post(rerank_url, json=upstream_payload, headers=headers) as resp: - response_bytes = await resp.read() - if resp.status >= 400: - raise HTTPException( - status_code=resp.status, - detail=_mask_secrets(response_bytes.decode("utf-8", errors="replace")), - ) - data = orjson.loads(response_bytes) - - # Record token usage if the upstream returned a usage object - usage = data.get("usage") or {} - prompt_tok = usage.get("prompt_tokens") or 0 - total_tok = usage.get("total_tokens") or 0 - # For reranking there are no completion tokens; we record prompt tokens only - if prompt_tok or total_tok: - await token_queue.put((endpoint, tracking_model, prompt_tok, 0)) - - return JSONResponse(content=data) - finally: - await decrement_usage(endpoint, tracking_model) - -# ------------------------------------------------------------- -# 25b. Cache management endpoints -# ------------------------------------------------------------- -@app.get("/api/cache/stats") -async def cache_stats(): - """Return hit/miss counters and configuration for the LLM response cache.""" - c = get_llm_cache() - if c is None: - return {"enabled": False} - return {"enabled": True, **c.stats()} - - -@app.post("/api/cache/invalidate") -async def cache_invalidate(): - """Clear all entries from the LLM response cache and reset counters.""" - c = get_llm_cache() - if c is None: - return {"enabled": False, "cleared": False} - await c.clear() - return {"enabled": True, "cleared": True} - - +# (affinity_stats, usage, config — moved to api/management.py) +# (v1/* routes — moved to api/openai.py) +# (cache routes — moved to api/management.py) # ------------------------------------------------------------- # 26. Serve the static front‑end # ------------------------------------------------------------- app.mount("/static", StaticFiles(directory="static"), name="static") -@app.get("/favicon.ico") -async def redirect_favicon(): - return RedirectResponse(url="/static/favicon.ico") - -@app.get("/", response_class=HTMLResponse) -async def index(request: Request): - """ - Render the dynamic NOMYO Router dashboard listing the configured endpoints - and the models details, availability & task status. - """ - index_path = STATIC_DIR / "index.html" - try: - return HTMLResponse(content=index_path.read_text(encoding="utf-8"), status_code=200) - except FileNotFoundError: - raise HTTPException(status_code=404, detail="Page not found") - except Exception: - raise HTTPException(status_code=500, detail="Internal server error") - -# ------------------------------------------------------------- -# 26. Healthendpoint -# ------------------------------------------------------------- -@app.get("/health") -async def health_proxy(request: Request): - """ - Health‑check endpoint for monitoring the proxy. - - * Queries each configured endpoint for both liveness and routing health: - Ollama endpoints are probed at `/api/version` AND `/api/ps`, - OpenAI-compatible endpoints at `/models`. - * Returns a JSON object containing: - - `status`: "ok" if every endpoint replied to every probe, otherwise "error". - - `endpoints`: a mapping of endpoint URL → `{status, version|detail}`. - * The HTTP status code is 200 when everything is healthy, 503 otherwise. - """ - # Run all health checks in parallel. - # Ollama endpoints expose /api/version (liveness) and /api/ps (routing - # health — required by `choose_endpoint`). OpenAI-compatible endpoints - # (vLLM, llama-server, external) expose /models, which serves both - # purposes. Probing /api/version alone would miss the case where the - # Ollama process is up but /api/ps is failing — see issue #83. - all_endpoints = list(config.endpoints) - llama_eps_extra = [ep for ep in config.llama_server_endpoints if ep not in config.endpoints] - all_endpoints += llama_eps_extra - - probe_results = await asyncio.gather( - *(_endpoint_health(ep) for ep in all_endpoints), - ) - - health_summary = dict(zip(all_endpoints, probe_results)) - overall_ok = all(entry.get("status") == "ok" for entry in probe_results) - - response_payload = { - "status": "ok" if overall_ok else "error", - "endpoints": health_summary, - } - - http_status = 200 if overall_ok else 503 - return JSONResponse(content=response_payload, status_code=http_status) - -# ------------------------------------------------------------- -# 27. Hostname endpoint -# ------------------------------------------------------------- -@app.get("/api/hostname") -async def get_hostname(): - """Return the hostname of the machine running the router.""" - return JSONResponse(content={"hostname": socket.gethostname()}) - -# ------------------------------------------------------------- -# 28. SSE route for usage broadcasts -# ------------------------------------------------------------- -@app.get("/api/usage-stream") -async def usage_stream(request: Request): - """ - Server‑Sent‑Events that emits a JSON payload every time the - global `usage_counts` dictionary changes. - """ - async def event_generator(): - # The queue that receives *every* new snapshot - queue = await subscribe() - try: - while True: - # If the client disconnects, cancel the loop - if await request.is_disconnected(): - break - data = await queue.get() - if data is None: - break - # Send the data as a single SSE message - yield f"data: {data}\n\n" - finally: - # Clean‑up: unsubscribe from the broadcast channel - await unsubscribe(queue) - - return StreamingResponse(event_generator(), media_type="text/event-stream") +from api.static import router as static_router +app.include_router(static_router) +from api.management import router as management_router +app.include_router(management_router) +from api.openai import router as openai_router +app.include_router(openai_router) +from api.ollama import router as ollama_router +app.include_router(ollama_router) +# (health, hostname, usage-stream — moved to api/management.py) # ------------------------------------------------------------- # 28. FastAPI startup/shutdown events # ------------------------------------------------------------- diff --git a/test/test_openai_proxies.py b/test/test_openai_proxies.py index 8a56c91..894f4a1 100644 --- a/test/test_openai_proxies.py +++ b/test/test_openai_proxies.py @@ -10,6 +10,7 @@ import pytest from fastapi import HTTPException import router +from api import openai as api_openai _BYPASS = HTTPException(status_code=599, detail="bypassed") @@ -47,8 +48,8 @@ class TestOpenAIChatCompletionsCacheHit: # Patch the route's references to both helpers — they're imported by name # into router's namespace at module load time. with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", + patch.object(api_openai, "get_llm_cache", return_value=fake), + patch.object(api_openai, "choose_endpoint", AsyncMock(side_effect=AssertionError("backend must not be reached"))), ): resp = await client.post( @@ -70,8 +71,8 @@ class TestOpenAIChatCompletionsCacheHit: async def test_stream_cache_hit_returns_sse(self, client, cache_hit_payload): fake = _FakeCache(cache_hit_payload) with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", + patch.object(api_openai, "get_llm_cache", return_value=fake), + patch.object(api_openai, "choose_endpoint", AsyncMock(side_effect=AssertionError("backend must not be reached"))), ): resp = await client.post( @@ -98,8 +99,8 @@ class TestOpenAIChatCompletionsCacheHit: """When nomyo.cache=False, get_chat is never called even if a cache exists.""" fake = _FakeCache(b"") # has a response, but should never be consulted with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", + patch.object(api_openai, "get_llm_cache", return_value=fake), + patch.object(api_openai, "choose_endpoint", AsyncMock(side_effect=_BYPASS)), ): resp = await client.post( @@ -117,8 +118,8 @@ class TestOpenAIChatCompletionsCacheHit: async def test_no_cache_configured_bypasses_cache_check(self, client): """get_llm_cache() returning None should not break the route.""" with ( - patch.object(router, "get_llm_cache", return_value=None), - patch.object(router, "choose_endpoint", + patch.object(api_openai, "get_llm_cache", return_value=None), + patch.object(api_openai, "choose_endpoint", AsyncMock(side_effect=_BYPASS)), ): resp = await client.post( @@ -140,8 +141,8 @@ class TestOpenAICompletionsCacheHit: async def test_nonstream_cache_hit(self, client, cache_hit_payload): fake = _FakeCache(cache_hit_payload) with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", + patch.object(api_openai, "get_llm_cache", return_value=fake), + patch.object(api_openai, "choose_endpoint", AsyncMock(side_effect=AssertionError("backend must not be reached"))), ): resp = await client.post( @@ -163,8 +164,8 @@ class TestOpenAICompletionsCacheHit: async def test_stream_cache_hit(self, client, cache_hit_payload): fake = _FakeCache(cache_hit_payload) with ( - patch.object(router, "get_llm_cache", return_value=fake), - patch.object(router, "choose_endpoint", + patch.object(api_openai, "get_llm_cache", return_value=fake), + patch.object(api_openai, "choose_endpoint", AsyncMock(side_effect=AssertionError("backend must not be reached"))), ): resp = await client.post(